SPU 训练神经网络#

请先阅读 Logistic Regression On SPU

在实验 Logistic Regression On SPU 中,我们展示了如何使用 SecretFlow/SPU 将明文 JAX 训练程序转换为安全 MPC 训练程序。

在这个实验室中,这个想法非常相似,但这次我们将使用神经网络模型。

我们将使用相同的数据集和所有设置。

首先,让我们制定明文模型。

以下代码仅作为示例,请勿在生产环境直接使用。

本教程需要比 8c16g 更多的资源,这是 SecretFlow 的最低要求。

使用 JAX/FLAX 训练模型#

加载数据集#

以下内容是复制于实验 Logistic Regression On SPU 。 我不会再次在这里解释。

[ ]:
import sys
!{sys.executable} -m pip install flax==0.6.0
[1]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer


def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
    scaler = Normalizer(norm='max')
    x, y = load_breast_cancer(return_X_y=True)
    x = scaler.fit_transform(x)
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.2, random_state=42
    )

    if train:
        if party_id:
            if party_id == 1:
                return x_train[:, 15:], _
            else:
                return x_train[:, :15], y_train
        else:
            return x_train, y_train
    else:
        return x_test, y_test

定义模型#

我们将使用 4 层 MLP 模型和 ReLU 激活函数。

[2]:
from typing import Sequence
import flax.linen as nn


FEATURES = [30, 15, 8, 1]


class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

然后我们在这里定义训练方法。

[3]:
import jax.numpy as jnp


def predict(params, x):
    # TODO(junfeng): investigate why need to have a duplicated definition in notebook,
    # which is not the case in a normal python program.
    from typing import Sequence
    import flax.linen as nn

    FEATURES = [30, 15, 8, 1]

    class MLP(nn.Module):
        features: Sequence[int]

        @nn.compact
        def __call__(self, x):
            for feat in self.features[:-1]:
                x = nn.relu(nn.Dense(feat)(x))
            x = nn.Dense(self.features[-1])(x)
            return x

    return MLP(FEATURES).apply(params, x)


def loss_func(params, x, y):
    pred = predict(params, x)

    def mse(y, pred):
        def squared_error(y, y_pred):
            return jnp.multiply(y - y_pred, y - y_pred) / 2.0

        return jnp.mean(squared_error(y, pred))

    return mse(y, pred)


def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):
    x = jnp.concatenate((x1, x2), axis=1)
    xs = jnp.array_split(x, len(x) / n_batch, axis=0)
    ys = jnp.array_split(y, len(y) / n_batch, axis=0)

    def body_fun(_, loop_carry):
        params = loop_carry
        for (x, y) in zip(xs, ys):
            _, grads = jax.value_and_grad(loss_func)(params, x, y)
            params = jax.tree_util.tree_map(
                lambda p, g: p - step_size * g, params, grads
            )
        return params

    params = jax.lax.fori_loop(0, n_epochs, body_fun, params)
    return params


def model_init(n_batch=10):
    model = MLP(FEATURES)
    return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))

验证模型#

我们使用 AUC 作为验证指标。

[4]:
from sklearn.metrics import roc_auc_score


def validate_model(params, X_test, y_test):
    y_pred = predict(params, X_test)
    return roc_auc_score(y_test, y_pred)

放在一起#

让我们把所有不住放在一起,训练一个明文 NN 模型!

[5]:
import jax

# Load the data
x1, _ = breast_cancer(party_id=1, train=True)
x2, y = breast_cancer(party_id=2, train=True)


# Hyperparameter
n_batch = 10
n_epochs = 10
step_size = 0.01


# Train the model
init_params = model_init(n_batch)
params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)

# Test the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
auc=0.9921388797903701

让我们把所有东西放在一起,训练一个明文神经网络模型!

使用 SPU 训练模型#

[6]:
import secretflow as sf

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(['alice', 'bob'], address='local')

alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))

x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params = model_init(n_batch)


device = spu
x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)
init_params_ = sf.to(alice, init_params).to(device)

params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(
    x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size
)

2022-08-25 10:54:11.675312: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944990) 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944994) 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944991) 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944992) 2022-08-25 10:54:17.064685: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944996) 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944988) 2022-08-25 10:54:17.011062: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(_run pid=2944995) 2022-08-25 10:54:17.145220: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(_run pid=2944993) 2022-08-25 10:54:17.154674: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(SPURuntime pid=2944988) I0825 10:54:18.801868 2944988 external/com_github_brpc_brpc/src/brpc/server.cpp:1066] Server[yacl::link::internal::ReceiverServiceImpl] is serving on port=44351.
(SPURuntime pid=2944988) I0825 10:54:18.801938 2944988 external/com_github_brpc_brpc/src/brpc/server.cpp:1069] Check out http://k69b13338.eu95sqa:44351 in web browser.
(_run pid=2944995) 2022-08-25 10:54:18,906,906 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(SPURuntime pid=2944990) I0825 10:54:18.863563 2944990 external/com_github_brpc_brpc/src/brpc/server.cpp:1066] Server[yacl::link::internal::ReceiverServiceImpl] is serving on port=22227.
(SPURuntime pid=2944990) I0825 10:54:18.863633 2944990 external/com_github_brpc_brpc/src/brpc/server.cpp:1069] Check out http://k69b13338.eu95sqa:22227 in web browser.
(SPURuntime pid=2944988) I0825 10:54:18.902659 2945551 external/com_github_brpc_brpc/src/brpc/socket.cpp:2202] Checking Socket{id=0 addr=127.0.0.1:22227} (0x55f83a4820c0)
(SPURuntime pid=2944988) I0825 10:54:18.902798 2945551 external/com_github_brpc_brpc/src/brpc/socket.cpp:2262] Revived Socket{id=0 addr=127.0.0.1:22227} (0x55f83a4820c0) (Connectable)
(_run pid=2944995) [2022-08-25 10:54:18.906] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(_run pid=2944993) [2022-08-25 10:54:18.941] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(_run pid=2944993) 2022-08-25 10:54:18,941,941 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=2944996) 2022-08-25 10:54:18,952,952 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

让我们检查 SPU 程序的参数。

[7]:
params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
params = sf.reveal(params_spu)
print(params)

(SPURuntime pid=2944988) [2022-08-25 10:54:29.343] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(SPURuntime pid=2944990) [2022-08-25 10:54:29.328] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
FrozenDict({
    params: {
        Dense_0: {
            bias: array([-5.1895231e-03,  1.8576235e-03,  6.7055225e-06,  2.3722053e-03,
                    3.1328186e-02,  6.7055225e-06, -1.5438795e-02,  6.7055225e-06,
                    6.7055225e-06, -4.1383177e-02,  6.7055225e-06, -2.2077844e-02,
                    6.7055225e-06,  6.7055225e-06, -4.1123331e-03,  6.7055225e-06,
                    3.9519504e-02, -2.0013824e-02,  6.7055225e-06, -1.7026097e-02,
                   -2.0100027e-03, -1.2290031e-02,  3.5510257e-02,  6.7055225e-06,
                    1.0400787e-02,  6.7055225e-06,  6.7055225e-06,  6.7055225e-06,
                    6.7055225e-06, -2.4091452e-03], dtype=float32),
            kernel: array([[-1.48707509e-01, -2.35312745e-01, -1.49370432e-01,
                    -1.55658722e-02, -1.33344889e-01,  1.91762865e-01,
                    -3.67970318e-02, -3.74531150e-02, -1.41760916e-01,
                     3.24019790e-02,  1.26554981e-01, -4.02508259e-01,
                    -1.68948472e-01,  2.13856786e-01, -1.38430491e-01,
                     1.05858222e-01, -1.16115242e-01,  3.86231482e-01,
                     5.96616715e-02,  6.30981326e-02,  7.79366642e-02,
                    -1.31276995e-02, -2.88023233e-01, -9.60215628e-02,
                     1.11036777e-01, -8.54356140e-02,  7.54796267e-02,
                    -4.11889851e-02, -3.82702172e-01,  2.37623483e-01],
                   [ 1.77950412e-01,  2.29401365e-01, -2.44401962e-01,
                    -1.48464620e-01,  3.36820692e-01,  2.58408487e-02,
                    -4.21461910e-02,  4.10513222e-01,  3.24393630e-01,
                    -1.64218560e-01,  8.17326903e-02,  5.25916368e-02,
                     3.11348379e-01,  2.92945385e-01,  1.22726962e-01,
                    -3.87524545e-01, -3.85492414e-01, -6.53883070e-02,
                    -2.59158403e-01, -3.32410604e-01, -3.15876693e-01,
                    -2.91079849e-01, -6.01480752e-02,  2.29791373e-01,
                     1.01019114e-01, -1.30936205e-02,  1.78829342e-01,
                    -2.32158348e-01,  3.82750481e-01,  3.79820764e-02],
                   [ 1.42894119e-01, -2.13525295e-02,  1.68190986e-01,
                     8.99314433e-02, -3.89296085e-01, -4.85262126e-02,
                     1.38709426e-01, -5.80671132e-02,  2.84729242e-01,
                    -1.26564085e-01,  2.57159889e-01,  9.64836925e-02,
                     1.16704285e-01, -1.90519944e-01, -3.98443341e-02,
                    -9.57311541e-02,  7.24758208e-02,  1.46388173e-01,
                     9.22411978e-02,  3.83930653e-02, -2.46246234e-01,
                     3.76523137e-02, -1.90257579e-02, -2.52097666e-01,
                     1.70252651e-01,  2.49821901e-01, -2.81400979e-03,
                    -9.84508097e-02,  2.07754672e-01,  8.81182998e-02],
                   [-2.25118160e-01, -4.46529686e-03, -4.55757827e-02,
                    -4.27887589e-02, -1.34788275e-01, -3.42829585e-01,
                    -9.25171375e-03,  9.01048332e-02, -2.82488763e-01,
                     2.23106414e-01,  1.17174536e-01,  4.47563082e-02,
                    -5.02136499e-02,  2.90469885e-01, -2.35435143e-01,
                     5.99489361e-02,  2.55998850e-01,  1.87139437e-01,
                     1.46078974e-01,  2.93805301e-01,  3.95906270e-02,
                     1.63574621e-01, -1.12822607e-01, -6.96890652e-02,
                     2.61215031e-01, -3.17869127e-01,  8.15212727e-02,
                     2.55623460e-01, -5.98048419e-02,  1.93995312e-01],
                   [-2.42546692e-01,  7.96881765e-02, -6.76873624e-02,
                    -1.17468625e-01,  1.87561423e-01, -6.13799393e-02,
                    -5.36612719e-02, -6.93449825e-02,  7.92452097e-02,
                     2.54104435e-02, -3.18573356e-01,  2.87047565e-01,
                    -6.02750182e-02, -3.01489800e-01, -1.76609412e-01,
                     7.97384530e-02,  1.61419287e-01,  3.27948928e-01,
                     2.05150560e-01,  3.03487569e-01,  2.71105886e-01,
                     2.76556283e-01,  7.07161725e-02,  2.08005011e-01,
                    -7.33365566e-02, -1.03249148e-01,  1.55347139e-02,
                     3.17582250e-01,  3.16770792e-01, -6.80912733e-02],
                   [ 7.31387287e-02,  6.06941879e-02,  5.30773848e-02,
                     1.78480223e-01,  1.88692018e-01,  1.77043870e-01,
                     8.07239711e-02,  1.48239732e-02, -4.42493856e-02,
                     6.05108887e-02,  2.88271666e-01,  5.82368821e-02,
                     2.64274329e-01,  1.20873421e-01, -2.19291598e-02,
                    -1.51663244e-01, -4.52437997e-02,  6.89959526e-03,
                     2.83989072e-01,  1.00405067e-01, -3.03579658e-01,
                    -3.13743889e-01, -1.75221011e-01,  1.12235680e-01,
                     1.15922809e-01, -2.76054680e-01, -6.86793476e-02,
                     6.13613129e-02,  3.08204144e-01, -2.80041754e-01],
                   [-2.58817077e-01, -1.54795498e-02,  2.77136236e-01,
                    -3.83958876e-01,  4.00266647e-01, -1.29193932e-01,
                    -2.80725807e-02,  1.75729364e-01,  1.30319521e-01,
                     1.56524405e-01, -2.62551606e-02,  2.93114007e-01,
                     3.55902165e-02,  3.58770192e-02, -2.96767652e-01,
                    -1.69694602e-01,  3.19763869e-02,  1.61003023e-01,
                    -1.75322652e-01, -8.83801430e-02, -4.86088842e-02,
                    -1.05747893e-01,  8.46082866e-02, -4.35629487e-02,
                    -2.53890634e-01, -9.85630006e-02, -4.16661799e-02,
                    -4.67738807e-02, -3.35340977e-01, -1.18303642e-01],
                   [ 2.78338730e-01, -6.96775764e-02, -2.48151869e-02,
                     4.41243351e-02, -8.29881430e-02, -1.39038265e-03,
                     8.48067403e-02, -2.48347864e-01,  1.53563991e-01,
                    -1.36309460e-01, -4.34331596e-03, -6.34510815e-03,
                    -7.79297352e-02, -1.02945447e-01,  2.40225092e-01,
                    -2.56992996e-01,  2.98645079e-01, -1.97924539e-01,
                    -2.75554240e-01,  1.29482746e-02,  1.26732409e-01,
                     1.27245054e-01,  1.59333318e-01,  1.43134385e-01,
                     6.40332401e-02, -1.11376449e-01,  5.90692759e-02,
                    -1.16492316e-01,  1.58721358e-02, -2.03465536e-01],
                   [ 6.27327412e-02,  8.74824524e-02, -9.10452157e-02,
                     1.81233525e-01,  1.06525749e-01, -1.38279110e-01,
                     3.20228636e-01, -1.48054436e-01, -3.18492800e-02,
                     3.50982606e-01, -1.33533657e-01,  1.76203504e-01,
                     5.12142032e-02, -4.60760295e-02, -1.33599430e-01,
                    -2.22507954e-01,  1.66286990e-01,  3.71264696e-01,
                    -1.16020367e-01, -1.61022544e-02, -8.04677159e-02,
                     1.58347130e-01,  3.06297302e-01,  6.74358159e-02,
                     8.46540183e-02, -9.24098492e-03, -3.53501737e-03,
                    -3.68553460e-01, -1.23027235e-01,  2.18146190e-01],
                   [ 2.83379406e-02, -1.24499649e-01,  1.79811046e-01,
                     1.53642461e-01,  5.48425168e-02,  1.91716999e-01,
                    -9.49321389e-02,  6.86786622e-02, -7.67818838e-02,
                    -1.93986297e-02,  5.70140183e-02, -3.93389583e-01,
                     5.28795272e-02,  3.79496545e-01,  2.46415466e-01,
                    -1.21219859e-01,  4.00151759e-02, -3.80354345e-01,
                     1.95414037e-01, -9.05120224e-02,  3.20608616e-01,
                     1.48523450e-02, -3.49244773e-02,  1.11090288e-01,
                    -3.37234616e-01, -3.06017160e-01, -1.13247246e-01,
                     1.59685776e-01,  6.75145537e-02,  1.00891382e-01],
                   [ 1.68052852e-01,  1.94981366e-01, -9.76378322e-02,
                    -1.45580143e-01,  1.01531252e-01, -3.17420304e-01,
                    -1.15842447e-01, -2.86557317e-01, -1.01209313e-01,
                    -1.30138457e-01,  1.97996482e-01, -6.92996681e-02,
                     1.83074176e-03, -6.13981634e-02, -2.38128498e-01,
                     1.41838133e-01,  4.12078500e-01, -1.11510590e-01,
                    -7.69597739e-02, -3.93857658e-02, -5.82328439e-02,
                    -2.56166637e-01,  1.75529957e-01, -5.77670783e-02,
                     4.62778807e-02,  1.20462373e-01,  3.14444661e-01,
                    -1.82372808e-01, -1.62539259e-01, -9.76682454e-02],
                   [-6.19109869e-02, -1.15577027e-01,  7.26505965e-02,
                     1.25303209e-01,  2.06839561e-01,  1.57670394e-01,
                    -8.05730969e-02, -1.94498345e-01,  2.13316083e-02,
                     2.35428169e-01, -1.77006692e-01, -3.51173997e-01,
                    -2.20170230e-01,  3.13621610e-02,  1.01004869e-01,
                    -4.00861502e-01, -1.33804008e-01, -6.59427792e-02,
                    -1.41224921e-01, -1.72024697e-01, -6.66107237e-02,
                     9.94113684e-02, -3.09019983e-02,  2.59390533e-01,
                     6.44736290e-02, -2.50633597e-01, -3.49195153e-02,
                     8.02383423e-02,  2.55563200e-01, -2.40900025e-01],
                   [ 8.98088515e-03, -4.00735557e-01, -6.30197525e-02,
                     6.18334711e-02,  3.73557806e-01,  3.17755789e-02,
                     2.75026381e-01, -2.88109362e-01, -2.02475563e-01,
                     1.61148801e-01, -2.17942014e-01,  1.06317997e-01,
                     2.66866386e-03, -2.73037195e-01,  7.52992183e-02,
                     7.77817965e-02,  2.63209343e-02,  9.45713371e-02,
                    -2.83376873e-01,  2.55718678e-02,  1.71330452e-01,
                     4.77515757e-02, -1.29865110e-02, -9.19252783e-02,
                     2.20202535e-01, -1.98967785e-01,  3.41536522e-01,
                     8.68079364e-02, -8.85302424e-02, -9.05130804e-03],
                   [ 1.20346710e-01,  1.25420868e-01, -3.62598658e-01,
                     2.23636329e-01, -7.36351609e-02,  1.05015352e-01,
                     4.35760617e-03, -8.31906050e-02,  2.28634894e-01,
                    -1.49365649e-01,  8.16624165e-02, -3.14154029e-01,
                     8.76248032e-02, -5.55702299e-02,  8.57660025e-02,
                    -2.31696367e-02, -2.23352492e-01,  2.85403579e-02,
                     1.04186803e-01,  9.74222422e-02,  8.70420933e-02,
                     1.58863813e-02,  1.73726633e-01,  8.37596357e-02,
                    -1.77873820e-02, -6.53727055e-02,  5.05047441e-02,
                    -2.29445338e-01,  5.72724342e-02,  2.50912070e-01],
                   [-3.97108495e-01,  1.00127161e-01, -7.08009750e-02,
                    -1.62648782e-01, -1.39103666e-01,  1.61610574e-01,
                     1.60218611e-01, -7.88767636e-03,  5.42842150e-02,
                     1.65928125e-01,  2.23704860e-01,  3.66966009e-01,
                     6.14991188e-02, -4.85754609e-02,  3.34523857e-01,
                    -7.26056099e-02,  1.93905726e-01, -6.00316077e-02,
                    -3.03020537e-01,  1.71825185e-01,  2.90646702e-01,
                     2.13969320e-01,  4.79212999e-02,  9.81050730e-02,
                     1.03307858e-01,  1.27324745e-01, -5.79782724e-02,
                     1.52468219e-01, -3.66631836e-01,  1.77991658e-01],
                   [ 1.05444938e-02,  8.89388323e-02,  2.68580914e-01,
                     3.43994796e-01,  7.01821893e-02,  3.07618678e-01,
                    -1.59211323e-01,  2.65379250e-03,  4.29789275e-02,
                    -6.48853630e-02, -1.19794458e-02, -1.01899371e-01,
                     3.90142053e-02, -2.12668777e-02,  1.34829685e-01,
                    -1.60709843e-01, -2.75420129e-01, -1.10718116e-01,
                     1.22148365e-01, -2.56418556e-01, -8.86735022e-02,
                     3.57692540e-02, -1.72454745e-01,  1.25590444e-01,
                     1.48163527e-01,  9.70243216e-02,  1.78432480e-01,
                     8.07075799e-02,  7.18790591e-02,  8.29363018e-02],
                   [ 1.46744028e-01,  1.35470942e-01, -5.01304418e-02,
                    -2.56607473e-01, -2.36472189e-01,  2.16720849e-01,
                     1.36735886e-01, -3.88280898e-02,  3.90521705e-01,
                     3.07039917e-03,  1.45443454e-01, -7.04357922e-02,
                    -1.51877582e-01,  6.67887330e-02, -2.40254313e-01,
                     3.11602056e-01,  6.78158849e-02, -2.53712773e-01,
                    -2.01759011e-01, -2.26584569e-01,  1.38092458e-01,
                    -1.41293734e-01,  3.44138265e-01,  1.29559129e-01,
                     2.85035908e-01,  6.18830621e-02, -2.29603484e-01,
                     2.91220188e-01, -8.08279067e-02, -3.44568819e-01],
                   [-1.89113468e-02,  1.27276510e-01,  1.18291497e-01,
                    -8.91744196e-02, -3.88738513e-02, -6.17537051e-02,
                    -1.13684982e-01, -6.69639111e-02, -3.41004312e-01,
                    -2.61811137e-01, -1.48398668e-01, -2.04433903e-01,
                    -3.67532670e-03,  5.23980856e-02,  6.42350912e-02,
                     8.27208012e-02,  6.50502890e-02,  2.10662231e-01,
                    -1.37932986e-01,  3.07054341e-01, -9.90331918e-02,
                    -1.52072370e-01, -9.87340361e-02,  3.40974480e-02,
                     9.54811275e-02, -2.24524289e-01,  3.59725446e-01,
                     2.40096241e-01, -3.08321267e-02, -1.03981838e-01],
                   [-1.48987636e-01,  1.81017682e-01,  1.63939446e-02,
                    -2.69196749e-01, -2.44855881e-04, -3.97299677e-02,
                    -4.50673252e-02,  3.15325707e-02, -1.54727101e-01,
                    -1.31139323e-01,  4.07651365e-02,  1.12029955e-01,
                    -1.95583731e-01, -1.76382348e-01,  1.18832231e-01,
                    -2.69036591e-01,  3.04102361e-01, -8.61337632e-02,
                    -7.33592659e-02,  2.46652514e-02,  2.24634409e-01,
                     2.91894495e-01,  5.09455353e-02, -3.00022811e-01,
                    -5.17807007e-02, -7.68917203e-02, -3.71362567e-01,
                     1.96652308e-01, -1.05256483e-01, -2.75513947e-01],
                   [ 5.38429469e-02,  3.15863788e-02, -4.09971178e-03,
                    -4.45098132e-02, -1.00758657e-01, -6.42608255e-02,
                     3.13616514e-01, -1.36063650e-01,  1.24328464e-01,
                    -1.09250084e-01, -3.94054949e-02,  2.20205531e-01,
                    -7.17411488e-02,  8.70945156e-02,  4.95519042e-02,
                     3.63173425e-01,  6.60566986e-03, -1.58391029e-01,
                     9.21002179e-02, -1.74151346e-01, -1.42024517e-01,
                     3.83424699e-01,  2.24799365e-02,  7.36033916e-03,
                    -2.80536860e-02, -1.58879921e-01,  3.91074270e-02,
                    -9.43726748e-02,  2.17871547e-01,  1.44041181e-02],
                   [-9.30076092e-02, -1.98024854e-01, -3.14120024e-01,
                     1.71728119e-01,  1.33051425e-01, -1.41134948e-01,
                    -2.13182554e-01, -1.62383363e-01,  9.43484604e-02,
                     1.46695167e-01,  1.86109841e-02, -2.21152157e-02,
                    -1.46707222e-01,  3.92628670e-01, -2.01349884e-01,
                     1.09045491e-01, -2.89507657e-02, -1.52118295e-01,
                     1.74316362e-01,  7.77858645e-02,  9.58564728e-02,
                     1.02938607e-01,  1.89557344e-01, -1.57446101e-01,
                     9.71454233e-02, -2.65448719e-01, -5.12915403e-02,
                     8.04104656e-02,  5.85189909e-02,  2.47814834e-01],
                   [ 5.50863743e-02,  2.30716497e-01,  2.78447568e-03,
                    -5.15906960e-02, -1.33431196e-01,  1.72307670e-01,
                    -3.83732319e-02,  1.72320321e-01, -1.20988593e-01,
                    -1.21835843e-01, -1.65673420e-01, -8.69580209e-02,
                    -1.52244121e-02, -3.16982597e-01,  1.96166813e-01,
                    -2.08498746e-01,  3.45463872e-01,  2.52553254e-01,
                     3.05833966e-02, -2.36524627e-01, -2.45504826e-02,
                    -7.39030093e-02,  1.80508122e-01,  8.00530612e-02,
                     2.32508481e-02,  5.16086072e-02,  8.30580592e-02,
                    -1.09614551e-01,  2.05054462e-01,  5.47491461e-02],
                   [-2.92942554e-01,  1.58341527e-02, -5.25936484e-04,
                     7.54894614e-02,  1.75417557e-01,  1.60730407e-01,
                     5.91714680e-03, -2.53270268e-02, -2.71934599e-01,
                    -2.63609916e-01,  1.75930902e-01,  2.68442690e-01,
                    -1.60669044e-01,  4.51646745e-03, -4.13359016e-01,
                     1.32156700e-01,  2.06499949e-01, -9.21603739e-02,
                    -3.21212083e-01,  2.94354409e-02, -3.51502746e-02,
                    -1.13733053e-01,  7.04959035e-03,  6.02722168e-02,
                     3.10117364e-01,  3.13739121e-01,  1.54746130e-01,
                     2.38439858e-01,  2.05240831e-01,  1.65418029e-01],
                   [-1.31756335e-01, -9.71664637e-02, -2.11081415e-01,
                     3.06450367e-01,  1.31590337e-01,  2.54443496e-01,
                    -2.31858999e-01,  2.65551567e-01, -2.02050045e-01,
                     2.71081388e-01, -1.37878060e-02, -1.70021936e-01,
                    -1.65380538e-03,  9.54623818e-02,  2.84351885e-01,
                    -1.01871341e-01,  2.01485753e-02,  9.12380219e-02,
                    -6.15977198e-02,  4.80147898e-02, -1.39563948e-01,
                     3.02652836e-01,  1.64852113e-01, -2.00138420e-01,
                     9.66768116e-02, -9.22664106e-02, -9.24981087e-02,
                    -2.47369722e-01,  2.81407475e-01,  1.83511779e-01],
                   [-1.95965677e-01, -2.62236357e-01, -2.39685029e-02,
                     1.40571550e-01, -5.11714965e-02,  9.83205438e-02,
                     1.00091219e-01,  8.76448601e-02, -2.09155291e-01,
                    -4.81763929e-02,  1.15129888e-01, -1.07420832e-02,
                     6.28655702e-02, -1.43948466e-01,  1.83107510e-01,
                     1.86010495e-01, -1.79242194e-02, -1.05096996e-02,
                     2.98826218e-01,  2.92384177e-02,  1.50228307e-01,
                    -2.57354379e-02,  1.05159089e-01,  3.26832682e-01,
                    -6.47494942e-02, -7.94630796e-02, -3.30955148e-01,
                    -3.33948135e-01,  1.46547139e-01, -1.86091036e-01],
                   [-4.33291495e-02,  1.88203737e-01,  3.16037238e-02,
                     1.19405270e-01, -2.26772234e-01,  9.43277180e-02,
                    -8.72166008e-02,  2.56006062e-01, -1.48900136e-01,
                     9.94460434e-02,  1.87725961e-01, -1.95279077e-01,
                     8.27598572e-02, -1.46700770e-01, -1.25416532e-01,
                    -1.37769252e-01,  9.57628340e-02, -2.98057973e-01,
                     1.05415285e-01, -1.18127242e-01, -2.35549033e-01,
                    -1.76974535e-02, -2.97596186e-01,  4.32238877e-02,
                    -4.16904092e-02,  4.33115363e-02, -1.08654559e-01,
                     3.52643073e-01,  2.74524629e-01,  1.66427344e-02],
                   [ 1.77624241e-01, -7.08051473e-02, -1.25589028e-01,
                    -1.33987144e-01, -2.28391200e-01, -2.04036236e-01,
                     7.88579136e-02,  1.33848041e-01, -3.16916883e-01,
                    -1.34889603e-01, -8.19702297e-02,  2.77304649e-02,
                     2.47642547e-02,  1.05886340e-01, -2.58318663e-01,
                    -2.43119746e-01,  3.77329141e-02,  5.44693917e-02,
                    -1.35345489e-01, -1.10017732e-01, -3.13927293e-01,
                     5.12518585e-02, -5.12987375e-04, -1.58919290e-01,
                    -1.70748994e-01,  2.36288786e-01,  8.46761465e-02,
                     1.05235279e-02,  8.87275785e-02,  1.64180309e-01],
                   [ 3.75420362e-01,  6.81641251e-02,  7.72201270e-02,
                    -4.03955430e-01, -5.49944937e-02,  8.78675282e-03,
                     3.32516432e-01,  1.84740856e-01, -7.99065679e-02,
                     1.99955359e-01,  1.41463295e-01, -1.58552006e-01,
                    -1.59612060e-01, -1.87727511e-01, -1.75989211e-01,
                    -1.34044841e-01,  2.13294059e-01, -1.30978391e-01,
                     1.06950596e-01,  2.87034482e-01,  1.33578539e-01,
                     3.30342472e-01, -2.68601805e-01, -2.23761916e-01,
                     2.93608367e-01, -3.48806530e-02,  1.48320645e-01,
                     1.26249835e-01, -2.08334908e-01,  5.82257062e-02],
                   [ 1.78356782e-01, -1.20771334e-01, -7.79799819e-02,
                     1.64743349e-01,  1.32354870e-01,  1.11938149e-01,
                    -2.97888666e-02, -1.83453709e-02,  3.70834023e-02,
                    -3.93104136e-01,  4.38099802e-02,  1.25917226e-01,
                    -2.03110918e-01,  1.49912372e-01,  8.95697773e-02,
                    -1.43052682e-01,  3.78358483e-01,  2.53705919e-01,
                     9.40865725e-02,  2.99577773e-01,  1.92300975e-03,
                    -3.35528851e-02, -3.18194628e-01,  8.42917114e-02,
                    -1.03216350e-01,  1.54624790e-01,  1.52044371e-01,
                    -3.53974104e-03, -1.56484321e-01,  3.17795128e-02],
                   [-3.34984601e-01, -1.87048599e-01,  1.26603231e-01,
                     2.71421999e-01, -4.17920053e-02, -1.65970325e-02,
                    -1.58861458e-01,  1.46432027e-01,  1.03171512e-01,
                     1.39130712e-01,  2.62030542e-01,  3.82863283e-02,
                     1.70419857e-01,  2.81390846e-01,  3.02026719e-02,
                    -2.17159256e-01,  5.98860085e-02,  2.09414154e-01,
                     2.78205037e-01, -3.02840352e-01,  2.17414171e-01,
                     6.87691420e-02, -1.62359178e-02, -9.31997299e-02,
                     1.67161644e-01, -5.67281246e-02, -1.67869329e-02,
                    -3.39672238e-01,  4.14884984e-02,  2.41749167e-01]], dtype=float32),
        },
        Dense_1: {
            bias: array([ 6.7055225e-06,  6.7055225e-06,  6.7055225e-06,  6.7055225e-06,
                    6.7055225e-06,  6.7055225e-06, -4.6221808e-02, -3.4295321e-02,
                    4.9075842e-02,  6.9026187e-02,  5.4597139e-02,  6.7055225e-06,
                    6.7055225e-06, -2.3682773e-02, -1.4512062e-02], dtype=float32),
            kernel: array([[-2.15783074e-01, -8.00836086e-02, -3.41679364e-01,
                    -3.61634940e-02, -4.04338688e-02, -1.92787528e-01,
                     7.80968368e-02,  3.84697020e-01, -2.70941556e-01,
                     3.09228301e-02, -1.12026677e-01,  1.21513918e-01,
                     3.84846568e-01,  1.29461914e-01,  3.02034169e-02],
                   [ 3.03589344e-01,  1.49001822e-01,  2.24479139e-02,
                     1.72641233e-01,  1.11690164e-03, -1.60670713e-01,
                     1.60862997e-01, -2.07054362e-01,  3.59371305e-03,
                     8.14582109e-02, -8.32385570e-02, -2.71430016e-01,
                     3.29026878e-01,  6.39057159e-03, -2.08910599e-01],
                   [ 1.42273605e-02,  2.52373248e-01,  2.65929043e-01,
                    -7.87675530e-02,  2.57075429e-02,  1.37467578e-01,
                    -3.03784698e-01, -3.00662845e-01,  2.28536919e-01,
                     7.39716142e-02, -5.44494987e-02,  6.82624280e-02,
                    -1.14752382e-01, -4.36386168e-02, -2.58043408e-03],
                   [-3.78126502e-02,  3.64015222e-01, -6.23669326e-02,
                     1.66056380e-01, -1.78249836e-01,  1.75972238e-01,
                    -2.43168324e-01, -3.16800296e-01, -6.96922392e-02,
                    -5.26114106e-02,  1.36705399e-01, -5.23618162e-02,
                     1.28879577e-01,  2.59115607e-01, -6.60537034e-02],
                   [ 1.62194669e-02, -5.36157191e-03,  6.25325292e-02,
                     3.77438962e-03,  5.34493476e-02,  9.57345217e-02,
                    -3.37895215e-01,  1.54839545e-01, -5.76823950e-05,
                     3.46475482e-01, -1.44219398e-02, -1.57836825e-02,
                     1.02470160e-01, -2.66014189e-02,  1.38210922e-01],
                   [-1.54244944e-01, -1.89163774e-01, -1.25277504e-01,
                    -1.53509825e-01, -2.08621144e-01, -3.57687324e-02,
                     1.82844996e-02,  1.68684870e-01,  5.94071597e-02,
                     3.75792086e-02,  7.39661902e-02,  3.35429460e-02,
                     6.90625012e-02,  1.51754677e-01, -2.60873646e-01],
                   [ 9.05385166e-02,  3.13364267e-01, -1.75755590e-01,
                     5.33922911e-02, -1.96638182e-01,  2.29208365e-01,
                     2.01122493e-01,  1.39216185e-01,  3.18074644e-01,
                    -2.20635474e-01,  1.32455468e-01, -4.49397564e-02,
                    -3.59446257e-02,  3.02191943e-01, -1.93763226e-02],
                   [-9.32638943e-02, -9.22463536e-02,  2.90949643e-01,
                     1.38086572e-01,  1.45970643e-01,  1.11227736e-01,
                     1.92929432e-01, -2.20268294e-01,  8.51592422e-03,
                    -1.14801973e-02, -1.08394876e-01, -4.92374301e-02,
                     1.53573290e-01,  3.89962226e-01, -1.53074652e-01],
                   [ 2.90715694e-03,  1.83663428e-01,  3.76528651e-02,
                    -1.73860639e-02,  1.83179498e-01,  4.10255790e-03,
                     9.65543687e-02,  7.96876550e-02,  2.19800651e-01,
                     2.27372974e-01, -1.51361659e-01,  2.04350561e-01,
                     1.18743330e-01, -3.37018371e-01,  1.12518296e-01],
                   [-3.69710326e-02,  5.37259430e-02, -5.55092096e-03,
                    -4.27453220e-03, -1.99901059e-01, -1.28307149e-01,
                     5.43928742e-02,  1.31580710e-01, -1.68890849e-01,
                    -3.78164351e-01, -6.06988072e-02,  2.29855090e-01,
                    -1.12242132e-01,  3.79682928e-02, -3.86388510e-01],
                   [ 2.35311255e-01,  1.66386098e-01, -8.38957876e-02,
                     2.03469038e-01, -2.03613058e-01, -7.19606429e-02,
                     1.12464979e-01,  2.45347440e-01,  2.39544034e-01,
                    -1.39806151e-01, -2.63890773e-02, -1.12561002e-01,
                    -2.70868719e-01, -4.71335649e-03,  1.29959971e-01],
                   [-5.57020754e-02, -3.46530616e-01,  2.98494935e-01,
                    -1.66801214e-01,  6.14305437e-02,  9.28813517e-02,
                     1.35294989e-01, -1.34979218e-01, -7.89761543e-04,
                    -2.50806570e-01,  1.01332352e-01,  1.00093633e-01,
                     1.26480818e-01,  1.32050708e-01,  2.52038181e-01],
                   [ 1.38003528e-02, -1.96471453e-01,  1.48797393e-01,
                     3.88497114e-02, -1.44033611e-01,  3.50036263e-01,
                    -3.26102078e-02, -1.19598106e-01, -3.50412369e-01,
                    -9.01353359e-02,  1.68153301e-01, -1.73634619e-01,
                    -2.64526069e-01,  1.89368397e-01, -3.03420037e-01],
                   [ 1.53409019e-01, -1.65645137e-01,  2.75816798e-01,
                    -2.61331946e-02,  9.50511992e-02, -1.15451679e-01,
                    -2.91420221e-02,  9.09120291e-02,  1.60791710e-01,
                     1.59705788e-01, -1.63535506e-01, -2.39277333e-02,
                     2.17306137e-01, -3.75522256e-01,  3.61660779e-01],
                   [ 2.75457501e-01, -5.11639714e-02,  3.02877873e-02,
                     3.83744717e-01,  1.89788371e-01, -3.05498004e-01,
                    -1.36434168e-01, -9.84745473e-02, -8.30809176e-02,
                    -1.72840640e-01,  2.08918750e-03,  2.70356536e-01,
                    -1.43099576e-02,  1.42251849e-02, -2.30559021e-01],
                   [-1.12811208e-01, -8.90487880e-02,  5.26717752e-02,
                    -3.34500819e-02,  1.79550871e-01,  1.52729020e-01,
                    -5.19486815e-02,  1.09064624e-01,  2.16731712e-01,
                    -5.77685237e-02,  2.93150067e-01, -2.72270977e-01,
                     2.27185652e-01, -4.16627526e-02,  8.24269503e-02],
                   [ 1.12723231e-01,  1.53931171e-01, -1.41348794e-01,
                    -1.82246700e-01, -2.60128081e-01,  2.28294432e-01,
                    -1.41973585e-01, -2.20511407e-01,  4.12718713e-01,
                     2.26580381e-01,  3.02141160e-02, -4.08996940e-02,
                    -1.77697018e-01,  2.13378891e-01,  2.34325364e-01],
                   [-8.89122784e-02, -1.33175939e-01,  1.13867760e-01,
                    -3.15929264e-01, -5.33000082e-02, -3.32492590e-02,
                     1.61763191e-01, -4.45095897e-02, -3.13157290e-02,
                     1.08764470e-02, -1.84109345e-01, -4.28237021e-02,
                    -6.48176223e-02, -5.03456593e-03,  7.79272169e-02],
                   [ 1.07352838e-01,  1.53558820e-01,  9.59271789e-02,
                     1.95107758e-02, -7.16409087e-03, -2.52965033e-01,
                    -2.34644741e-01,  3.56166244e-01,  1.75924093e-01,
                    -1.80997580e-01, -2.52380699e-01, -5.56088388e-02,
                    -2.03573480e-01, -1.34791806e-01,  1.43947557e-01],
                   [-3.20794344e-01,  1.52731538e-02,  3.17644477e-01,
                     3.08467388e-01,  6.71731085e-02, -2.07317039e-01,
                     6.92533553e-02,  8.92890543e-02, -5.71464747e-02,
                     6.77597970e-02, -2.40857795e-01,  9.92331654e-02,
                    -3.97704244e-01, -1.44478083e-01,  1.78948924e-01],
                   [ 1.78290546e-01, -3.73491168e-01, -3.44891042e-01,
                    -1.85131654e-01, -1.25298679e-01, -3.59349012e-01,
                    -2.19286203e-01,  4.03494358e-01, -5.62228560e-02,
                    -1.21513605e-01,  3.14101666e-01, -5.40824682e-02,
                     1.32633939e-01,  1.56066120e-02,  2.20855936e-01],
                   [-1.59709528e-01, -1.92597717e-01, -3.99962455e-01,
                    -1.37944147e-01,  2.50320375e-01, -3.02701652e-01,
                     2.37861603e-01,  3.54276657e-01, -2.65088379e-02,
                     2.30067194e-01,  1.21199116e-01, -1.71849012e-01,
                    -1.15864873e-02,  2.79654086e-01,  6.52542561e-02],
                   [-3.87499511e-01, -2.14815795e-01, -3.50072563e-01,
                    -3.74599099e-01,  3.15827131e-03, -3.80146921e-01,
                    -2.75051177e-01,  4.92296517e-02, -3.42986435e-02,
                     2.64439374e-01,  2.24484161e-01, -9.03718919e-02,
                    -1.66824907e-01,  2.33413532e-01, -2.70957470e-01],
                   [-2.09726006e-01, -1.01519674e-02,  1.65578052e-01,
                     2.08754063e-01, -1.90129966e-01, -3.17800641e-01,
                    -3.11262459e-02, -6.45868927e-02,  3.97725523e-01,
                    -2.66408682e-01,  3.11382055e-01, -6.38214052e-02,
                    -3.96969140e-01,  1.07675835e-01,  1.15448385e-02],
                   [ 1.90294176e-01,  2.62975395e-02,  1.02480963e-01,
                    -2.17438534e-01, -2.62192190e-01,  1.49379149e-01,
                     1.36339113e-01, -2.20250994e-01, -6.47595972e-02,
                     2.78646350e-01, -1.04613230e-01, -3.21846724e-01,
                    -3.60531211e-02,  4.10647094e-02, -2.01862305e-02],
                   [ 4.84227687e-02, -1.74249634e-01,  1.22636214e-01,
                     3.27275246e-01, -8.30558389e-02, -3.14863682e-01,
                     1.06451482e-01,  9.95579511e-02,  7.17695206e-02,
                    -2.05804944e-01,  4.14260328e-02, -2.82542408e-03,
                     1.59710690e-01,  1.95358038e-01, -2.18680993e-01],
                   [ 1.80844143e-01,  9.28273797e-02, -2.79074967e-01,
                    -3.21841598e-01,  8.46190900e-02, -1.31674752e-01,
                    -2.22147837e-01,  6.93762749e-02,  1.08454213e-01,
                    -1.54395998e-01, -2.52935737e-02,  3.96427214e-02,
                    -1.77300423e-02,  4.08218354e-02,  1.57031059e-01],
                   [ 1.72292829e-01, -2.74210989e-01,  3.91593874e-02,
                    -1.06435806e-01,  1.53482422e-01, -4.07753527e-01,
                    -1.45187899e-01, -1.97193578e-01,  1.51643381e-01,
                     8.71199816e-02, -1.80934519e-02,  3.16381007e-02,
                    -3.16615790e-01, -8.88970047e-02, -3.15811843e-01],
                   [-9.74183232e-02, -2.85981894e-02,  3.55937093e-01,
                    -2.71050155e-01,  3.29447448e-01,  8.03399384e-02,
                     3.03259671e-01, -1.92141250e-01, -1.79099917e-01,
                     2.89139271e-01,  9.35540348e-02, -2.52466857e-01,
                     1.29278764e-01,  3.87039840e-01, -3.96069586e-01],
                   [ 6.32186532e-02,  8.49261880e-02, -1.27169192e-02,
                     1.09564900e-01,  1.01468608e-01, -1.91385388e-01,
                     2.12272123e-01,  3.93263578e-01,  3.21644545e-03,
                     1.96522057e-01, -1.37627035e-01,  2.72568524e-01,
                    -9.58500952e-02, -5.87881505e-02,  3.29161406e-01]], dtype=float32),
        },
        Dense_2: {
            bias: array([-1.9683748e-02, -2.1595210e-03,  6.7055225e-06, -5.0753772e-02,
                    1.3499202e-01,  4.5890406e-02, -1.5943155e-02,  6.7055225e-06],
                  dtype=float32),
            kernel: array([[ 1.70742482e-01, -9.91618633e-02, -2.29330659e-02,
                     1.22218162e-01, -1.06615439e-01, -3.67971659e-02,
                     1.41981930e-01,  8.84072632e-02],
                   [ 1.80515304e-01,  1.18458852e-01,  5.27178884e-01,
                    -1.53531089e-01, -3.87934148e-02, -1.23133227e-01,
                     6.29240274e-03,  1.60064697e-02],
                   [ 1.46227553e-01,  5.25590360e-01, -3.44348907e-01,
                    -9.29221213e-02,  1.28932714e-01, -1.73384249e-02,
                     1.19381860e-01, -4.11042988e-01],
                   [ 2.49569118e-01,  2.14901835e-01,  3.24754208e-01,
                    -4.91983056e-01, -1.14351794e-01, -2.11403668e-02,
                     7.69451857e-02,  3.31384718e-01],
                   [ 4.12717015e-02, -1.21542990e-01, -3.10934186e-01,
                     3.62211466e-01, -2.74409771e-01, -5.17052352e-01,
                    -5.41649759e-02,  5.59365511e-01],
                   [-4.59556133e-02, -6.34765029e-02,  1.12813324e-01,
                     3.72479081e-01,  1.19602963e-01,  2.29919955e-01,
                    -1.38893351e-01,  4.47091460e-02],
                   [-6.79109395e-02, -3.46875697e-01, -4.84580278e-01,
                     1.12352774e-01, -1.92538321e-01, -3.71922582e-01,
                     5.66129565e-01, -2.42807910e-01],
                   [ 3.35851789e-01, -1.21225804e-01,  3.77379179e-01,
                     4.12614942e-01,  4.95961308e-02, -3.28645080e-01,
                     4.95315671e-01, -1.31850213e-01],
                   [ 3.26091319e-01,  6.93900436e-02,  2.03807250e-01,
                     1.04140580e-01,  3.88821214e-01,  1.70190245e-01,
                    -2.11963385e-01,  5.57184219e-04],
                   [ 1.63282499e-01, -1.22407228e-02, -4.44926322e-01,
                    -9.26266760e-02,  3.41380268e-01,  4.87684071e-01,
                    -1.46689430e-01, -2.39341646e-01],
                   [-2.73013204e-01,  1.98071614e-01,  1.63841352e-01,
                    -4.02920216e-01,  1.04726478e-01,  3.63073707e-01,
                    -5.87156415e-03, -2.55927563e-01],
                   [ 2.71851361e-01, -6.50364012e-02,  5.75378239e-02,
                     8.61035287e-02,  4.62561399e-02,  6.84097409e-02,
                    -3.49434376e-01,  3.20657223e-01],
                   [ 3.87204051e-01,  1.02552548e-01, -3.67724180e-01,
                    -1.37631521e-01, -2.76330113e-03, -9.74713266e-03,
                    -2.03522891e-02, -3.78593743e-01],
                   [ 1.42845362e-01, -1.62034005e-01, -4.73217815e-01,
                     2.21667886e-01,  9.57852453e-02, -3.21318150e-01,
                     1.93716571e-01, -1.64225832e-01],
                   [-8.96123946e-02,  1.43070787e-01, -3.05288434e-02,
                     4.30531144e-01, -3.90004218e-01,  5.45606494e-01,
                    -9.91108418e-02, -3.92888844e-01]], dtype=float32),
        },
        Dense_3: {
            bias: array([0.3012179], dtype=float32),
            kernel: array([[-0.04792835],
                   [ 0.6472556 ],
                   [ 0.19396764],
                   [-0.19879013],
                   [ 0.4868285 ],
                   [ 0.70990056],
                   [-0.05229847],
                   [-0.20487049]], dtype=float32),
        },
    },
})

最后,让我们验证模型。

[8]:
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

auc=0.9921388797903701

实验到此结束。