Neural Network with SPU#
Please read lab Logistic Regression On SPU first if you have not。
In lab Logistic Regression On SPU, we have showed how to use SecretFlow/SPU to convert a plaintext JAX training program to a secure MPC training program.
In this lab, the idea is quite similar but this time we will work with a Neural Network model.
We are going to use the same dataset and all the settings as lab Logistic Regression On SPU.
And first, let’s work out the plaintext model.
The following codes are demos only. It’s NOT for production due to system security concerns, please DO NOT use it directly in production.
This tutorial needs more resources than 8c16g, which is the minimum requirement of SecretFlow.
Train a model with JAX/FLAX#
Load the Dataset#
The below is just copied from lab Logistic Regression On SPU. I’m not going to explain again.
[ ]:
import sys
!{sys.executable} -m pip install flax==0.6.0
[1]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer
def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
scaler = Normalizer(norm='max')
x, y = load_breast_cancer(return_X_y=True)
x = scaler.fit_transform(x)
x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=0.2, random_state=42
)
if train:
if party_id:
if party_id == 1:
return x_train[:, 15:], _
else:
return x_train[:, :15], y_train
else:
return x_train, y_train
else:
return x_test, y_test
Define the Model#
We are going to use a 4-layer MLP model with a ReLU activation function here.
[2]:
from typing import Sequence
import flax.linen as nn
FEATURES = [30, 15, 8, 1]
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
Then we define the training method here.
[3]:
import jax.numpy as jnp
def predict(params, x):
# TODO(junfeng): investigate why need to have a duplicated definition in notebook,
# which is not the case in a normal python program.
from typing import Sequence
import flax.linen as nn
FEATURES = [30, 15, 8, 1]
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
return MLP(FEATURES).apply(params, x)
def loss_func(params, x, y):
pred = predict(params, x)
def mse(y, pred):
def squared_error(y, y_pred):
return jnp.multiply(y - y_pred, y - y_pred) / 2.0
return jnp.mean(squared_error(y, pred))
return mse(y, pred)
def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):
x = jnp.concatenate((x1, x2), axis=1)
xs = jnp.array_split(x, len(x) / n_batch, axis=0)
ys = jnp.array_split(y, len(y) / n_batch, axis=0)
def body_fun(_, loop_carry):
params = loop_carry
for (x, y) in zip(xs, ys):
_, grads = jax.value_and_grad(loss_func)(params, x, y)
params = jax.tree_util.tree_map(
lambda p, g: p - step_size * g, params, grads
)
return params
params = jax.lax.fori_loop(0, n_epochs, body_fun, params)
return params
def model_init(n_batch=10):
model = MLP(FEATURES)
return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))
Validate the Model#
We use AUC as the validation metric.
[4]:
from sklearn.metrics import roc_auc_score
def validate_model(params, X_test, y_test):
y_pred = predict(params, X_test)
return roc_auc_score(y_test, y_pred)
BUILD Together#
Let’s put everything together and train a plaintext NN model!
[5]:
import jax
# Load the data
x1, _ = breast_cancer(party_id=1, train=True)
x2, y = breast_cancer(party_id=2, train=True)
# Hyperparameter
n_batch = 10
n_epochs = 10
step_size = 0.01
# Train the model
init_params = model_init(n_batch)
params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)
# Test the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
auc=0.9921388797903701
ust keep the number of AUC in mind, we are going to repeat the training with SPU. Let’s do that magic!
Train a Model with SPU#
[6]:
import secretflow as sf
# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob'], address='local')
alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))
x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params = model_init(n_batch)
device = spu
x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)
init_params_ = sf.to(alice, init_params).to(device)
params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(
x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size
)
2022-08-25 10:54:11.675312: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944990) 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944994) 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944991) 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944992) 2022-08-25 10:54:17.064685: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944996) 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=2944988) 2022-08-25 10:54:17.011062: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(_run pid=2944995) 2022-08-25 10:54:17.145220: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(_run pid=2944993) 2022-08-25 10:54:17.154674: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(SPURuntime pid=2944988) I0825 10:54:18.801868 2944988 external/com_github_brpc_brpc/src/brpc/server.cpp:1066] Server[yacl::link::internal::ReceiverServiceImpl] is serving on port=44351.
(SPURuntime pid=2944988) I0825 10:54:18.801938 2944988 external/com_github_brpc_brpc/src/brpc/server.cpp:1069] Check out http://k69b13338.eu95sqa:44351 in web browser.
(_run pid=2944995) 2022-08-25 10:54:18,906,906 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(SPURuntime pid=2944990) I0825 10:54:18.863563 2944990 external/com_github_brpc_brpc/src/brpc/server.cpp:1066] Server[yacl::link::internal::ReceiverServiceImpl] is serving on port=22227.
(SPURuntime pid=2944990) I0825 10:54:18.863633 2944990 external/com_github_brpc_brpc/src/brpc/server.cpp:1069] Check out http://k69b13338.eu95sqa:22227 in web browser.
(SPURuntime pid=2944988) I0825 10:54:18.902659 2945551 external/com_github_brpc_brpc/src/brpc/socket.cpp:2202] Checking Socket{id=0 addr=127.0.0.1:22227} (0x55f83a4820c0)
(SPURuntime pid=2944988) I0825 10:54:18.902798 2945551 external/com_github_brpc_brpc/src/brpc/socket.cpp:2262] Revived Socket{id=0 addr=127.0.0.1:22227} (0x55f83a4820c0) (Connectable)
(_run pid=2944995) [2022-08-25 10:54:18.906] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(_run pid=2944993) [2022-08-25 10:54:18.941] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(_run pid=2944993) 2022-08-25 10:54:18,941,941 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=2944996) 2022-08-25 10:54:18,952,952 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Let’s check params from SPU program.
[7]:
params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
params = sf.reveal(params_spu)
print(params)
(SPURuntime pid=2944988) [2022-08-25 10:54:29.343] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(SPURuntime pid=2944990) [2022-08-25 10:54:29.328] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
FrozenDict({
params: {
Dense_0: {
bias: array([-5.1895231e-03, 1.8576235e-03, 6.7055225e-06, 2.3722053e-03,
3.1328186e-02, 6.7055225e-06, -1.5438795e-02, 6.7055225e-06,
6.7055225e-06, -4.1383177e-02, 6.7055225e-06, -2.2077844e-02,
6.7055225e-06, 6.7055225e-06, -4.1123331e-03, 6.7055225e-06,
3.9519504e-02, -2.0013824e-02, 6.7055225e-06, -1.7026097e-02,
-2.0100027e-03, -1.2290031e-02, 3.5510257e-02, 6.7055225e-06,
1.0400787e-02, 6.7055225e-06, 6.7055225e-06, 6.7055225e-06,
6.7055225e-06, -2.4091452e-03], dtype=float32),
kernel: array([[-1.48707509e-01, -2.35312745e-01, -1.49370432e-01,
-1.55658722e-02, -1.33344889e-01, 1.91762865e-01,
-3.67970318e-02, -3.74531150e-02, -1.41760916e-01,
3.24019790e-02, 1.26554981e-01, -4.02508259e-01,
-1.68948472e-01, 2.13856786e-01, -1.38430491e-01,
1.05858222e-01, -1.16115242e-01, 3.86231482e-01,
5.96616715e-02, 6.30981326e-02, 7.79366642e-02,
-1.31276995e-02, -2.88023233e-01, -9.60215628e-02,
1.11036777e-01, -8.54356140e-02, 7.54796267e-02,
-4.11889851e-02, -3.82702172e-01, 2.37623483e-01],
[ 1.77950412e-01, 2.29401365e-01, -2.44401962e-01,
-1.48464620e-01, 3.36820692e-01, 2.58408487e-02,
-4.21461910e-02, 4.10513222e-01, 3.24393630e-01,
-1.64218560e-01, 8.17326903e-02, 5.25916368e-02,
3.11348379e-01, 2.92945385e-01, 1.22726962e-01,
-3.87524545e-01, -3.85492414e-01, -6.53883070e-02,
-2.59158403e-01, -3.32410604e-01, -3.15876693e-01,
-2.91079849e-01, -6.01480752e-02, 2.29791373e-01,
1.01019114e-01, -1.30936205e-02, 1.78829342e-01,
-2.32158348e-01, 3.82750481e-01, 3.79820764e-02],
[ 1.42894119e-01, -2.13525295e-02, 1.68190986e-01,
8.99314433e-02, -3.89296085e-01, -4.85262126e-02,
1.38709426e-01, -5.80671132e-02, 2.84729242e-01,
-1.26564085e-01, 2.57159889e-01, 9.64836925e-02,
1.16704285e-01, -1.90519944e-01, -3.98443341e-02,
-9.57311541e-02, 7.24758208e-02, 1.46388173e-01,
9.22411978e-02, 3.83930653e-02, -2.46246234e-01,
3.76523137e-02, -1.90257579e-02, -2.52097666e-01,
1.70252651e-01, 2.49821901e-01, -2.81400979e-03,
-9.84508097e-02, 2.07754672e-01, 8.81182998e-02],
[-2.25118160e-01, -4.46529686e-03, -4.55757827e-02,
-4.27887589e-02, -1.34788275e-01, -3.42829585e-01,
-9.25171375e-03, 9.01048332e-02, -2.82488763e-01,
2.23106414e-01, 1.17174536e-01, 4.47563082e-02,
-5.02136499e-02, 2.90469885e-01, -2.35435143e-01,
5.99489361e-02, 2.55998850e-01, 1.87139437e-01,
1.46078974e-01, 2.93805301e-01, 3.95906270e-02,
1.63574621e-01, -1.12822607e-01, -6.96890652e-02,
2.61215031e-01, -3.17869127e-01, 8.15212727e-02,
2.55623460e-01, -5.98048419e-02, 1.93995312e-01],
[-2.42546692e-01, 7.96881765e-02, -6.76873624e-02,
-1.17468625e-01, 1.87561423e-01, -6.13799393e-02,
-5.36612719e-02, -6.93449825e-02, 7.92452097e-02,
2.54104435e-02, -3.18573356e-01, 2.87047565e-01,
-6.02750182e-02, -3.01489800e-01, -1.76609412e-01,
7.97384530e-02, 1.61419287e-01, 3.27948928e-01,
2.05150560e-01, 3.03487569e-01, 2.71105886e-01,
2.76556283e-01, 7.07161725e-02, 2.08005011e-01,
-7.33365566e-02, -1.03249148e-01, 1.55347139e-02,
3.17582250e-01, 3.16770792e-01, -6.80912733e-02],
[ 7.31387287e-02, 6.06941879e-02, 5.30773848e-02,
1.78480223e-01, 1.88692018e-01, 1.77043870e-01,
8.07239711e-02, 1.48239732e-02, -4.42493856e-02,
6.05108887e-02, 2.88271666e-01, 5.82368821e-02,
2.64274329e-01, 1.20873421e-01, -2.19291598e-02,
-1.51663244e-01, -4.52437997e-02, 6.89959526e-03,
2.83989072e-01, 1.00405067e-01, -3.03579658e-01,
-3.13743889e-01, -1.75221011e-01, 1.12235680e-01,
1.15922809e-01, -2.76054680e-01, -6.86793476e-02,
6.13613129e-02, 3.08204144e-01, -2.80041754e-01],
[-2.58817077e-01, -1.54795498e-02, 2.77136236e-01,
-3.83958876e-01, 4.00266647e-01, -1.29193932e-01,
-2.80725807e-02, 1.75729364e-01, 1.30319521e-01,
1.56524405e-01, -2.62551606e-02, 2.93114007e-01,
3.55902165e-02, 3.58770192e-02, -2.96767652e-01,
-1.69694602e-01, 3.19763869e-02, 1.61003023e-01,
-1.75322652e-01, -8.83801430e-02, -4.86088842e-02,
-1.05747893e-01, 8.46082866e-02, -4.35629487e-02,
-2.53890634e-01, -9.85630006e-02, -4.16661799e-02,
-4.67738807e-02, -3.35340977e-01, -1.18303642e-01],
[ 2.78338730e-01, -6.96775764e-02, -2.48151869e-02,
4.41243351e-02, -8.29881430e-02, -1.39038265e-03,
8.48067403e-02, -2.48347864e-01, 1.53563991e-01,
-1.36309460e-01, -4.34331596e-03, -6.34510815e-03,
-7.79297352e-02, -1.02945447e-01, 2.40225092e-01,
-2.56992996e-01, 2.98645079e-01, -1.97924539e-01,
-2.75554240e-01, 1.29482746e-02, 1.26732409e-01,
1.27245054e-01, 1.59333318e-01, 1.43134385e-01,
6.40332401e-02, -1.11376449e-01, 5.90692759e-02,
-1.16492316e-01, 1.58721358e-02, -2.03465536e-01],
[ 6.27327412e-02, 8.74824524e-02, -9.10452157e-02,
1.81233525e-01, 1.06525749e-01, -1.38279110e-01,
3.20228636e-01, -1.48054436e-01, -3.18492800e-02,
3.50982606e-01, -1.33533657e-01, 1.76203504e-01,
5.12142032e-02, -4.60760295e-02, -1.33599430e-01,
-2.22507954e-01, 1.66286990e-01, 3.71264696e-01,
-1.16020367e-01, -1.61022544e-02, -8.04677159e-02,
1.58347130e-01, 3.06297302e-01, 6.74358159e-02,
8.46540183e-02, -9.24098492e-03, -3.53501737e-03,
-3.68553460e-01, -1.23027235e-01, 2.18146190e-01],
[ 2.83379406e-02, -1.24499649e-01, 1.79811046e-01,
1.53642461e-01, 5.48425168e-02, 1.91716999e-01,
-9.49321389e-02, 6.86786622e-02, -7.67818838e-02,
-1.93986297e-02, 5.70140183e-02, -3.93389583e-01,
5.28795272e-02, 3.79496545e-01, 2.46415466e-01,
-1.21219859e-01, 4.00151759e-02, -3.80354345e-01,
1.95414037e-01, -9.05120224e-02, 3.20608616e-01,
1.48523450e-02, -3.49244773e-02, 1.11090288e-01,
-3.37234616e-01, -3.06017160e-01, -1.13247246e-01,
1.59685776e-01, 6.75145537e-02, 1.00891382e-01],
[ 1.68052852e-01, 1.94981366e-01, -9.76378322e-02,
-1.45580143e-01, 1.01531252e-01, -3.17420304e-01,
-1.15842447e-01, -2.86557317e-01, -1.01209313e-01,
-1.30138457e-01, 1.97996482e-01, -6.92996681e-02,
1.83074176e-03, -6.13981634e-02, -2.38128498e-01,
1.41838133e-01, 4.12078500e-01, -1.11510590e-01,
-7.69597739e-02, -3.93857658e-02, -5.82328439e-02,
-2.56166637e-01, 1.75529957e-01, -5.77670783e-02,
4.62778807e-02, 1.20462373e-01, 3.14444661e-01,
-1.82372808e-01, -1.62539259e-01, -9.76682454e-02],
[-6.19109869e-02, -1.15577027e-01, 7.26505965e-02,
1.25303209e-01, 2.06839561e-01, 1.57670394e-01,
-8.05730969e-02, -1.94498345e-01, 2.13316083e-02,
2.35428169e-01, -1.77006692e-01, -3.51173997e-01,
-2.20170230e-01, 3.13621610e-02, 1.01004869e-01,
-4.00861502e-01, -1.33804008e-01, -6.59427792e-02,
-1.41224921e-01, -1.72024697e-01, -6.66107237e-02,
9.94113684e-02, -3.09019983e-02, 2.59390533e-01,
6.44736290e-02, -2.50633597e-01, -3.49195153e-02,
8.02383423e-02, 2.55563200e-01, -2.40900025e-01],
[ 8.98088515e-03, -4.00735557e-01, -6.30197525e-02,
6.18334711e-02, 3.73557806e-01, 3.17755789e-02,
2.75026381e-01, -2.88109362e-01, -2.02475563e-01,
1.61148801e-01, -2.17942014e-01, 1.06317997e-01,
2.66866386e-03, -2.73037195e-01, 7.52992183e-02,
7.77817965e-02, 2.63209343e-02, 9.45713371e-02,
-2.83376873e-01, 2.55718678e-02, 1.71330452e-01,
4.77515757e-02, -1.29865110e-02, -9.19252783e-02,
2.20202535e-01, -1.98967785e-01, 3.41536522e-01,
8.68079364e-02, -8.85302424e-02, -9.05130804e-03],
[ 1.20346710e-01, 1.25420868e-01, -3.62598658e-01,
2.23636329e-01, -7.36351609e-02, 1.05015352e-01,
4.35760617e-03, -8.31906050e-02, 2.28634894e-01,
-1.49365649e-01, 8.16624165e-02, -3.14154029e-01,
8.76248032e-02, -5.55702299e-02, 8.57660025e-02,
-2.31696367e-02, -2.23352492e-01, 2.85403579e-02,
1.04186803e-01, 9.74222422e-02, 8.70420933e-02,
1.58863813e-02, 1.73726633e-01, 8.37596357e-02,
-1.77873820e-02, -6.53727055e-02, 5.05047441e-02,
-2.29445338e-01, 5.72724342e-02, 2.50912070e-01],
[-3.97108495e-01, 1.00127161e-01, -7.08009750e-02,
-1.62648782e-01, -1.39103666e-01, 1.61610574e-01,
1.60218611e-01, -7.88767636e-03, 5.42842150e-02,
1.65928125e-01, 2.23704860e-01, 3.66966009e-01,
6.14991188e-02, -4.85754609e-02, 3.34523857e-01,
-7.26056099e-02, 1.93905726e-01, -6.00316077e-02,
-3.03020537e-01, 1.71825185e-01, 2.90646702e-01,
2.13969320e-01, 4.79212999e-02, 9.81050730e-02,
1.03307858e-01, 1.27324745e-01, -5.79782724e-02,
1.52468219e-01, -3.66631836e-01, 1.77991658e-01],
[ 1.05444938e-02, 8.89388323e-02, 2.68580914e-01,
3.43994796e-01, 7.01821893e-02, 3.07618678e-01,
-1.59211323e-01, 2.65379250e-03, 4.29789275e-02,
-6.48853630e-02, -1.19794458e-02, -1.01899371e-01,
3.90142053e-02, -2.12668777e-02, 1.34829685e-01,
-1.60709843e-01, -2.75420129e-01, -1.10718116e-01,
1.22148365e-01, -2.56418556e-01, -8.86735022e-02,
3.57692540e-02, -1.72454745e-01, 1.25590444e-01,
1.48163527e-01, 9.70243216e-02, 1.78432480e-01,
8.07075799e-02, 7.18790591e-02, 8.29363018e-02],
[ 1.46744028e-01, 1.35470942e-01, -5.01304418e-02,
-2.56607473e-01, -2.36472189e-01, 2.16720849e-01,
1.36735886e-01, -3.88280898e-02, 3.90521705e-01,
3.07039917e-03, 1.45443454e-01, -7.04357922e-02,
-1.51877582e-01, 6.67887330e-02, -2.40254313e-01,
3.11602056e-01, 6.78158849e-02, -2.53712773e-01,
-2.01759011e-01, -2.26584569e-01, 1.38092458e-01,
-1.41293734e-01, 3.44138265e-01, 1.29559129e-01,
2.85035908e-01, 6.18830621e-02, -2.29603484e-01,
2.91220188e-01, -8.08279067e-02, -3.44568819e-01],
[-1.89113468e-02, 1.27276510e-01, 1.18291497e-01,
-8.91744196e-02, -3.88738513e-02, -6.17537051e-02,
-1.13684982e-01, -6.69639111e-02, -3.41004312e-01,
-2.61811137e-01, -1.48398668e-01, -2.04433903e-01,
-3.67532670e-03, 5.23980856e-02, 6.42350912e-02,
8.27208012e-02, 6.50502890e-02, 2.10662231e-01,
-1.37932986e-01, 3.07054341e-01, -9.90331918e-02,
-1.52072370e-01, -9.87340361e-02, 3.40974480e-02,
9.54811275e-02, -2.24524289e-01, 3.59725446e-01,
2.40096241e-01, -3.08321267e-02, -1.03981838e-01],
[-1.48987636e-01, 1.81017682e-01, 1.63939446e-02,
-2.69196749e-01, -2.44855881e-04, -3.97299677e-02,
-4.50673252e-02, 3.15325707e-02, -1.54727101e-01,
-1.31139323e-01, 4.07651365e-02, 1.12029955e-01,
-1.95583731e-01, -1.76382348e-01, 1.18832231e-01,
-2.69036591e-01, 3.04102361e-01, -8.61337632e-02,
-7.33592659e-02, 2.46652514e-02, 2.24634409e-01,
2.91894495e-01, 5.09455353e-02, -3.00022811e-01,
-5.17807007e-02, -7.68917203e-02, -3.71362567e-01,
1.96652308e-01, -1.05256483e-01, -2.75513947e-01],
[ 5.38429469e-02, 3.15863788e-02, -4.09971178e-03,
-4.45098132e-02, -1.00758657e-01, -6.42608255e-02,
3.13616514e-01, -1.36063650e-01, 1.24328464e-01,
-1.09250084e-01, -3.94054949e-02, 2.20205531e-01,
-7.17411488e-02, 8.70945156e-02, 4.95519042e-02,
3.63173425e-01, 6.60566986e-03, -1.58391029e-01,
9.21002179e-02, -1.74151346e-01, -1.42024517e-01,
3.83424699e-01, 2.24799365e-02, 7.36033916e-03,
-2.80536860e-02, -1.58879921e-01, 3.91074270e-02,
-9.43726748e-02, 2.17871547e-01, 1.44041181e-02],
[-9.30076092e-02, -1.98024854e-01, -3.14120024e-01,
1.71728119e-01, 1.33051425e-01, -1.41134948e-01,
-2.13182554e-01, -1.62383363e-01, 9.43484604e-02,
1.46695167e-01, 1.86109841e-02, -2.21152157e-02,
-1.46707222e-01, 3.92628670e-01, -2.01349884e-01,
1.09045491e-01, -2.89507657e-02, -1.52118295e-01,
1.74316362e-01, 7.77858645e-02, 9.58564728e-02,
1.02938607e-01, 1.89557344e-01, -1.57446101e-01,
9.71454233e-02, -2.65448719e-01, -5.12915403e-02,
8.04104656e-02, 5.85189909e-02, 2.47814834e-01],
[ 5.50863743e-02, 2.30716497e-01, 2.78447568e-03,
-5.15906960e-02, -1.33431196e-01, 1.72307670e-01,
-3.83732319e-02, 1.72320321e-01, -1.20988593e-01,
-1.21835843e-01, -1.65673420e-01, -8.69580209e-02,
-1.52244121e-02, -3.16982597e-01, 1.96166813e-01,
-2.08498746e-01, 3.45463872e-01, 2.52553254e-01,
3.05833966e-02, -2.36524627e-01, -2.45504826e-02,
-7.39030093e-02, 1.80508122e-01, 8.00530612e-02,
2.32508481e-02, 5.16086072e-02, 8.30580592e-02,
-1.09614551e-01, 2.05054462e-01, 5.47491461e-02],
[-2.92942554e-01, 1.58341527e-02, -5.25936484e-04,
7.54894614e-02, 1.75417557e-01, 1.60730407e-01,
5.91714680e-03, -2.53270268e-02, -2.71934599e-01,
-2.63609916e-01, 1.75930902e-01, 2.68442690e-01,
-1.60669044e-01, 4.51646745e-03, -4.13359016e-01,
1.32156700e-01, 2.06499949e-01, -9.21603739e-02,
-3.21212083e-01, 2.94354409e-02, -3.51502746e-02,
-1.13733053e-01, 7.04959035e-03, 6.02722168e-02,
3.10117364e-01, 3.13739121e-01, 1.54746130e-01,
2.38439858e-01, 2.05240831e-01, 1.65418029e-01],
[-1.31756335e-01, -9.71664637e-02, -2.11081415e-01,
3.06450367e-01, 1.31590337e-01, 2.54443496e-01,
-2.31858999e-01, 2.65551567e-01, -2.02050045e-01,
2.71081388e-01, -1.37878060e-02, -1.70021936e-01,
-1.65380538e-03, 9.54623818e-02, 2.84351885e-01,
-1.01871341e-01, 2.01485753e-02, 9.12380219e-02,
-6.15977198e-02, 4.80147898e-02, -1.39563948e-01,
3.02652836e-01, 1.64852113e-01, -2.00138420e-01,
9.66768116e-02, -9.22664106e-02, -9.24981087e-02,
-2.47369722e-01, 2.81407475e-01, 1.83511779e-01],
[-1.95965677e-01, -2.62236357e-01, -2.39685029e-02,
1.40571550e-01, -5.11714965e-02, 9.83205438e-02,
1.00091219e-01, 8.76448601e-02, -2.09155291e-01,
-4.81763929e-02, 1.15129888e-01, -1.07420832e-02,
6.28655702e-02, -1.43948466e-01, 1.83107510e-01,
1.86010495e-01, -1.79242194e-02, -1.05096996e-02,
2.98826218e-01, 2.92384177e-02, 1.50228307e-01,
-2.57354379e-02, 1.05159089e-01, 3.26832682e-01,
-6.47494942e-02, -7.94630796e-02, -3.30955148e-01,
-3.33948135e-01, 1.46547139e-01, -1.86091036e-01],
[-4.33291495e-02, 1.88203737e-01, 3.16037238e-02,
1.19405270e-01, -2.26772234e-01, 9.43277180e-02,
-8.72166008e-02, 2.56006062e-01, -1.48900136e-01,
9.94460434e-02, 1.87725961e-01, -1.95279077e-01,
8.27598572e-02, -1.46700770e-01, -1.25416532e-01,
-1.37769252e-01, 9.57628340e-02, -2.98057973e-01,
1.05415285e-01, -1.18127242e-01, -2.35549033e-01,
-1.76974535e-02, -2.97596186e-01, 4.32238877e-02,
-4.16904092e-02, 4.33115363e-02, -1.08654559e-01,
3.52643073e-01, 2.74524629e-01, 1.66427344e-02],
[ 1.77624241e-01, -7.08051473e-02, -1.25589028e-01,
-1.33987144e-01, -2.28391200e-01, -2.04036236e-01,
7.88579136e-02, 1.33848041e-01, -3.16916883e-01,
-1.34889603e-01, -8.19702297e-02, 2.77304649e-02,
2.47642547e-02, 1.05886340e-01, -2.58318663e-01,
-2.43119746e-01, 3.77329141e-02, 5.44693917e-02,
-1.35345489e-01, -1.10017732e-01, -3.13927293e-01,
5.12518585e-02, -5.12987375e-04, -1.58919290e-01,
-1.70748994e-01, 2.36288786e-01, 8.46761465e-02,
1.05235279e-02, 8.87275785e-02, 1.64180309e-01],
[ 3.75420362e-01, 6.81641251e-02, 7.72201270e-02,
-4.03955430e-01, -5.49944937e-02, 8.78675282e-03,
3.32516432e-01, 1.84740856e-01, -7.99065679e-02,
1.99955359e-01, 1.41463295e-01, -1.58552006e-01,
-1.59612060e-01, -1.87727511e-01, -1.75989211e-01,
-1.34044841e-01, 2.13294059e-01, -1.30978391e-01,
1.06950596e-01, 2.87034482e-01, 1.33578539e-01,
3.30342472e-01, -2.68601805e-01, -2.23761916e-01,
2.93608367e-01, -3.48806530e-02, 1.48320645e-01,
1.26249835e-01, -2.08334908e-01, 5.82257062e-02],
[ 1.78356782e-01, -1.20771334e-01, -7.79799819e-02,
1.64743349e-01, 1.32354870e-01, 1.11938149e-01,
-2.97888666e-02, -1.83453709e-02, 3.70834023e-02,
-3.93104136e-01, 4.38099802e-02, 1.25917226e-01,
-2.03110918e-01, 1.49912372e-01, 8.95697773e-02,
-1.43052682e-01, 3.78358483e-01, 2.53705919e-01,
9.40865725e-02, 2.99577773e-01, 1.92300975e-03,
-3.35528851e-02, -3.18194628e-01, 8.42917114e-02,
-1.03216350e-01, 1.54624790e-01, 1.52044371e-01,
-3.53974104e-03, -1.56484321e-01, 3.17795128e-02],
[-3.34984601e-01, -1.87048599e-01, 1.26603231e-01,
2.71421999e-01, -4.17920053e-02, -1.65970325e-02,
-1.58861458e-01, 1.46432027e-01, 1.03171512e-01,
1.39130712e-01, 2.62030542e-01, 3.82863283e-02,
1.70419857e-01, 2.81390846e-01, 3.02026719e-02,
-2.17159256e-01, 5.98860085e-02, 2.09414154e-01,
2.78205037e-01, -3.02840352e-01, 2.17414171e-01,
6.87691420e-02, -1.62359178e-02, -9.31997299e-02,
1.67161644e-01, -5.67281246e-02, -1.67869329e-02,
-3.39672238e-01, 4.14884984e-02, 2.41749167e-01]], dtype=float32),
},
Dense_1: {
bias: array([ 6.7055225e-06, 6.7055225e-06, 6.7055225e-06, 6.7055225e-06,
6.7055225e-06, 6.7055225e-06, -4.6221808e-02, -3.4295321e-02,
4.9075842e-02, 6.9026187e-02, 5.4597139e-02, 6.7055225e-06,
6.7055225e-06, -2.3682773e-02, -1.4512062e-02], dtype=float32),
kernel: array([[-2.15783074e-01, -8.00836086e-02, -3.41679364e-01,
-3.61634940e-02, -4.04338688e-02, -1.92787528e-01,
7.80968368e-02, 3.84697020e-01, -2.70941556e-01,
3.09228301e-02, -1.12026677e-01, 1.21513918e-01,
3.84846568e-01, 1.29461914e-01, 3.02034169e-02],
[ 3.03589344e-01, 1.49001822e-01, 2.24479139e-02,
1.72641233e-01, 1.11690164e-03, -1.60670713e-01,
1.60862997e-01, -2.07054362e-01, 3.59371305e-03,
8.14582109e-02, -8.32385570e-02, -2.71430016e-01,
3.29026878e-01, 6.39057159e-03, -2.08910599e-01],
[ 1.42273605e-02, 2.52373248e-01, 2.65929043e-01,
-7.87675530e-02, 2.57075429e-02, 1.37467578e-01,
-3.03784698e-01, -3.00662845e-01, 2.28536919e-01,
7.39716142e-02, -5.44494987e-02, 6.82624280e-02,
-1.14752382e-01, -4.36386168e-02, -2.58043408e-03],
[-3.78126502e-02, 3.64015222e-01, -6.23669326e-02,
1.66056380e-01, -1.78249836e-01, 1.75972238e-01,
-2.43168324e-01, -3.16800296e-01, -6.96922392e-02,
-5.26114106e-02, 1.36705399e-01, -5.23618162e-02,
1.28879577e-01, 2.59115607e-01, -6.60537034e-02],
[ 1.62194669e-02, -5.36157191e-03, 6.25325292e-02,
3.77438962e-03, 5.34493476e-02, 9.57345217e-02,
-3.37895215e-01, 1.54839545e-01, -5.76823950e-05,
3.46475482e-01, -1.44219398e-02, -1.57836825e-02,
1.02470160e-01, -2.66014189e-02, 1.38210922e-01],
[-1.54244944e-01, -1.89163774e-01, -1.25277504e-01,
-1.53509825e-01, -2.08621144e-01, -3.57687324e-02,
1.82844996e-02, 1.68684870e-01, 5.94071597e-02,
3.75792086e-02, 7.39661902e-02, 3.35429460e-02,
6.90625012e-02, 1.51754677e-01, -2.60873646e-01],
[ 9.05385166e-02, 3.13364267e-01, -1.75755590e-01,
5.33922911e-02, -1.96638182e-01, 2.29208365e-01,
2.01122493e-01, 1.39216185e-01, 3.18074644e-01,
-2.20635474e-01, 1.32455468e-01, -4.49397564e-02,
-3.59446257e-02, 3.02191943e-01, -1.93763226e-02],
[-9.32638943e-02, -9.22463536e-02, 2.90949643e-01,
1.38086572e-01, 1.45970643e-01, 1.11227736e-01,
1.92929432e-01, -2.20268294e-01, 8.51592422e-03,
-1.14801973e-02, -1.08394876e-01, -4.92374301e-02,
1.53573290e-01, 3.89962226e-01, -1.53074652e-01],
[ 2.90715694e-03, 1.83663428e-01, 3.76528651e-02,
-1.73860639e-02, 1.83179498e-01, 4.10255790e-03,
9.65543687e-02, 7.96876550e-02, 2.19800651e-01,
2.27372974e-01, -1.51361659e-01, 2.04350561e-01,
1.18743330e-01, -3.37018371e-01, 1.12518296e-01],
[-3.69710326e-02, 5.37259430e-02, -5.55092096e-03,
-4.27453220e-03, -1.99901059e-01, -1.28307149e-01,
5.43928742e-02, 1.31580710e-01, -1.68890849e-01,
-3.78164351e-01, -6.06988072e-02, 2.29855090e-01,
-1.12242132e-01, 3.79682928e-02, -3.86388510e-01],
[ 2.35311255e-01, 1.66386098e-01, -8.38957876e-02,
2.03469038e-01, -2.03613058e-01, -7.19606429e-02,
1.12464979e-01, 2.45347440e-01, 2.39544034e-01,
-1.39806151e-01, -2.63890773e-02, -1.12561002e-01,
-2.70868719e-01, -4.71335649e-03, 1.29959971e-01],
[-5.57020754e-02, -3.46530616e-01, 2.98494935e-01,
-1.66801214e-01, 6.14305437e-02, 9.28813517e-02,
1.35294989e-01, -1.34979218e-01, -7.89761543e-04,
-2.50806570e-01, 1.01332352e-01, 1.00093633e-01,
1.26480818e-01, 1.32050708e-01, 2.52038181e-01],
[ 1.38003528e-02, -1.96471453e-01, 1.48797393e-01,
3.88497114e-02, -1.44033611e-01, 3.50036263e-01,
-3.26102078e-02, -1.19598106e-01, -3.50412369e-01,
-9.01353359e-02, 1.68153301e-01, -1.73634619e-01,
-2.64526069e-01, 1.89368397e-01, -3.03420037e-01],
[ 1.53409019e-01, -1.65645137e-01, 2.75816798e-01,
-2.61331946e-02, 9.50511992e-02, -1.15451679e-01,
-2.91420221e-02, 9.09120291e-02, 1.60791710e-01,
1.59705788e-01, -1.63535506e-01, -2.39277333e-02,
2.17306137e-01, -3.75522256e-01, 3.61660779e-01],
[ 2.75457501e-01, -5.11639714e-02, 3.02877873e-02,
3.83744717e-01, 1.89788371e-01, -3.05498004e-01,
-1.36434168e-01, -9.84745473e-02, -8.30809176e-02,
-1.72840640e-01, 2.08918750e-03, 2.70356536e-01,
-1.43099576e-02, 1.42251849e-02, -2.30559021e-01],
[-1.12811208e-01, -8.90487880e-02, 5.26717752e-02,
-3.34500819e-02, 1.79550871e-01, 1.52729020e-01,
-5.19486815e-02, 1.09064624e-01, 2.16731712e-01,
-5.77685237e-02, 2.93150067e-01, -2.72270977e-01,
2.27185652e-01, -4.16627526e-02, 8.24269503e-02],
[ 1.12723231e-01, 1.53931171e-01, -1.41348794e-01,
-1.82246700e-01, -2.60128081e-01, 2.28294432e-01,
-1.41973585e-01, -2.20511407e-01, 4.12718713e-01,
2.26580381e-01, 3.02141160e-02, -4.08996940e-02,
-1.77697018e-01, 2.13378891e-01, 2.34325364e-01],
[-8.89122784e-02, -1.33175939e-01, 1.13867760e-01,
-3.15929264e-01, -5.33000082e-02, -3.32492590e-02,
1.61763191e-01, -4.45095897e-02, -3.13157290e-02,
1.08764470e-02, -1.84109345e-01, -4.28237021e-02,
-6.48176223e-02, -5.03456593e-03, 7.79272169e-02],
[ 1.07352838e-01, 1.53558820e-01, 9.59271789e-02,
1.95107758e-02, -7.16409087e-03, -2.52965033e-01,
-2.34644741e-01, 3.56166244e-01, 1.75924093e-01,
-1.80997580e-01, -2.52380699e-01, -5.56088388e-02,
-2.03573480e-01, -1.34791806e-01, 1.43947557e-01],
[-3.20794344e-01, 1.52731538e-02, 3.17644477e-01,
3.08467388e-01, 6.71731085e-02, -2.07317039e-01,
6.92533553e-02, 8.92890543e-02, -5.71464747e-02,
6.77597970e-02, -2.40857795e-01, 9.92331654e-02,
-3.97704244e-01, -1.44478083e-01, 1.78948924e-01],
[ 1.78290546e-01, -3.73491168e-01, -3.44891042e-01,
-1.85131654e-01, -1.25298679e-01, -3.59349012e-01,
-2.19286203e-01, 4.03494358e-01, -5.62228560e-02,
-1.21513605e-01, 3.14101666e-01, -5.40824682e-02,
1.32633939e-01, 1.56066120e-02, 2.20855936e-01],
[-1.59709528e-01, -1.92597717e-01, -3.99962455e-01,
-1.37944147e-01, 2.50320375e-01, -3.02701652e-01,
2.37861603e-01, 3.54276657e-01, -2.65088379e-02,
2.30067194e-01, 1.21199116e-01, -1.71849012e-01,
-1.15864873e-02, 2.79654086e-01, 6.52542561e-02],
[-3.87499511e-01, -2.14815795e-01, -3.50072563e-01,
-3.74599099e-01, 3.15827131e-03, -3.80146921e-01,
-2.75051177e-01, 4.92296517e-02, -3.42986435e-02,
2.64439374e-01, 2.24484161e-01, -9.03718919e-02,
-1.66824907e-01, 2.33413532e-01, -2.70957470e-01],
[-2.09726006e-01, -1.01519674e-02, 1.65578052e-01,
2.08754063e-01, -1.90129966e-01, -3.17800641e-01,
-3.11262459e-02, -6.45868927e-02, 3.97725523e-01,
-2.66408682e-01, 3.11382055e-01, -6.38214052e-02,
-3.96969140e-01, 1.07675835e-01, 1.15448385e-02],
[ 1.90294176e-01, 2.62975395e-02, 1.02480963e-01,
-2.17438534e-01, -2.62192190e-01, 1.49379149e-01,
1.36339113e-01, -2.20250994e-01, -6.47595972e-02,
2.78646350e-01, -1.04613230e-01, -3.21846724e-01,
-3.60531211e-02, 4.10647094e-02, -2.01862305e-02],
[ 4.84227687e-02, -1.74249634e-01, 1.22636214e-01,
3.27275246e-01, -8.30558389e-02, -3.14863682e-01,
1.06451482e-01, 9.95579511e-02, 7.17695206e-02,
-2.05804944e-01, 4.14260328e-02, -2.82542408e-03,
1.59710690e-01, 1.95358038e-01, -2.18680993e-01],
[ 1.80844143e-01, 9.28273797e-02, -2.79074967e-01,
-3.21841598e-01, 8.46190900e-02, -1.31674752e-01,
-2.22147837e-01, 6.93762749e-02, 1.08454213e-01,
-1.54395998e-01, -2.52935737e-02, 3.96427214e-02,
-1.77300423e-02, 4.08218354e-02, 1.57031059e-01],
[ 1.72292829e-01, -2.74210989e-01, 3.91593874e-02,
-1.06435806e-01, 1.53482422e-01, -4.07753527e-01,
-1.45187899e-01, -1.97193578e-01, 1.51643381e-01,
8.71199816e-02, -1.80934519e-02, 3.16381007e-02,
-3.16615790e-01, -8.88970047e-02, -3.15811843e-01],
[-9.74183232e-02, -2.85981894e-02, 3.55937093e-01,
-2.71050155e-01, 3.29447448e-01, 8.03399384e-02,
3.03259671e-01, -1.92141250e-01, -1.79099917e-01,
2.89139271e-01, 9.35540348e-02, -2.52466857e-01,
1.29278764e-01, 3.87039840e-01, -3.96069586e-01],
[ 6.32186532e-02, 8.49261880e-02, -1.27169192e-02,
1.09564900e-01, 1.01468608e-01, -1.91385388e-01,
2.12272123e-01, 3.93263578e-01, 3.21644545e-03,
1.96522057e-01, -1.37627035e-01, 2.72568524e-01,
-9.58500952e-02, -5.87881505e-02, 3.29161406e-01]], dtype=float32),
},
Dense_2: {
bias: array([-1.9683748e-02, -2.1595210e-03, 6.7055225e-06, -5.0753772e-02,
1.3499202e-01, 4.5890406e-02, -1.5943155e-02, 6.7055225e-06],
dtype=float32),
kernel: array([[ 1.70742482e-01, -9.91618633e-02, -2.29330659e-02,
1.22218162e-01, -1.06615439e-01, -3.67971659e-02,
1.41981930e-01, 8.84072632e-02],
[ 1.80515304e-01, 1.18458852e-01, 5.27178884e-01,
-1.53531089e-01, -3.87934148e-02, -1.23133227e-01,
6.29240274e-03, 1.60064697e-02],
[ 1.46227553e-01, 5.25590360e-01, -3.44348907e-01,
-9.29221213e-02, 1.28932714e-01, -1.73384249e-02,
1.19381860e-01, -4.11042988e-01],
[ 2.49569118e-01, 2.14901835e-01, 3.24754208e-01,
-4.91983056e-01, -1.14351794e-01, -2.11403668e-02,
7.69451857e-02, 3.31384718e-01],
[ 4.12717015e-02, -1.21542990e-01, -3.10934186e-01,
3.62211466e-01, -2.74409771e-01, -5.17052352e-01,
-5.41649759e-02, 5.59365511e-01],
[-4.59556133e-02, -6.34765029e-02, 1.12813324e-01,
3.72479081e-01, 1.19602963e-01, 2.29919955e-01,
-1.38893351e-01, 4.47091460e-02],
[-6.79109395e-02, -3.46875697e-01, -4.84580278e-01,
1.12352774e-01, -1.92538321e-01, -3.71922582e-01,
5.66129565e-01, -2.42807910e-01],
[ 3.35851789e-01, -1.21225804e-01, 3.77379179e-01,
4.12614942e-01, 4.95961308e-02, -3.28645080e-01,
4.95315671e-01, -1.31850213e-01],
[ 3.26091319e-01, 6.93900436e-02, 2.03807250e-01,
1.04140580e-01, 3.88821214e-01, 1.70190245e-01,
-2.11963385e-01, 5.57184219e-04],
[ 1.63282499e-01, -1.22407228e-02, -4.44926322e-01,
-9.26266760e-02, 3.41380268e-01, 4.87684071e-01,
-1.46689430e-01, -2.39341646e-01],
[-2.73013204e-01, 1.98071614e-01, 1.63841352e-01,
-4.02920216e-01, 1.04726478e-01, 3.63073707e-01,
-5.87156415e-03, -2.55927563e-01],
[ 2.71851361e-01, -6.50364012e-02, 5.75378239e-02,
8.61035287e-02, 4.62561399e-02, 6.84097409e-02,
-3.49434376e-01, 3.20657223e-01],
[ 3.87204051e-01, 1.02552548e-01, -3.67724180e-01,
-1.37631521e-01, -2.76330113e-03, -9.74713266e-03,
-2.03522891e-02, -3.78593743e-01],
[ 1.42845362e-01, -1.62034005e-01, -4.73217815e-01,
2.21667886e-01, 9.57852453e-02, -3.21318150e-01,
1.93716571e-01, -1.64225832e-01],
[-8.96123946e-02, 1.43070787e-01, -3.05288434e-02,
4.30531144e-01, -3.90004218e-01, 5.45606494e-01,
-9.91108418e-02, -3.92888844e-01]], dtype=float32),
},
Dense_3: {
bias: array([0.3012179], dtype=float32),
kernel: array([[-0.04792835],
[ 0.6472556 ],
[ 0.19396764],
[-0.19879013],
[ 0.4868285 ],
[ 0.70990056],
[-0.05229847],
[-0.20487049]], dtype=float32),
},
},
})
Lastly, let’s validate the model.
[8]:
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')
auc=0.9921388797903701
This is the end of the lab.