SPU 训练神经网络#
请先阅读 Logistic Regression On SPU 。
在实验 Logistic Regression On SPU 中,我们展示了如何使用 SecretFlow/SPU 将明文 JAX 训练程序转换为安全 MPC 训练程序。
在这个实验室中,这个想法非常相似,但这次我们将使用神经网络模型。
我们将使用相同的数据集和所有设置。
首先,让我们制定明文模型。
以下代码仅作为示例,请勿在生产环境直接使用。
本教程需要比 8c16g 更多的资源,这是 SecretFlow 的最低要求。
使用 JAX/FLAX 训练模型#
加载数据集#
以下内容是复制于实验 Logistic Regression On SPU 。 我不会再次在这里解释。
[ ]:
import sys
!{sys.executable} -m pip install flax==0.6.0
[1]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer
def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
scaler = Normalizer(norm='max')
x, y = load_breast_cancer(return_X_y=True)
x = scaler.fit_transform(x)
x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=0.2, random_state=42
)
if train:
if party_id:
if party_id == 1:
return x_train[:, 15:], _
else:
return x_train[:, :15], y_train
else:
return x_train, y_train
else:
return x_test, y_test
定义模型#
[2]:
from typing import Sequence
import flax.linen as nn
FEATURES = [30, 15, 8, 1]
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
然后我们在这里定义训练方法。
[3]:
import jax.numpy as jnp
def predict(params, x):
# TODO(junfeng): investigate why need to have a duplicated definition in notebook,
# which is not the case in a normal python program.
from typing import Sequence
import flax.linen as nn
FEATURES = [30, 15, 8, 1]
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
return MLP(FEATURES).apply(params, x)
def loss_func(params, x, y):
pred = predict(params, x)
def mse(y, pred):
def squared_error(y, y_pred):
return jnp.multiply(y - y_pred, y - y_pred) / 2.0
return jnp.mean(squared_error(y, pred))
return mse(y, pred)
def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):
x = jnp.concatenate((x1, x2), axis=1)
xs = jnp.array_split(x, len(x) / n_batch, axis=0)
ys = jnp.array_split(y, len(y) / n_batch, axis=0)
def body_fun(_, loop_carry):
params = loop_carry
for (x, y) in zip(xs, ys):
_, grads = jax.value_and_grad(loss_func)(params, x, y)
params = jax.tree_util.tree_map(
lambda p, g: p - step_size * g, params, grads
)
return params
params = jax.lax.fori_loop(0, n_epochs, body_fun, params)
return params
def model_init(n_batch=10):
model = MLP(FEATURES)
return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))
验证模型#
我们使用 AUC 作为验证指标。
[4]:
from sklearn.metrics import roc_auc_score
def validate_model(params, X_test, y_test):
y_pred = predict(params, X_test)
return roc_auc_score(y_test, y_pred)
放在一起#
让我们把所有不住放在一起,训练一个明文 NN 模型!
[5]:
import jax
# Load the data
x1, _ = breast_cancer(party_id=1, train=True)
x2, y = breast_cancer(party_id=2, train=True)
# Hyperparameter
n_batch = 10
n_epochs = 10
step_size = 0.01
# Train the model
init_params = model_init(n_batch)
params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)
# Test the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
auc=0.9921388797903701
让我们把所有东西放在一起,训练一个明文神经网络模型!
使用 SPU 训练模型#
[6]:
import secretflow as sf
# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob'], address='local')
alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))
x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params = model_init(n_batch)
device = spu
x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)
init_params_ = sf.to(alice, init_params).to(device)
params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(
x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size
)
2022-08-25 10:54:11.675312: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944990) 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944994) 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944991) 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944992) 2022-08-25 10:54:17.064685: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944996) 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944988) 2022-08-25 10:54:17.011062: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(_run pid=2944995) 2022-08-25 10:54:17.145220: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(_run pid=2944993) 2022-08-25 10:54:17.154674: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(SPURuntime pid=2944988) I0825 10:54:18.801868 2944988 external/com_github_brpc_brpc/src/brpc/server.cpp:1066] Server[yacl::link::internal::ReceiverServiceImpl] is serving on port=44351.
(SPURuntime pid=2944988) I0825 10:54:18.801938 2944988 external/com_github_brpc_brpc/src/brpc/server.cpp:1069] Check out http://k69b13338.eu95sqa:44351 in web browser.
(_run pid=2944995) 2022-08-25 10:54:18,906,906 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(SPURuntime pid=2944990) I0825 10:54:18.863563 2944990 external/com_github_brpc_brpc/src/brpc/server.cpp:1066] Server[yacl::link::internal::ReceiverServiceImpl] is serving on port=22227.
(SPURuntime pid=2944990) I0825 10:54:18.863633 2944990 external/com_github_brpc_brpc/src/brpc/server.cpp:1069] Check out http://k69b13338.eu95sqa:22227 in web browser.
(SPURuntime pid=2944988) I0825 10:54:18.902659 2945551 external/com_github_brpc_brpc/src/brpc/socket.cpp:2202] Checking Socket{id=0 addr=127.0.0.1:22227} (0x55f83a4820c0)
(SPURuntime pid=2944988) I0825 10:54:18.902798 2945551 external/com_github_brpc_brpc/src/brpc/socket.cpp:2262] Revived Socket{id=0 addr=127.0.0.1:22227} (0x55f83a4820c0) (Connectable)
(_run pid=2944995) [2022-08-25 10:54:18.906] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(_run pid=2944993) [2022-08-25 10:54:18.941] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(_run pid=2944993) 2022-08-25 10:54:18,941,941 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=2944996) 2022-08-25 10:54:18,952,952 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
让我们检查 SPU 程序的参数。
[7]:
params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
params = sf.reveal(params_spu)
print(params)
(SPURuntime pid=2944988) [2022-08-25 10:54:29.343] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(SPURuntime pid=2944990) [2022-08-25 10:54:29.328] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
FrozenDict({
params: {
Dense_0: {
bias: array([-5.1895231e-03, 1.8576235e-03, 6.7055225e-06, 2.3722053e-03,
3.1328186e-02, 6.7055225e-06, -1.5438795e-02, 6.7055225e-06,
6.7055225e-06, -4.1383177e-02, 6.7055225e-06, -2.2077844e-02,
6.7055225e-06, 6.7055225e-06, -4.1123331e-03, 6.7055225e-06,
3.9519504e-02, -2.0013824e-02, 6.7055225e-06, -1.7026097e-02,
-2.0100027e-03, -1.2290031e-02, 3.5510257e-02, 6.7055225e-06,
1.0400787e-02, 6.7055225e-06, 6.7055225e-06, 6.7055225e-06,
6.7055225e-06, -2.4091452e-03], dtype=float32),
kernel: array([[-1.48707509e-01, -2.35312745e-01, -1.49370432e-01,
-1.55658722e-02, -1.33344889e-01, 1.91762865e-01,
-3.67970318e-02, -3.74531150e-02, -1.41760916e-01,
3.24019790e-02, 1.26554981e-01, -4.02508259e-01,
-1.68948472e-01, 2.13856786e-01, -1.38430491e-01,
1.05858222e-01, -1.16115242e-01, 3.86231482e-01,
5.96616715e-02, 6.30981326e-02, 7.79366642e-02,
-1.31276995e-02, -2.88023233e-01, -9.60215628e-02,
1.11036777e-01, -8.54356140e-02, 7.54796267e-02,
-4.11889851e-02, -3.82702172e-01, 2.37623483e-01],
[ 1.77950412e-01, 2.29401365e-01, -2.44401962e-01,
-1.48464620e-01, 3.36820692e-01, 2.58408487e-02,
-4.21461910e-02, 4.10513222e-01, 3.24393630e-01,
-1.64218560e-01, 8.17326903e-02, 5.25916368e-02,
3.11348379e-01, 2.92945385e-01, 1.22726962e-01,
-3.87524545e-01, -3.85492414e-01, -6.53883070e-02,
-2.59158403e-01, -3.32410604e-01, -3.15876693e-01,
-2.91079849e-01, -6.01480752e-02, 2.29791373e-01,
1.01019114e-01, -1.30936205e-02, 1.78829342e-01,
-2.32158348e-01, 3.82750481e-01, 3.79820764e-02],
[ 1.42894119e-01, -2.13525295e-02, 1.68190986e-01,
8.99314433e-02, -3.89296085e-01, -4.85262126e-02,
1.38709426e-01, -5.80671132e-02, 2.84729242e-01,
-1.26564085e-01, 2.57159889e-01, 9.64836925e-02,
1.16704285e-01, -1.90519944e-01, -3.98443341e-02,
-9.57311541e-02, 7.24758208e-02, 1.46388173e-01,
9.22411978e-02, 3.83930653e-02, -2.46246234e-01,
3.76523137e-02, -1.90257579e-02, -2.52097666e-01,
1.70252651e-01, 2.49821901e-01, -2.81400979e-03,
-9.84508097e-02, 2.07754672e-01, 8.81182998e-02],
[-2.25118160e-01, -4.46529686e-03, -4.55757827e-02,
-4.27887589e-02, -1.34788275e-01, -3.42829585e-01,
-9.25171375e-03, 9.01048332e-02, -2.82488763e-01,
2.23106414e-01, 1.17174536e-01, 4.47563082e-02,
-5.02136499e-02, 2.90469885e-01, -2.35435143e-01,
5.99489361e-02, 2.55998850e-01, 1.87139437e-01,
1.46078974e-01, 2.93805301e-01, 3.95906270e-02,
1.63574621e-01, -1.12822607e-01, -6.96890652e-02,
2.61215031e-01, -3.17869127e-01, 8.15212727e-02,
2.55623460e-01, -5.98048419e-02, 1.93995312e-01],
[-2.42546692e-01, 7.96881765e-02, -6.76873624e-02,
-1.17468625e-01, 1.87561423e-01, -6.13799393e-02,
-5.36612719e-02, -6.93449825e-02, 7.92452097e-02,
2.54104435e-02, -3.18573356e-01, 2.87047565e-01,
-6.02750182e-02, -3.01489800e-01, -1.76609412e-01,
7.97384530e-02, 1.61419287e-01, 3.27948928e-01,
2.05150560e-01, 3.03487569e-01, 2.71105886e-01,
2.76556283e-01, 7.07161725e-02, 2.08005011e-01,
-7.33365566e-02, -1.03249148e-01, 1.55347139e-02,
3.17582250e-01, 3.16770792e-01, -6.80912733e-02],
[ 7.31387287e-02, 6.06941879e-02, 5.30773848e-02,
1.78480223e-01, 1.88692018e-01, 1.77043870e-01,
8.07239711e-02, 1.48239732e-02, -4.42493856e-02,
6.05108887e-02, 2.88271666e-01, 5.82368821e-02,
2.64274329e-01, 1.20873421e-01, -2.19291598e-02,
-1.51663244e-01, -4.52437997e-02, 6.89959526e-03,
2.83989072e-01, 1.00405067e-01, -3.03579658e-01,
-3.13743889e-01, -1.75221011e-01, 1.12235680e-01,
1.15922809e-01, -2.76054680e-01, -6.86793476e-02,
6.13613129e-02, 3.08204144e-01, -2.80041754e-01],
[-2.58817077e-01, -1.54795498e-02, 2.77136236e-01,
-3.83958876e-01, 4.00266647e-01, -1.29193932e-01,
-2.80725807e-02, 1.75729364e-01, 1.30319521e-01,
1.56524405e-01, -2.62551606e-02, 2.93114007e-01,
3.55902165e-02, 3.58770192e-02, -2.96767652e-01,
-1.69694602e-01, 3.19763869e-02, 1.61003023e-01,
-1.75322652e-01, -8.83801430e-02, -4.86088842e-02,
-1.05747893e-01, 8.46082866e-02, -4.35629487e-02,
-2.53890634e-01, -9.85630006e-02, -4.16661799e-02,
-4.67738807e-02, -3.35340977e-01, -1.18303642e-01],
[ 2.78338730e-01, -6.96775764e-02, -2.48151869e-02,
4.41243351e-02, -8.29881430e-02, -1.39038265e-03,
8.48067403e-02, -2.48347864e-01, 1.53563991e-01,
-1.36309460e-01, -4.34331596e-03, -6.34510815e-03,
-7.79297352e-02, -1.02945447e-01, 2.40225092e-01,
-2.56992996e-01, 2.98645079e-01, -1.97924539e-01,
-2.75554240e-01, 1.29482746e-02, 1.26732409e-01,
1.27245054e-01, 1.59333318e-01, 1.43134385e-01,
6.40332401e-02, -1.11376449e-01, 5.90692759e-02,
-1.16492316e-01, 1.58721358e-02, -2.03465536e-01],
[ 6.27327412e-02, 8.74824524e-02, -9.10452157e-02,
1.81233525e-01, 1.06525749e-01, -1.38279110e-01,
3.20228636e-01, -1.48054436e-01, -3.18492800e-02,
3.50982606e-01, -1.33533657e-01, 1.76203504e-01,
5.12142032e-02, -4.60760295e-02, -1.33599430e-01,
-2.22507954e-01, 1.66286990e-01, 3.71264696e-01,
-1.16020367e-01, -1.61022544e-02, -8.04677159e-02,
1.58347130e-01, 3.06297302e-01, 6.74358159e-02,
8.46540183e-02, -9.24098492e-03, -3.53501737e-03,
-3.68553460e-01, -1.23027235e-01, 2.18146190e-01],
[ 2.83379406e-02, -1.24499649e-01, 1.79811046e-01,
1.53642461e-01, 5.48425168e-02, 1.91716999e-01,
-9.49321389e-02, 6.86786622e-02, -7.67818838e-02,
-1.93986297e-02, 5.70140183e-02, -3.93389583e-01,
5.28795272e-02, 3.79496545e-01, 2.46415466e-01,
-1.21219859e-01, 4.00151759e-02, -3.80354345e-01,
1.95414037e-01, -9.05120224e-02, 3.20608616e-01,
1.48523450e-02, -3.49244773e-02, 1.11090288e-01,
-3.37234616e-01, -3.06017160e-01, -1.13247246e-01,
1.59685776e-01, 6.75145537e-02, 1.00891382e-01],
[ 1.68052852e-01, 1.94981366e-01, -9.76378322e-02,
-1.45580143e-01, 1.01531252e-01, -3.17420304e-01,
-1.15842447e-01, -2.86557317e-01, -1.01209313e-01,
-1.30138457e-01, 1.97996482e-01, -6.92996681e-02,
1.83074176e-03, -6.13981634e-02, -2.38128498e-01,
1.41838133e-01, 4.12078500e-01, -1.11510590e-01,
-7.69597739e-02, -3.93857658e-02, -5.82328439e-02,
-2.56166637e-01, 1.75529957e-01, -5.77670783e-02,
4.62778807e-02, 1.20462373e-01, 3.14444661e-01,
-1.82372808e-01, -1.62539259e-01, -9.76682454e-02],
[-6.19109869e-02, -1.15577027e-01, 7.26505965e-02,
1.25303209e-01, 2.06839561e-01, 1.57670394e-01,
-8.05730969e-02, -1.94498345e-01, 2.13316083e-02,
2.35428169e-01, -1.77006692e-01, -3.51173997e-01,
-2.20170230e-01, 3.13621610e-02, 1.01004869e-01,
-4.00861502e-01, -1.33804008e-01, -6.59427792e-02,
-1.41224921e-01, -1.72024697e-01, -6.66107237e-02,
9.94113684e-02, -3.09019983e-02, 2.59390533e-01,
6.44736290e-02, -2.50633597e-01, -3.49195153e-02,
8.02383423e-02, 2.55563200e-01, -2.40900025e-01],
[ 8.98088515e-03, -4.00735557e-01, -6.30197525e-02,
6.18334711e-02, 3.73557806e-01, 3.17755789e-02,
2.75026381e-01, -2.88109362e-01, -2.02475563e-01,
1.61148801e-01, -2.17942014e-01, 1.06317997e-01,
2.66866386e-03, -2.73037195e-01, 7.52992183e-02,
7.77817965e-02, 2.63209343e-02, 9.45713371e-02,
-2.83376873e-01, 2.55718678e-02, 1.71330452e-01,
4.77515757e-02, -1.29865110e-02, -9.19252783e-02,
2.20202535e-01, -1.98967785e-01, 3.41536522e-01,
8.68079364e-02, -8.85302424e-02, -9.05130804e-03],
[ 1.20346710e-01, 1.25420868e-01, -3.62598658e-01,
2.23636329e-01, -7.36351609e-02, 1.05015352e-01,
4.35760617e-03, -8.31906050e-02, 2.28634894e-01,
-1.49365649e-01, 8.16624165e-02, -3.14154029e-01,
8.76248032e-02, -5.55702299e-02, 8.57660025e-02,
-2.31696367e-02, -2.23352492e-01, 2.85403579e-02,
1.04186803e-01, 9.74222422e-02, 8.70420933e-02,
1.58863813e-02, 1.73726633e-01, 8.37596357e-02,
-1.77873820e-02, -6.53727055e-02, 5.05047441e-02,
-2.29445338e-01, 5.72724342e-02, 2.50912070e-01],
[-3.97108495e-01, 1.00127161e-01, -7.08009750e-02,
-1.62648782e-01, -1.39103666e-01, 1.61610574e-01,
1.60218611e-01, -7.88767636e-03, 5.42842150e-02,
1.65928125e-01, 2.23704860e-01, 3.66966009e-01,
6.14991188e-02, -4.85754609e-02, 3.34523857e-01,
-7.26056099e-02, 1.93905726e-01, -6.00316077e-02,
-3.03020537e-01, 1.71825185e-01, 2.90646702e-01,
2.13969320e-01, 4.79212999e-02, 9.81050730e-02,
1.03307858e-01, 1.27324745e-01, -5.79782724e-02,
1.52468219e-01, -3.66631836e-01, 1.77991658e-01],
[ 1.05444938e-02, 8.89388323e-02, 2.68580914e-01,
3.43994796e-01, 7.01821893e-02, 3.07618678e-01,
-1.59211323e-01, 2.65379250e-03, 4.29789275e-02,
-6.48853630e-02, -1.19794458e-02, -1.01899371e-01,
3.90142053e-02, -2.12668777e-02, 1.34829685e-01,
-1.60709843e-01, -2.75420129e-01, -1.10718116e-01,
1.22148365e-01, -2.56418556e-01, -8.86735022e-02,
3.57692540e-02, -1.72454745e-01, 1.25590444e-01,
1.48163527e-01, 9.70243216e-02, 1.78432480e-01,
8.07075799e-02, 7.18790591e-02, 8.29363018e-02],
[ 1.46744028e-01, 1.35470942e-01, -5.01304418e-02,
-2.56607473e-01, -2.36472189e-01, 2.16720849e-01,
1.36735886e-01, -3.88280898e-02, 3.90521705e-01,
3.07039917e-03, 1.45443454e-01, -7.04357922e-02,
-1.51877582e-01, 6.67887330e-02, -2.40254313e-01,
3.11602056e-01, 6.78158849e-02, -2.53712773e-01,
-2.01759011e-01, -2.26584569e-01, 1.38092458e-01,
-1.41293734e-01, 3.44138265e-01, 1.29559129e-01,
2.85035908e-01, 6.18830621e-02, -2.29603484e-01,
2.91220188e-01, -8.08279067e-02, -3.44568819e-01],
[-1.89113468e-02, 1.27276510e-01, 1.18291497e-01,
-8.91744196e-02, -3.88738513e-02, -6.17537051e-02,
-1.13684982e-01, -6.69639111e-02, -3.41004312e-01,
-2.61811137e-01, -1.48398668e-01, -2.04433903e-01,
-3.67532670e-03, 5.23980856e-02, 6.42350912e-02,
8.27208012e-02, 6.50502890e-02, 2.10662231e-01,
-1.37932986e-01, 3.07054341e-01, -9.90331918e-02,
-1.52072370e-01, -9.87340361e-02, 3.40974480e-02,
9.54811275e-02, -2.24524289e-01, 3.59725446e-01,
2.40096241e-01, -3.08321267e-02, -1.03981838e-01],
[-1.48987636e-01, 1.81017682e-01, 1.63939446e-02,
-2.69196749e-01, -2.44855881e-04, -3.97299677e-02,
-4.50673252e-02, 3.15325707e-02, -1.54727101e-01,
-1.31139323e-01, 4.07651365e-02, 1.12029955e-01,
-1.95583731e-01, -1.76382348e-01, 1.18832231e-01,
-2.69036591e-01, 3.04102361e-01, -8.61337632e-02,
-7.33592659e-02, 2.46652514e-02, 2.24634409e-01,
2.91894495e-01, 5.09455353e-02, -3.00022811e-01,
-5.17807007e-02, -7.68917203e-02, -3.71362567e-01,
1.96652308e-01, -1.05256483e-01, -2.75513947e-01],
[ 5.38429469e-02, 3.15863788e-02, -4.09971178e-03,
-4.45098132e-02, -1.00758657e-01, -6.42608255e-02,
3.13616514e-01, -1.36063650e-01, 1.24328464e-01,
-1.09250084e-01, -3.94054949e-02, 2.20205531e-01,
-7.17411488e-02, 8.70945156e-02, 4.95519042e-02,
3.63173425e-01, 6.60566986e-03, -1.58391029e-01,
9.21002179e-02, -1.74151346e-01, -1.42024517e-01,
3.83424699e-01, 2.24799365e-02, 7.36033916e-03,
-2.80536860e-02, -1.58879921e-01, 3.91074270e-02,
-9.43726748e-02, 2.17871547e-01, 1.44041181e-02],
[-9.30076092e-02, -1.98024854e-01, -3.14120024e-01,
1.71728119e-01, 1.33051425e-01, -1.41134948e-01,
-2.13182554e-01, -1.62383363e-01, 9.43484604e-02,
1.46695167e-01, 1.86109841e-02, -2.21152157e-02,
-1.46707222e-01, 3.92628670e-01, -2.01349884e-01,
1.09045491e-01, -2.89507657e-02, -1.52118295e-01,
1.74316362e-01, 7.77858645e-02, 9.58564728e-02,
1.02938607e-01, 1.89557344e-01, -1.57446101e-01,
9.71454233e-02, -2.65448719e-01, -5.12915403e-02,
8.04104656e-02, 5.85189909e-02, 2.47814834e-01],
[ 5.50863743e-02, 2.30716497e-01, 2.78447568e-03,
-5.15906960e-02, -1.33431196e-01, 1.72307670e-01,
-3.83732319e-02, 1.72320321e-01, -1.20988593e-01,
-1.21835843e-01, -1.65673420e-01, -8.69580209e-02,
-1.52244121e-02, -3.16982597e-01, 1.96166813e-01,
-2.08498746e-01, 3.45463872e-01, 2.52553254e-01,
3.05833966e-02, -2.36524627e-01, -2.45504826e-02,
-7.39030093e-02, 1.80508122e-01, 8.00530612e-02,
2.32508481e-02, 5.16086072e-02, 8.30580592e-02,
-1.09614551e-01, 2.05054462e-01, 5.47491461e-02],
[-2.92942554e-01, 1.58341527e-02, -5.25936484e-04,
7.54894614e-02, 1.75417557e-01, 1.60730407e-01,
5.91714680e-03, -2.53270268e-02, -2.71934599e-01,
-2.63609916e-01, 1.75930902e-01, 2.68442690e-01,
-1.60669044e-01, 4.51646745e-03, -4.13359016e-01,
1.32156700e-01, 2.06499949e-01, -9.21603739e-02,
-3.21212083e-01, 2.94354409e-02, -3.51502746e-02,
-1.13733053e-01, 7.04959035e-03, 6.02722168e-02,
3.10117364e-01, 3.13739121e-01, 1.54746130e-01,
2.38439858e-01, 2.05240831e-01, 1.65418029e-01],
[-1.31756335e-01, -9.71664637e-02, -2.11081415e-01,
3.06450367e-01, 1.31590337e-01, 2.54443496e-01,
-2.31858999e-01, 2.65551567e-01, -2.02050045e-01,
2.71081388e-01, -1.37878060e-02, -1.70021936e-01,
-1.65380538e-03, 9.54623818e-02, 2.84351885e-01,
-1.01871341e-01, 2.01485753e-02, 9.12380219e-02,
-6.15977198e-02, 4.80147898e-02, -1.39563948e-01,
3.02652836e-01, 1.64852113e-01, -2.00138420e-01,
9.66768116e-02, -9.22664106e-02, -9.24981087e-02,
-2.47369722e-01, 2.81407475e-01, 1.83511779e-01],
[-1.95965677e-01, -2.62236357e-01, -2.39685029e-02,
1.40571550e-01, -5.11714965e-02, 9.83205438e-02,
1.00091219e-01, 8.76448601e-02, -2.09155291e-01,
-4.81763929e-02, 1.15129888e-01, -1.07420832e-02,
6.28655702e-02, -1.43948466e-01, 1.83107510e-01,
1.86010495e-01, -1.79242194e-02, -1.05096996e-02,
2.98826218e-01, 2.92384177e-02, 1.50228307e-01,
-2.57354379e-02, 1.05159089e-01, 3.26832682e-01,
-6.47494942e-02, -7.94630796e-02, -3.30955148e-01,
-3.33948135e-01, 1.46547139e-01, -1.86091036e-01],
[-4.33291495e-02, 1.88203737e-01, 3.16037238e-02,
1.19405270e-01, -2.26772234e-01, 9.43277180e-02,
-8.72166008e-02, 2.56006062e-01, -1.48900136e-01,
9.94460434e-02, 1.87725961e-01, -1.95279077e-01,
8.27598572e-02, -1.46700770e-01, -1.25416532e-01,
-1.37769252e-01, 9.57628340e-02, -2.98057973e-01,
1.05415285e-01, -1.18127242e-01, -2.35549033e-01,
-1.76974535e-02, -2.97596186e-01, 4.32238877e-02,
-4.16904092e-02, 4.33115363e-02, -1.08654559e-01,
3.52643073e-01, 2.74524629e-01, 1.66427344e-02],
[ 1.77624241e-01, -7.08051473e-02, -1.25589028e-01,
-1.33987144e-01, -2.28391200e-01, -2.04036236e-01,
7.88579136e-02, 1.33848041e-01, -3.16916883e-01,
-1.34889603e-01, -8.19702297e-02, 2.77304649e-02,
2.47642547e-02, 1.05886340e-01, -2.58318663e-01,
-2.43119746e-01, 3.77329141e-02, 5.44693917e-02,
-1.35345489e-01, -1.10017732e-01, -3.13927293e-01,
5.12518585e-02, -5.12987375e-04, -1.58919290e-01,
-1.70748994e-01, 2.36288786e-01, 8.46761465e-02,
1.05235279e-02, 8.87275785e-02, 1.64180309e-01],
[ 3.75420362e-01, 6.81641251e-02, 7.72201270e-02,
-4.03955430e-01, -5.49944937e-02, 8.78675282e-03,
3.32516432e-01, 1.84740856e-01, -7.99065679e-02,
1.99955359e-01, 1.41463295e-01, -1.58552006e-01,
-1.59612060e-01, -1.87727511e-01, -1.75989211e-01,
-1.34044841e-01, 2.13294059e-01, -1.30978391e-01,
1.06950596e-01, 2.87034482e-01, 1.33578539e-01,
3.30342472e-01, -2.68601805e-01, -2.23761916e-01,
2.93608367e-01, -3.48806530e-02, 1.48320645e-01,
1.26249835e-01, -2.08334908e-01, 5.82257062e-02],
[ 1.78356782e-01, -1.20771334e-01, -7.79799819e-02,
1.64743349e-01, 1.32354870e-01, 1.11938149e-01,
-2.97888666e-02, -1.83453709e-02, 3.70834023e-02,
-3.93104136e-01, 4.38099802e-02, 1.25917226e-01,
-2.03110918e-01, 1.49912372e-01, 8.95697773e-02,
-1.43052682e-01, 3.78358483e-01, 2.53705919e-01,
9.40865725e-02, 2.99577773e-01, 1.92300975e-03,
-3.35528851e-02, -3.18194628e-01, 8.42917114e-02,
-1.03216350e-01, 1.54624790e-01, 1.52044371e-01,
-3.53974104e-03, -1.56484321e-01, 3.17795128e-02],
[-3.34984601e-01, -1.87048599e-01, 1.26603231e-01,
2.71421999e-01, -4.17920053e-02, -1.65970325e-02,
-1.58861458e-01, 1.46432027e-01, 1.03171512e-01,
1.39130712e-01, 2.62030542e-01, 3.82863283e-02,
1.70419857e-01, 2.81390846e-01, 3.02026719e-02,
-2.17159256e-01, 5.98860085e-02, 2.09414154e-01,
2.78205037e-01, -3.02840352e-01, 2.17414171e-01,
6.87691420e-02, -1.62359178e-02, -9.31997299e-02,
1.67161644e-01, -5.67281246e-02, -1.67869329e-02,
-3.39672238e-01, 4.14884984e-02, 2.41749167e-01]], dtype=float32),
},
Dense_1: {
bias: array([ 6.7055225e-06, 6.7055225e-06, 6.7055225e-06, 6.7055225e-06,
6.7055225e-06, 6.7055225e-06, -4.6221808e-02, -3.4295321e-02,
4.9075842e-02, 6.9026187e-02, 5.4597139e-02, 6.7055225e-06,
6.7055225e-06, -2.3682773e-02, -1.4512062e-02], dtype=float32),
kernel: array([[-2.15783074e-01, -8.00836086e-02, -3.41679364e-01,
-3.61634940e-02, -4.04338688e-02, -1.92787528e-01,
7.80968368e-02, 3.84697020e-01, -2.70941556e-01,
3.09228301e-02, -1.12026677e-01, 1.21513918e-01,
3.84846568e-01, 1.29461914e-01, 3.02034169e-02],
[ 3.03589344e-01, 1.49001822e-01, 2.24479139e-02,
1.72641233e-01, 1.11690164e-03, -1.60670713e-01,
1.60862997e-01, -2.07054362e-01, 3.59371305e-03,
8.14582109e-02, -8.32385570e-02, -2.71430016e-01,
3.29026878e-01, 6.39057159e-03, -2.08910599e-01],
[ 1.42273605e-02, 2.52373248e-01, 2.65929043e-01,
-7.87675530e-02, 2.57075429e-02, 1.37467578e-01,
-3.03784698e-01, -3.00662845e-01, 2.28536919e-01,
7.39716142e-02, -5.44494987e-02, 6.82624280e-02,
-1.14752382e-01, -4.36386168e-02, -2.58043408e-03],
[-3.78126502e-02, 3.64015222e-01, -6.23669326e-02,
1.66056380e-01, -1.78249836e-01, 1.75972238e-01,
-2.43168324e-01, -3.16800296e-01, -6.96922392e-02,
-5.26114106e-02, 1.36705399e-01, -5.23618162e-02,
1.28879577e-01, 2.59115607e-01, -6.60537034e-02],
[ 1.62194669e-02, -5.36157191e-03, 6.25325292e-02,
3.77438962e-03, 5.34493476e-02, 9.57345217e-02,
-3.37895215e-01, 1.54839545e-01, -5.76823950e-05,
3.46475482e-01, -1.44219398e-02, -1.57836825e-02,
1.02470160e-01, -2.66014189e-02, 1.38210922e-01],
[-1.54244944e-01, -1.89163774e-01, -1.25277504e-01,
-1.53509825e-01, -2.08621144e-01, -3.57687324e-02,
1.82844996e-02, 1.68684870e-01, 5.94071597e-02,
3.75792086e-02, 7.39661902e-02, 3.35429460e-02,
6.90625012e-02, 1.51754677e-01, -2.60873646e-01],
[ 9.05385166e-02, 3.13364267e-01, -1.75755590e-01,
5.33922911e-02, -1.96638182e-01, 2.29208365e-01,
2.01122493e-01, 1.39216185e-01, 3.18074644e-01,
-2.20635474e-01, 1.32455468e-01, -4.49397564e-02,
-3.59446257e-02, 3.02191943e-01, -1.93763226e-02],
[-9.32638943e-02, -9.22463536e-02, 2.90949643e-01,
1.38086572e-01, 1.45970643e-01, 1.11227736e-01,
1.92929432e-01, -2.20268294e-01, 8.51592422e-03,
-1.14801973e-02, -1.08394876e-01, -4.92374301e-02,
1.53573290e-01, 3.89962226e-01, -1.53074652e-01],
[ 2.90715694e-03, 1.83663428e-01, 3.76528651e-02,
-1.73860639e-02, 1.83179498e-01, 4.10255790e-03,
9.65543687e-02, 7.96876550e-02, 2.19800651e-01,
2.27372974e-01, -1.51361659e-01, 2.04350561e-01,
1.18743330e-01, -3.37018371e-01, 1.12518296e-01],
[-3.69710326e-02, 5.37259430e-02, -5.55092096e-03,
-4.27453220e-03, -1.99901059e-01, -1.28307149e-01,
5.43928742e-02, 1.31580710e-01, -1.68890849e-01,
-3.78164351e-01, -6.06988072e-02, 2.29855090e-01,
-1.12242132e-01, 3.79682928e-02, -3.86388510e-01],
[ 2.35311255e-01, 1.66386098e-01, -8.38957876e-02,
2.03469038e-01, -2.03613058e-01, -7.19606429e-02,
1.12464979e-01, 2.45347440e-01, 2.39544034e-01,
-1.39806151e-01, -2.63890773e-02, -1.12561002e-01,
-2.70868719e-01, -4.71335649e-03, 1.29959971e-01],
[-5.57020754e-02, -3.46530616e-01, 2.98494935e-01,
-1.66801214e-01, 6.14305437e-02, 9.28813517e-02,
1.35294989e-01, -1.34979218e-01, -7.89761543e-04,
-2.50806570e-01, 1.01332352e-01, 1.00093633e-01,
1.26480818e-01, 1.32050708e-01, 2.52038181e-01],
[ 1.38003528e-02, -1.96471453e-01, 1.48797393e-01,
3.88497114e-02, -1.44033611e-01, 3.50036263e-01,
-3.26102078e-02, -1.19598106e-01, -3.50412369e-01,
-9.01353359e-02, 1.68153301e-01, -1.73634619e-01,
-2.64526069e-01, 1.89368397e-01, -3.03420037e-01],
[ 1.53409019e-01, -1.65645137e-01, 2.75816798e-01,
-2.61331946e-02, 9.50511992e-02, -1.15451679e-01,
-2.91420221e-02, 9.09120291e-02, 1.60791710e-01,
1.59705788e-01, -1.63535506e-01, -2.39277333e-02,
2.17306137e-01, -3.75522256e-01, 3.61660779e-01],
[ 2.75457501e-01, -5.11639714e-02, 3.02877873e-02,
3.83744717e-01, 1.89788371e-01, -3.05498004e-01,
-1.36434168e-01, -9.84745473e-02, -8.30809176e-02,
-1.72840640e-01, 2.08918750e-03, 2.70356536e-01,
-1.43099576e-02, 1.42251849e-02, -2.30559021e-01],
[-1.12811208e-01, -8.90487880e-02, 5.26717752e-02,
-3.34500819e-02, 1.79550871e-01, 1.52729020e-01,
-5.19486815e-02, 1.09064624e-01, 2.16731712e-01,
-5.77685237e-02, 2.93150067e-01, -2.72270977e-01,
2.27185652e-01, -4.16627526e-02, 8.24269503e-02],
[ 1.12723231e-01, 1.53931171e-01, -1.41348794e-01,
-1.82246700e-01, -2.60128081e-01, 2.28294432e-01,
-1.41973585e-01, -2.20511407e-01, 4.12718713e-01,
2.26580381e-01, 3.02141160e-02, -4.08996940e-02,
-1.77697018e-01, 2.13378891e-01, 2.34325364e-01],
[-8.89122784e-02, -1.33175939e-01, 1.13867760e-01,
-3.15929264e-01, -5.33000082e-02, -3.32492590e-02,
1.61763191e-01, -4.45095897e-02, -3.13157290e-02,
1.08764470e-02, -1.84109345e-01, -4.28237021e-02,
-6.48176223e-02, -5.03456593e-03, 7.79272169e-02],
[ 1.07352838e-01, 1.53558820e-01, 9.59271789e-02,
1.95107758e-02, -7.16409087e-03, -2.52965033e-01,
-2.34644741e-01, 3.56166244e-01, 1.75924093e-01,
-1.80997580e-01, -2.52380699e-01, -5.56088388e-02,
-2.03573480e-01, -1.34791806e-01, 1.43947557e-01],
[-3.20794344e-01, 1.52731538e-02, 3.17644477e-01,
3.08467388e-01, 6.71731085e-02, -2.07317039e-01,
6.92533553e-02, 8.92890543e-02, -5.71464747e-02,
6.77597970e-02, -2.40857795e-01, 9.92331654e-02,
-3.97704244e-01, -1.44478083e-01, 1.78948924e-01],
[ 1.78290546e-01, -3.73491168e-01, -3.44891042e-01,
-1.85131654e-01, -1.25298679e-01, -3.59349012e-01,
-2.19286203e-01, 4.03494358e-01, -5.62228560e-02,
-1.21513605e-01, 3.14101666e-01, -5.40824682e-02,
1.32633939e-01, 1.56066120e-02, 2.20855936e-01],
[-1.59709528e-01, -1.92597717e-01, -3.99962455e-01,
-1.37944147e-01, 2.50320375e-01, -3.02701652e-01,
2.37861603e-01, 3.54276657e-01, -2.65088379e-02,
2.30067194e-01, 1.21199116e-01, -1.71849012e-01,
-1.15864873e-02, 2.79654086e-01, 6.52542561e-02],
[-3.87499511e-01, -2.14815795e-01, -3.50072563e-01,
-3.74599099e-01, 3.15827131e-03, -3.80146921e-01,
-2.75051177e-01, 4.92296517e-02, -3.42986435e-02,
2.64439374e-01, 2.24484161e-01, -9.03718919e-02,
-1.66824907e-01, 2.33413532e-01, -2.70957470e-01],
[-2.09726006e-01, -1.01519674e-02, 1.65578052e-01,
2.08754063e-01, -1.90129966e-01, -3.17800641e-01,
-3.11262459e-02, -6.45868927e-02, 3.97725523e-01,
-2.66408682e-01, 3.11382055e-01, -6.38214052e-02,
-3.96969140e-01, 1.07675835e-01, 1.15448385e-02],
[ 1.90294176e-01, 2.62975395e-02, 1.02480963e-01,
-2.17438534e-01, -2.62192190e-01, 1.49379149e-01,
1.36339113e-01, -2.20250994e-01, -6.47595972e-02,
2.78646350e-01, -1.04613230e-01, -3.21846724e-01,
-3.60531211e-02, 4.10647094e-02, -2.01862305e-02],
[ 4.84227687e-02, -1.74249634e-01, 1.22636214e-01,
3.27275246e-01, -8.30558389e-02, -3.14863682e-01,
1.06451482e-01, 9.95579511e-02, 7.17695206e-02,
-2.05804944e-01, 4.14260328e-02, -2.82542408e-03,
1.59710690e-01, 1.95358038e-01, -2.18680993e-01],
[ 1.80844143e-01, 9.28273797e-02, -2.79074967e-01,
-3.21841598e-01, 8.46190900e-02, -1.31674752e-01,
-2.22147837e-01, 6.93762749e-02, 1.08454213e-01,
-1.54395998e-01, -2.52935737e-02, 3.96427214e-02,
-1.77300423e-02, 4.08218354e-02, 1.57031059e-01],
[ 1.72292829e-01, -2.74210989e-01, 3.91593874e-02,
-1.06435806e-01, 1.53482422e-01, -4.07753527e-01,
-1.45187899e-01, -1.97193578e-01, 1.51643381e-01,
8.71199816e-02, -1.80934519e-02, 3.16381007e-02,
-3.16615790e-01, -8.88970047e-02, -3.15811843e-01],
[-9.74183232e-02, -2.85981894e-02, 3.55937093e-01,
-2.71050155e-01, 3.29447448e-01, 8.03399384e-02,
3.03259671e-01, -1.92141250e-01, -1.79099917e-01,
2.89139271e-01, 9.35540348e-02, -2.52466857e-01,
1.29278764e-01, 3.87039840e-01, -3.96069586e-01],
[ 6.32186532e-02, 8.49261880e-02, -1.27169192e-02,
1.09564900e-01, 1.01468608e-01, -1.91385388e-01,
2.12272123e-01, 3.93263578e-01, 3.21644545e-03,
1.96522057e-01, -1.37627035e-01, 2.72568524e-01,
-9.58500952e-02, -5.87881505e-02, 3.29161406e-01]], dtype=float32),
},
Dense_2: {
bias: array([-1.9683748e-02, -2.1595210e-03, 6.7055225e-06, -5.0753772e-02,
1.3499202e-01, 4.5890406e-02, -1.5943155e-02, 6.7055225e-06],
dtype=float32),
kernel: array([[ 1.70742482e-01, -9.91618633e-02, -2.29330659e-02,
1.22218162e-01, -1.06615439e-01, -3.67971659e-02,
1.41981930e-01, 8.84072632e-02],
[ 1.80515304e-01, 1.18458852e-01, 5.27178884e-01,
-1.53531089e-01, -3.87934148e-02, -1.23133227e-01,
6.29240274e-03, 1.60064697e-02],
[ 1.46227553e-01, 5.25590360e-01, -3.44348907e-01,
-9.29221213e-02, 1.28932714e-01, -1.73384249e-02,
1.19381860e-01, -4.11042988e-01],
[ 2.49569118e-01, 2.14901835e-01, 3.24754208e-01,
-4.91983056e-01, -1.14351794e-01, -2.11403668e-02,
7.69451857e-02, 3.31384718e-01],
[ 4.12717015e-02, -1.21542990e-01, -3.10934186e-01,
3.62211466e-01, -2.74409771e-01, -5.17052352e-01,
-5.41649759e-02, 5.59365511e-01],
[-4.59556133e-02, -6.34765029e-02, 1.12813324e-01,
3.72479081e-01, 1.19602963e-01, 2.29919955e-01,
-1.38893351e-01, 4.47091460e-02],
[-6.79109395e-02, -3.46875697e-01, -4.84580278e-01,
1.12352774e-01, -1.92538321e-01, -3.71922582e-01,
5.66129565e-01, -2.42807910e-01],
[ 3.35851789e-01, -1.21225804e-01, 3.77379179e-01,
4.12614942e-01, 4.95961308e-02, -3.28645080e-01,
4.95315671e-01, -1.31850213e-01],
[ 3.26091319e-01, 6.93900436e-02, 2.03807250e-01,
1.04140580e-01, 3.88821214e-01, 1.70190245e-01,
-2.11963385e-01, 5.57184219e-04],
[ 1.63282499e-01, -1.22407228e-02, -4.44926322e-01,
-9.26266760e-02, 3.41380268e-01, 4.87684071e-01,
-1.46689430e-01, -2.39341646e-01],
[-2.73013204e-01, 1.98071614e-01, 1.63841352e-01,
-4.02920216e-01, 1.04726478e-01, 3.63073707e-01,
-5.87156415e-03, -2.55927563e-01],
[ 2.71851361e-01, -6.50364012e-02, 5.75378239e-02,
8.61035287e-02, 4.62561399e-02, 6.84097409e-02,
-3.49434376e-01, 3.20657223e-01],
[ 3.87204051e-01, 1.02552548e-01, -3.67724180e-01,
-1.37631521e-01, -2.76330113e-03, -9.74713266e-03,
-2.03522891e-02, -3.78593743e-01],
[ 1.42845362e-01, -1.62034005e-01, -4.73217815e-01,
2.21667886e-01, 9.57852453e-02, -3.21318150e-01,
1.93716571e-01, -1.64225832e-01],
[-8.96123946e-02, 1.43070787e-01, -3.05288434e-02,
4.30531144e-01, -3.90004218e-01, 5.45606494e-01,
-9.91108418e-02, -3.92888844e-01]], dtype=float32),
},
Dense_3: {
bias: array([0.3012179], dtype=float32),
kernel: array([[-0.04792835],
[ 0.6472556 ],
[ 0.19396764],
[-0.19879013],
[ 0.4868285 ],
[ 0.70990056],
[-0.05229847],
[-0.20487049]], dtype=float32),
},
},
})
最后,让我们验证模型。
[8]:
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')
auc=0.9921388797903701
实验到此结束。