{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Neural Network with SPU\n", "\n", "> Please read lab [Logistic Regression On SPU](./lr_with_spu.ipynb) first if you have not。\n", "\n", "In lab [Logistic Regression On SPU](./lr_with_spu.ipynb), we have showed how to use SecretFlow/SPU to convert a plaintext JAX training program to a secure MPC training program.\n", "\n", "In this lab, the idea is quite similar but this time we will work with a Neural Network model.\n", "\n", "We are going to use the same dataset and all the settings as lab [Logistic Regression On SPU](./lr_with_spu.ipynb).\n", "\n", "And first, let's work out the plaintext model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ ">The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> This tutorial needs more resources than 8c16g, which is the minimum requirement of SecretFlow." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train a model with JAX/FLAX\n", "\n", "### Load the Dataset\n", "\n", "The below is just copied from lab [Logistic Regression On SPU](./lr_with_spu.ipynb). I'm not going to explain again." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "!{sys.executable} -m pip install flax==0.6.0" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.datasets import load_breast_cancer\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import Normalizer\n", "\n", "\n", "def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):\n", " scaler = Normalizer(norm='max')\n", " x, y = load_breast_cancer(return_X_y=True)\n", " x = scaler.fit_transform(x)\n", " x_train, x_test, y_train, y_test = train_test_split(\n", " x, y, test_size=0.2, random_state=42\n", " )\n", "\n", " if train:\n", " if party_id:\n", " if party_id == 1:\n", " return x_train[:, 15:], _\n", " else:\n", " return x_train[:, :15], y_train\n", " else:\n", " return x_train, y_train\n", " else:\n", " return x_test, y_test\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define the Model\n", "\n", "\n", "We are going to use a 4-layer [MLP](https://en.wikipedia.org/wiki/Multilayer_perceptron) model with a [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) activation function here." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from typing import Sequence\n", "import flax.linen as nn\n", "\n", "\n", "FEATURES = [30, 15, 8, 1]\n", "\n", "\n", "class MLP(nn.Module):\n", " features: Sequence[int]\n", "\n", " @nn.compact\n", " def __call__(self, x):\n", " for feat in self.features[:-1]:\n", " x = nn.relu(nn.Dense(feat)(x))\n", " x = nn.Dense(self.features[-1])(x)\n", " return x\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we define the training method here." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import jax.numpy as jnp\n", "\n", "\n", "def predict(params, x):\n", " # TODO(junfeng): investigate why need to have a duplicated definition in notebook,\n", " # which is not the case in a normal python program.\n", " from typing import Sequence\n", " import flax.linen as nn\n", "\n", " FEATURES = [30, 15, 8, 1]\n", "\n", " class MLP(nn.Module):\n", " features: Sequence[int]\n", "\n", " @nn.compact\n", " def __call__(self, x):\n", " for feat in self.features[:-1]:\n", " x = nn.relu(nn.Dense(feat)(x))\n", " x = nn.Dense(self.features[-1])(x)\n", " return x\n", "\n", " return MLP(FEATURES).apply(params, x)\n", "\n", "\n", "def loss_func(params, x, y):\n", " pred = predict(params, x)\n", "\n", " def mse(y, pred):\n", " def squared_error(y, y_pred):\n", " return jnp.multiply(y - y_pred, y - y_pred) / 2.0\n", "\n", " return jnp.mean(squared_error(y, pred))\n", "\n", " return mse(y, pred)\n", "\n", "\n", "def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):\n", " x = jnp.concatenate((x1, x2), axis=1)\n", " xs = jnp.array_split(x, len(x) / n_batch, axis=0)\n", " ys = jnp.array_split(y, len(y) / n_batch, axis=0)\n", "\n", " def body_fun(_, loop_carry):\n", " params = loop_carry\n", " for (x, y) in zip(xs, ys):\n", " _, grads = jax.value_and_grad(loss_func)(params, x, y)\n", " params = jax.tree_util.tree_map(\n", " lambda p, g: p - step_size * g, params, grads\n", " )\n", " return params\n", "\n", " params = jax.lax.fori_loop(0, n_epochs, body_fun, params)\n", " return params\n", "\n", "\n", "def model_init(n_batch=10):\n", " model = MLP(FEATURES)\n", " return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Validate the Model\n", "\n", "We use AUC as the validation metric." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import roc_auc_score\n", "\n", "\n", "def validate_model(params, X_test, y_test):\n", " y_pred = predict(params, X_test)\n", " return roc_auc_score(y_test, y_pred)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### BUILD Together\n", "\n", "Let's put everything together and train a plaintext NN model!" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "auc=0.9921388797903701\n" ] } ], "source": [ "import jax\n", "\n", "# Load the data\n", "x1, _ = breast_cancer(party_id=1, train=True)\n", "x2, y = breast_cancer(party_id=2, train=True)\n", "\n", "\n", "# Hyperparameter\n", "n_batch = 10\n", "n_epochs = 10\n", "step_size = 0.01\n", "\n", "\n", "# Train the model\n", "init_params = model_init(n_batch)\n", "params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)\n", "\n", "# Test the model\n", "X_test, y_test = breast_cancer(train=False)\n", "auc = validate_model(params, X_test, y_test)\n", "print(f'auc={auc}')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ust keep the number of AUC in mind, we are going to repeat the training with SPU. Let's do that magic!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Train a Model with SPU" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-08-25 10:54:11.675312: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(pid=2944990)\u001b[0m 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n", "\u001b[2m\u001b[36m(pid=2944994)\u001b[0m 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n", "\u001b[2m\u001b[36m(pid=2944991)\u001b[0m 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n", "\u001b[2m\u001b[36m(pid=2944992)\u001b[0m 2022-08-25 10:54:17.064685: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n", "\u001b[2m\u001b[36m(pid=2944996)\u001b[0m 2022-08-25 10:54:17.053663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n", "\u001b[2m\u001b[36m(pid=2944988)\u001b[0m 2022-08-25 10:54:17.011062: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n", "\u001b[2m\u001b[36m(_run pid=2944995)\u001b[0m 2022-08-25 10:54:17.145220: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n", "\u001b[2m\u001b[36m(_run pid=2944993)\u001b[0m 2022-08-25 10:54:17.154674: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n", "\u001b[2m\u001b[36m(SPURuntime pid=2944988)\u001b[0m I0825 10:54:18.801868 2944988 external/com_github_brpc_brpc/src/brpc/server.cpp:1066] Server[yacl::link::internal::ReceiverServiceImpl] is serving on port=44351.\n", "\u001b[2m\u001b[36m(SPURuntime pid=2944988)\u001b[0m I0825 10:54:18.801938 2944988 external/com_github_brpc_brpc/src/brpc/server.cpp:1069] Check out http://k69b13338.eu95sqa:44351 in web browser.\n", "\u001b[2m\u001b[36m(_run pid=2944995)\u001b[0m 2022-08-25 10:54:18,906,906 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n", "\u001b[2m\u001b[36m(SPURuntime pid=2944990)\u001b[0m I0825 10:54:18.863563 2944990 external/com_github_brpc_brpc/src/brpc/server.cpp:1066] Server[yacl::link::internal::ReceiverServiceImpl] is serving on port=22227.\n", "\u001b[2m\u001b[36m(SPURuntime pid=2944990)\u001b[0m I0825 10:54:18.863633 2944990 external/com_github_brpc_brpc/src/brpc/server.cpp:1069] Check out http://k69b13338.eu95sqa:22227 in web browser.\n", "\u001b[2m\u001b[36m(SPURuntime pid=2944988)\u001b[0m I0825 10:54:18.902659 2945551 external/com_github_brpc_brpc/src/brpc/socket.cpp:2202] Checking Socket{id=0 addr=127.0.0.1:22227} (0x55f83a4820c0)\n", "\u001b[2m\u001b[36m(SPURuntime pid=2944988)\u001b[0m I0825 10:54:18.902798 2945551 external/com_github_brpc_brpc/src/brpc/socket.cpp:2262] Revived Socket{id=0 addr=127.0.0.1:22227} (0x55f83a4820c0) (Connectable)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(_run pid=2944995)\u001b[0m [2022-08-25 10:54:18.906] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63\n", "\u001b[2m\u001b[36m(_run pid=2944993)\u001b[0m [2022-08-25 10:54:18.941] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(_run pid=2944993)\u001b[0m 2022-08-25 10:54:18,941,941 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n", "\u001b[2m\u001b[36m(_run pid=2944996)\u001b[0m 2022-08-25 10:54:18,952,952 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] } ], "source": [ "import secretflow as sf\n", "\n", "# In case you have a running secretflow runtime already.\n", "sf.shutdown()\n", "\n", "sf.init(['alice', 'bob'], address='local')\n", "\n", "alice, bob = sf.PYU('alice'), sf.PYU('bob')\n", "spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))\n", "\n", "x1, _ = alice(breast_cancer)(party_id=1, train=True)\n", "x2, y = bob(breast_cancer)(party_id=2, train=True)\n", "init_params = model_init(n_batch)\n", "\n", "\n", "device = spu\n", "x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)\n", "init_params_ = sf.to(alice, init_params).to(device)\n", "\n", "params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(\n", " x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size\n", ")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check params from SPU program." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(SPURuntime pid=2944988)\u001b[0m [2022-08-25 10:54:29.343] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63\n", "\u001b[2m\u001b[36m(SPURuntime pid=2944990)\u001b[0m [2022-08-25 10:54:29.328] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63\n", "FrozenDict({\n", " params: {\n", " Dense_0: {\n", " bias: array([-5.1895231e-03, 1.8576235e-03, 6.7055225e-06, 2.3722053e-03,\n", " 3.1328186e-02, 6.7055225e-06, -1.5438795e-02, 6.7055225e-06,\n", " 6.7055225e-06, -4.1383177e-02, 6.7055225e-06, -2.2077844e-02,\n", " 6.7055225e-06, 6.7055225e-06, -4.1123331e-03, 6.7055225e-06,\n", " 3.9519504e-02, -2.0013824e-02, 6.7055225e-06, -1.7026097e-02,\n", " -2.0100027e-03, -1.2290031e-02, 3.5510257e-02, 6.7055225e-06,\n", " 1.0400787e-02, 6.7055225e-06, 6.7055225e-06, 6.7055225e-06,\n", " 6.7055225e-06, -2.4091452e-03], dtype=float32),\n", " kernel: array([[-1.48707509e-01, -2.35312745e-01, -1.49370432e-01,\n", " -1.55658722e-02, -1.33344889e-01, 1.91762865e-01,\n", " -3.67970318e-02, -3.74531150e-02, -1.41760916e-01,\n", " 3.24019790e-02, 1.26554981e-01, -4.02508259e-01,\n", " -1.68948472e-01, 2.13856786e-01, -1.38430491e-01,\n", " 1.05858222e-01, -1.16115242e-01, 3.86231482e-01,\n", " 5.96616715e-02, 6.30981326e-02, 7.79366642e-02,\n", " -1.31276995e-02, -2.88023233e-01, -9.60215628e-02,\n", " 1.11036777e-01, -8.54356140e-02, 7.54796267e-02,\n", " -4.11889851e-02, -3.82702172e-01, 2.37623483e-01],\n", " [ 1.77950412e-01, 2.29401365e-01, -2.44401962e-01,\n", " -1.48464620e-01, 3.36820692e-01, 2.58408487e-02,\n", " -4.21461910e-02, 4.10513222e-01, 3.24393630e-01,\n", " -1.64218560e-01, 8.17326903e-02, 5.25916368e-02,\n", " 3.11348379e-01, 2.92945385e-01, 1.22726962e-01,\n", " -3.87524545e-01, -3.85492414e-01, -6.53883070e-02,\n", " -2.59158403e-01, -3.32410604e-01, -3.15876693e-01,\n", " -2.91079849e-01, -6.01480752e-02, 2.29791373e-01,\n", " 1.01019114e-01, -1.30936205e-02, 1.78829342e-01,\n", " -2.32158348e-01, 3.82750481e-01, 3.79820764e-02],\n", " [ 1.42894119e-01, -2.13525295e-02, 1.68190986e-01,\n", " 8.99314433e-02, -3.89296085e-01, -4.85262126e-02,\n", " 1.38709426e-01, -5.80671132e-02, 2.84729242e-01,\n", " -1.26564085e-01, 2.57159889e-01, 9.64836925e-02,\n", " 1.16704285e-01, -1.90519944e-01, -3.98443341e-02,\n", " -9.57311541e-02, 7.24758208e-02, 1.46388173e-01,\n", " 9.22411978e-02, 3.83930653e-02, -2.46246234e-01,\n", " 3.76523137e-02, -1.90257579e-02, -2.52097666e-01,\n", " 1.70252651e-01, 2.49821901e-01, -2.81400979e-03,\n", " -9.84508097e-02, 2.07754672e-01, 8.81182998e-02],\n", " [-2.25118160e-01, -4.46529686e-03, -4.55757827e-02,\n", " -4.27887589e-02, -1.34788275e-01, -3.42829585e-01,\n", " -9.25171375e-03, 9.01048332e-02, -2.82488763e-01,\n", " 2.23106414e-01, 1.17174536e-01, 4.47563082e-02,\n", " -5.02136499e-02, 2.90469885e-01, -2.35435143e-01,\n", " 5.99489361e-02, 2.55998850e-01, 1.87139437e-01,\n", " 1.46078974e-01, 2.93805301e-01, 3.95906270e-02,\n", " 1.63574621e-01, -1.12822607e-01, -6.96890652e-02,\n", " 2.61215031e-01, -3.17869127e-01, 8.15212727e-02,\n", " 2.55623460e-01, -5.98048419e-02, 1.93995312e-01],\n", " [-2.42546692e-01, 7.96881765e-02, -6.76873624e-02,\n", " -1.17468625e-01, 1.87561423e-01, -6.13799393e-02,\n", " -5.36612719e-02, -6.93449825e-02, 7.92452097e-02,\n", " 2.54104435e-02, -3.18573356e-01, 2.87047565e-01,\n", " -6.02750182e-02, -3.01489800e-01, -1.76609412e-01,\n", " 7.97384530e-02, 1.61419287e-01, 3.27948928e-01,\n", " 2.05150560e-01, 3.03487569e-01, 2.71105886e-01,\n", " 2.76556283e-01, 7.07161725e-02, 2.08005011e-01,\n", " -7.33365566e-02, -1.03249148e-01, 1.55347139e-02,\n", " 3.17582250e-01, 3.16770792e-01, -6.80912733e-02],\n", " [ 7.31387287e-02, 6.06941879e-02, 5.30773848e-02,\n", " 1.78480223e-01, 1.88692018e-01, 1.77043870e-01,\n", " 8.07239711e-02, 1.48239732e-02, -4.42493856e-02,\n", " 6.05108887e-02, 2.88271666e-01, 5.82368821e-02,\n", " 2.64274329e-01, 1.20873421e-01, -2.19291598e-02,\n", " -1.51663244e-01, -4.52437997e-02, 6.89959526e-03,\n", " 2.83989072e-01, 1.00405067e-01, -3.03579658e-01,\n", " -3.13743889e-01, -1.75221011e-01, 1.12235680e-01,\n", " 1.15922809e-01, -2.76054680e-01, -6.86793476e-02,\n", " 6.13613129e-02, 3.08204144e-01, -2.80041754e-01],\n", " [-2.58817077e-01, -1.54795498e-02, 2.77136236e-01,\n", " -3.83958876e-01, 4.00266647e-01, -1.29193932e-01,\n", " -2.80725807e-02, 1.75729364e-01, 1.30319521e-01,\n", " 1.56524405e-01, -2.62551606e-02, 2.93114007e-01,\n", " 3.55902165e-02, 3.58770192e-02, -2.96767652e-01,\n", " -1.69694602e-01, 3.19763869e-02, 1.61003023e-01,\n", " -1.75322652e-01, -8.83801430e-02, -4.86088842e-02,\n", " -1.05747893e-01, 8.46082866e-02, -4.35629487e-02,\n", " -2.53890634e-01, -9.85630006e-02, -4.16661799e-02,\n", " -4.67738807e-02, -3.35340977e-01, -1.18303642e-01],\n", " [ 2.78338730e-01, -6.96775764e-02, -2.48151869e-02,\n", " 4.41243351e-02, -8.29881430e-02, -1.39038265e-03,\n", " 8.48067403e-02, -2.48347864e-01, 1.53563991e-01,\n", " -1.36309460e-01, -4.34331596e-03, -6.34510815e-03,\n", " -7.79297352e-02, -1.02945447e-01, 2.40225092e-01,\n", " -2.56992996e-01, 2.98645079e-01, -1.97924539e-01,\n", " -2.75554240e-01, 1.29482746e-02, 1.26732409e-01,\n", " 1.27245054e-01, 1.59333318e-01, 1.43134385e-01,\n", " 6.40332401e-02, -1.11376449e-01, 5.90692759e-02,\n", " -1.16492316e-01, 1.58721358e-02, -2.03465536e-01],\n", " [ 6.27327412e-02, 8.74824524e-02, -9.10452157e-02,\n", " 1.81233525e-01, 1.06525749e-01, -1.38279110e-01,\n", " 3.20228636e-01, -1.48054436e-01, -3.18492800e-02,\n", " 3.50982606e-01, -1.33533657e-01, 1.76203504e-01,\n", " 5.12142032e-02, -4.60760295e-02, -1.33599430e-01,\n", " -2.22507954e-01, 1.66286990e-01, 3.71264696e-01,\n", " -1.16020367e-01, -1.61022544e-02, -8.04677159e-02,\n", " 1.58347130e-01, 3.06297302e-01, 6.74358159e-02,\n", " 8.46540183e-02, -9.24098492e-03, -3.53501737e-03,\n", " -3.68553460e-01, -1.23027235e-01, 2.18146190e-01],\n", " [ 2.83379406e-02, -1.24499649e-01, 1.79811046e-01,\n", " 1.53642461e-01, 5.48425168e-02, 1.91716999e-01,\n", " -9.49321389e-02, 6.86786622e-02, -7.67818838e-02,\n", " -1.93986297e-02, 5.70140183e-02, -3.93389583e-01,\n", " 5.28795272e-02, 3.79496545e-01, 2.46415466e-01,\n", " -1.21219859e-01, 4.00151759e-02, -3.80354345e-01,\n", " 1.95414037e-01, -9.05120224e-02, 3.20608616e-01,\n", " 1.48523450e-02, -3.49244773e-02, 1.11090288e-01,\n", " -3.37234616e-01, -3.06017160e-01, -1.13247246e-01,\n", " 1.59685776e-01, 6.75145537e-02, 1.00891382e-01],\n", " [ 1.68052852e-01, 1.94981366e-01, -9.76378322e-02,\n", " -1.45580143e-01, 1.01531252e-01, -3.17420304e-01,\n", " -1.15842447e-01, -2.86557317e-01, -1.01209313e-01,\n", " -1.30138457e-01, 1.97996482e-01, -6.92996681e-02,\n", " 1.83074176e-03, -6.13981634e-02, -2.38128498e-01,\n", " 1.41838133e-01, 4.12078500e-01, -1.11510590e-01,\n", " -7.69597739e-02, -3.93857658e-02, -5.82328439e-02,\n", " -2.56166637e-01, 1.75529957e-01, -5.77670783e-02,\n", " 4.62778807e-02, 1.20462373e-01, 3.14444661e-01,\n", " -1.82372808e-01, -1.62539259e-01, -9.76682454e-02],\n", " [-6.19109869e-02, -1.15577027e-01, 7.26505965e-02,\n", " 1.25303209e-01, 2.06839561e-01, 1.57670394e-01,\n", " -8.05730969e-02, -1.94498345e-01, 2.13316083e-02,\n", " 2.35428169e-01, -1.77006692e-01, -3.51173997e-01,\n", " -2.20170230e-01, 3.13621610e-02, 1.01004869e-01,\n", " -4.00861502e-01, -1.33804008e-01, -6.59427792e-02,\n", " -1.41224921e-01, -1.72024697e-01, -6.66107237e-02,\n", " 9.94113684e-02, -3.09019983e-02, 2.59390533e-01,\n", " 6.44736290e-02, -2.50633597e-01, -3.49195153e-02,\n", " 8.02383423e-02, 2.55563200e-01, -2.40900025e-01],\n", " [ 8.98088515e-03, -4.00735557e-01, -6.30197525e-02,\n", " 6.18334711e-02, 3.73557806e-01, 3.17755789e-02,\n", " 2.75026381e-01, -2.88109362e-01, -2.02475563e-01,\n", " 1.61148801e-01, -2.17942014e-01, 1.06317997e-01,\n", " 2.66866386e-03, -2.73037195e-01, 7.52992183e-02,\n", " 7.77817965e-02, 2.63209343e-02, 9.45713371e-02,\n", " -2.83376873e-01, 2.55718678e-02, 1.71330452e-01,\n", " 4.77515757e-02, -1.29865110e-02, -9.19252783e-02,\n", " 2.20202535e-01, -1.98967785e-01, 3.41536522e-01,\n", " 8.68079364e-02, -8.85302424e-02, -9.05130804e-03],\n", " [ 1.20346710e-01, 1.25420868e-01, -3.62598658e-01,\n", " 2.23636329e-01, -7.36351609e-02, 1.05015352e-01,\n", " 4.35760617e-03, -8.31906050e-02, 2.28634894e-01,\n", " -1.49365649e-01, 8.16624165e-02, -3.14154029e-01,\n", " 8.76248032e-02, -5.55702299e-02, 8.57660025e-02,\n", " -2.31696367e-02, -2.23352492e-01, 2.85403579e-02,\n", " 1.04186803e-01, 9.74222422e-02, 8.70420933e-02,\n", " 1.58863813e-02, 1.73726633e-01, 8.37596357e-02,\n", " -1.77873820e-02, -6.53727055e-02, 5.05047441e-02,\n", " -2.29445338e-01, 5.72724342e-02, 2.50912070e-01],\n", " [-3.97108495e-01, 1.00127161e-01, -7.08009750e-02,\n", " -1.62648782e-01, -1.39103666e-01, 1.61610574e-01,\n", " 1.60218611e-01, -7.88767636e-03, 5.42842150e-02,\n", " 1.65928125e-01, 2.23704860e-01, 3.66966009e-01,\n", " 6.14991188e-02, -4.85754609e-02, 3.34523857e-01,\n", " -7.26056099e-02, 1.93905726e-01, -6.00316077e-02,\n", " -3.03020537e-01, 1.71825185e-01, 2.90646702e-01,\n", " 2.13969320e-01, 4.79212999e-02, 9.81050730e-02,\n", " 1.03307858e-01, 1.27324745e-01, -5.79782724e-02,\n", " 1.52468219e-01, -3.66631836e-01, 1.77991658e-01],\n", " [ 1.05444938e-02, 8.89388323e-02, 2.68580914e-01,\n", " 3.43994796e-01, 7.01821893e-02, 3.07618678e-01,\n", " -1.59211323e-01, 2.65379250e-03, 4.29789275e-02,\n", " -6.48853630e-02, -1.19794458e-02, -1.01899371e-01,\n", " 3.90142053e-02, -2.12668777e-02, 1.34829685e-01,\n", " -1.60709843e-01, -2.75420129e-01, -1.10718116e-01,\n", " 1.22148365e-01, -2.56418556e-01, -8.86735022e-02,\n", " 3.57692540e-02, -1.72454745e-01, 1.25590444e-01,\n", " 1.48163527e-01, 9.70243216e-02, 1.78432480e-01,\n", " 8.07075799e-02, 7.18790591e-02, 8.29363018e-02],\n", " [ 1.46744028e-01, 1.35470942e-01, -5.01304418e-02,\n", " -2.56607473e-01, -2.36472189e-01, 2.16720849e-01,\n", " 1.36735886e-01, -3.88280898e-02, 3.90521705e-01,\n", " 3.07039917e-03, 1.45443454e-01, -7.04357922e-02,\n", " -1.51877582e-01, 6.67887330e-02, -2.40254313e-01,\n", " 3.11602056e-01, 6.78158849e-02, -2.53712773e-01,\n", " -2.01759011e-01, -2.26584569e-01, 1.38092458e-01,\n", " -1.41293734e-01, 3.44138265e-01, 1.29559129e-01,\n", " 2.85035908e-01, 6.18830621e-02, -2.29603484e-01,\n", " 2.91220188e-01, -8.08279067e-02, -3.44568819e-01],\n", " [-1.89113468e-02, 1.27276510e-01, 1.18291497e-01,\n", " -8.91744196e-02, -3.88738513e-02, -6.17537051e-02,\n", " -1.13684982e-01, -6.69639111e-02, -3.41004312e-01,\n", " -2.61811137e-01, -1.48398668e-01, -2.04433903e-01,\n", " -3.67532670e-03, 5.23980856e-02, 6.42350912e-02,\n", " 8.27208012e-02, 6.50502890e-02, 2.10662231e-01,\n", " -1.37932986e-01, 3.07054341e-01, -9.90331918e-02,\n", " -1.52072370e-01, -9.87340361e-02, 3.40974480e-02,\n", " 9.54811275e-02, -2.24524289e-01, 3.59725446e-01,\n", " 2.40096241e-01, -3.08321267e-02, -1.03981838e-01],\n", " [-1.48987636e-01, 1.81017682e-01, 1.63939446e-02,\n", " -2.69196749e-01, -2.44855881e-04, -3.97299677e-02,\n", " -4.50673252e-02, 3.15325707e-02, -1.54727101e-01,\n", " -1.31139323e-01, 4.07651365e-02, 1.12029955e-01,\n", " -1.95583731e-01, -1.76382348e-01, 1.18832231e-01,\n", " -2.69036591e-01, 3.04102361e-01, -8.61337632e-02,\n", " -7.33592659e-02, 2.46652514e-02, 2.24634409e-01,\n", " 2.91894495e-01, 5.09455353e-02, -3.00022811e-01,\n", " -5.17807007e-02, -7.68917203e-02, -3.71362567e-01,\n", " 1.96652308e-01, -1.05256483e-01, -2.75513947e-01],\n", " [ 5.38429469e-02, 3.15863788e-02, -4.09971178e-03,\n", " -4.45098132e-02, -1.00758657e-01, -6.42608255e-02,\n", " 3.13616514e-01, -1.36063650e-01, 1.24328464e-01,\n", " -1.09250084e-01, -3.94054949e-02, 2.20205531e-01,\n", " -7.17411488e-02, 8.70945156e-02, 4.95519042e-02,\n", " 3.63173425e-01, 6.60566986e-03, -1.58391029e-01,\n", " 9.21002179e-02, -1.74151346e-01, -1.42024517e-01,\n", " 3.83424699e-01, 2.24799365e-02, 7.36033916e-03,\n", " -2.80536860e-02, -1.58879921e-01, 3.91074270e-02,\n", " -9.43726748e-02, 2.17871547e-01, 1.44041181e-02],\n", " [-9.30076092e-02, -1.98024854e-01, -3.14120024e-01,\n", " 1.71728119e-01, 1.33051425e-01, -1.41134948e-01,\n", " -2.13182554e-01, -1.62383363e-01, 9.43484604e-02,\n", " 1.46695167e-01, 1.86109841e-02, -2.21152157e-02,\n", " -1.46707222e-01, 3.92628670e-01, -2.01349884e-01,\n", " 1.09045491e-01, -2.89507657e-02, -1.52118295e-01,\n", " 1.74316362e-01, 7.77858645e-02, 9.58564728e-02,\n", " 1.02938607e-01, 1.89557344e-01, -1.57446101e-01,\n", " 9.71454233e-02, -2.65448719e-01, -5.12915403e-02,\n", " 8.04104656e-02, 5.85189909e-02, 2.47814834e-01],\n", " [ 5.50863743e-02, 2.30716497e-01, 2.78447568e-03,\n", " -5.15906960e-02, -1.33431196e-01, 1.72307670e-01,\n", " -3.83732319e-02, 1.72320321e-01, -1.20988593e-01,\n", " -1.21835843e-01, -1.65673420e-01, -8.69580209e-02,\n", " -1.52244121e-02, -3.16982597e-01, 1.96166813e-01,\n", " -2.08498746e-01, 3.45463872e-01, 2.52553254e-01,\n", " 3.05833966e-02, -2.36524627e-01, -2.45504826e-02,\n", " -7.39030093e-02, 1.80508122e-01, 8.00530612e-02,\n", " 2.32508481e-02, 5.16086072e-02, 8.30580592e-02,\n", " -1.09614551e-01, 2.05054462e-01, 5.47491461e-02],\n", " [-2.92942554e-01, 1.58341527e-02, -5.25936484e-04,\n", " 7.54894614e-02, 1.75417557e-01, 1.60730407e-01,\n", " 5.91714680e-03, -2.53270268e-02, -2.71934599e-01,\n", " -2.63609916e-01, 1.75930902e-01, 2.68442690e-01,\n", " -1.60669044e-01, 4.51646745e-03, -4.13359016e-01,\n", " 1.32156700e-01, 2.06499949e-01, -9.21603739e-02,\n", " -3.21212083e-01, 2.94354409e-02, -3.51502746e-02,\n", " -1.13733053e-01, 7.04959035e-03, 6.02722168e-02,\n", " 3.10117364e-01, 3.13739121e-01, 1.54746130e-01,\n", " 2.38439858e-01, 2.05240831e-01, 1.65418029e-01],\n", " [-1.31756335e-01, -9.71664637e-02, -2.11081415e-01,\n", " 3.06450367e-01, 1.31590337e-01, 2.54443496e-01,\n", " -2.31858999e-01, 2.65551567e-01, -2.02050045e-01,\n", " 2.71081388e-01, -1.37878060e-02, -1.70021936e-01,\n", " -1.65380538e-03, 9.54623818e-02, 2.84351885e-01,\n", " -1.01871341e-01, 2.01485753e-02, 9.12380219e-02,\n", " -6.15977198e-02, 4.80147898e-02, -1.39563948e-01,\n", " 3.02652836e-01, 1.64852113e-01, -2.00138420e-01,\n", " 9.66768116e-02, -9.22664106e-02, -9.24981087e-02,\n", " -2.47369722e-01, 2.81407475e-01, 1.83511779e-01],\n", " [-1.95965677e-01, -2.62236357e-01, -2.39685029e-02,\n", " 1.40571550e-01, -5.11714965e-02, 9.83205438e-02,\n", " 1.00091219e-01, 8.76448601e-02, -2.09155291e-01,\n", " -4.81763929e-02, 1.15129888e-01, -1.07420832e-02,\n", " 6.28655702e-02, -1.43948466e-01, 1.83107510e-01,\n", " 1.86010495e-01, -1.79242194e-02, -1.05096996e-02,\n", " 2.98826218e-01, 2.92384177e-02, 1.50228307e-01,\n", " -2.57354379e-02, 1.05159089e-01, 3.26832682e-01,\n", " -6.47494942e-02, -7.94630796e-02, -3.30955148e-01,\n", " -3.33948135e-01, 1.46547139e-01, -1.86091036e-01],\n", " [-4.33291495e-02, 1.88203737e-01, 3.16037238e-02,\n", " 1.19405270e-01, -2.26772234e-01, 9.43277180e-02,\n", " -8.72166008e-02, 2.56006062e-01, -1.48900136e-01,\n", " 9.94460434e-02, 1.87725961e-01, -1.95279077e-01,\n", " 8.27598572e-02, -1.46700770e-01, -1.25416532e-01,\n", " -1.37769252e-01, 9.57628340e-02, -2.98057973e-01,\n", " 1.05415285e-01, -1.18127242e-01, -2.35549033e-01,\n", " -1.76974535e-02, -2.97596186e-01, 4.32238877e-02,\n", " -4.16904092e-02, 4.33115363e-02, -1.08654559e-01,\n", " 3.52643073e-01, 2.74524629e-01, 1.66427344e-02],\n", " [ 1.77624241e-01, -7.08051473e-02, -1.25589028e-01,\n", " -1.33987144e-01, -2.28391200e-01, -2.04036236e-01,\n", " 7.88579136e-02, 1.33848041e-01, -3.16916883e-01,\n", " -1.34889603e-01, -8.19702297e-02, 2.77304649e-02,\n", " 2.47642547e-02, 1.05886340e-01, -2.58318663e-01,\n", " -2.43119746e-01, 3.77329141e-02, 5.44693917e-02,\n", " -1.35345489e-01, -1.10017732e-01, -3.13927293e-01,\n", " 5.12518585e-02, -5.12987375e-04, -1.58919290e-01,\n", " -1.70748994e-01, 2.36288786e-01, 8.46761465e-02,\n", " 1.05235279e-02, 8.87275785e-02, 1.64180309e-01],\n", " [ 3.75420362e-01, 6.81641251e-02, 7.72201270e-02,\n", " -4.03955430e-01, -5.49944937e-02, 8.78675282e-03,\n", " 3.32516432e-01, 1.84740856e-01, -7.99065679e-02,\n", " 1.99955359e-01, 1.41463295e-01, -1.58552006e-01,\n", " -1.59612060e-01, -1.87727511e-01, -1.75989211e-01,\n", " -1.34044841e-01, 2.13294059e-01, -1.30978391e-01,\n", " 1.06950596e-01, 2.87034482e-01, 1.33578539e-01,\n", " 3.30342472e-01, -2.68601805e-01, -2.23761916e-01,\n", " 2.93608367e-01, -3.48806530e-02, 1.48320645e-01,\n", " 1.26249835e-01, -2.08334908e-01, 5.82257062e-02],\n", " [ 1.78356782e-01, -1.20771334e-01, -7.79799819e-02,\n", " 1.64743349e-01, 1.32354870e-01, 1.11938149e-01,\n", " -2.97888666e-02, -1.83453709e-02, 3.70834023e-02,\n", " -3.93104136e-01, 4.38099802e-02, 1.25917226e-01,\n", " -2.03110918e-01, 1.49912372e-01, 8.95697773e-02,\n", " -1.43052682e-01, 3.78358483e-01, 2.53705919e-01,\n", " 9.40865725e-02, 2.99577773e-01, 1.92300975e-03,\n", " -3.35528851e-02, -3.18194628e-01, 8.42917114e-02,\n", " -1.03216350e-01, 1.54624790e-01, 1.52044371e-01,\n", " -3.53974104e-03, -1.56484321e-01, 3.17795128e-02],\n", " [-3.34984601e-01, -1.87048599e-01, 1.26603231e-01,\n", " 2.71421999e-01, -4.17920053e-02, -1.65970325e-02,\n", " -1.58861458e-01, 1.46432027e-01, 1.03171512e-01,\n", " 1.39130712e-01, 2.62030542e-01, 3.82863283e-02,\n", " 1.70419857e-01, 2.81390846e-01, 3.02026719e-02,\n", " -2.17159256e-01, 5.98860085e-02, 2.09414154e-01,\n", " 2.78205037e-01, -3.02840352e-01, 2.17414171e-01,\n", " 6.87691420e-02, -1.62359178e-02, -9.31997299e-02,\n", " 1.67161644e-01, -5.67281246e-02, -1.67869329e-02,\n", " -3.39672238e-01, 4.14884984e-02, 2.41749167e-01]], dtype=float32),\n", " },\n", " Dense_1: {\n", " bias: array([ 6.7055225e-06, 6.7055225e-06, 6.7055225e-06, 6.7055225e-06,\n", " 6.7055225e-06, 6.7055225e-06, -4.6221808e-02, -3.4295321e-02,\n", " 4.9075842e-02, 6.9026187e-02, 5.4597139e-02, 6.7055225e-06,\n", " 6.7055225e-06, -2.3682773e-02, -1.4512062e-02], dtype=float32),\n", " kernel: array([[-2.15783074e-01, -8.00836086e-02, -3.41679364e-01,\n", " -3.61634940e-02, -4.04338688e-02, -1.92787528e-01,\n", " 7.80968368e-02, 3.84697020e-01, -2.70941556e-01,\n", " 3.09228301e-02, -1.12026677e-01, 1.21513918e-01,\n", " 3.84846568e-01, 1.29461914e-01, 3.02034169e-02],\n", " [ 3.03589344e-01, 1.49001822e-01, 2.24479139e-02,\n", " 1.72641233e-01, 1.11690164e-03, -1.60670713e-01,\n", " 1.60862997e-01, -2.07054362e-01, 3.59371305e-03,\n", " 8.14582109e-02, -8.32385570e-02, -2.71430016e-01,\n", " 3.29026878e-01, 6.39057159e-03, -2.08910599e-01],\n", " [ 1.42273605e-02, 2.52373248e-01, 2.65929043e-01,\n", " -7.87675530e-02, 2.57075429e-02, 1.37467578e-01,\n", " -3.03784698e-01, -3.00662845e-01, 2.28536919e-01,\n", " 7.39716142e-02, -5.44494987e-02, 6.82624280e-02,\n", " -1.14752382e-01, -4.36386168e-02, -2.58043408e-03],\n", " [-3.78126502e-02, 3.64015222e-01, -6.23669326e-02,\n", " 1.66056380e-01, -1.78249836e-01, 1.75972238e-01,\n", " -2.43168324e-01, -3.16800296e-01, -6.96922392e-02,\n", " -5.26114106e-02, 1.36705399e-01, -5.23618162e-02,\n", " 1.28879577e-01, 2.59115607e-01, -6.60537034e-02],\n", " [ 1.62194669e-02, -5.36157191e-03, 6.25325292e-02,\n", " 3.77438962e-03, 5.34493476e-02, 9.57345217e-02,\n", " -3.37895215e-01, 1.54839545e-01, -5.76823950e-05,\n", " 3.46475482e-01, -1.44219398e-02, -1.57836825e-02,\n", " 1.02470160e-01, -2.66014189e-02, 1.38210922e-01],\n", " [-1.54244944e-01, -1.89163774e-01, -1.25277504e-01,\n", " -1.53509825e-01, -2.08621144e-01, -3.57687324e-02,\n", " 1.82844996e-02, 1.68684870e-01, 5.94071597e-02,\n", " 3.75792086e-02, 7.39661902e-02, 3.35429460e-02,\n", " 6.90625012e-02, 1.51754677e-01, -2.60873646e-01],\n", " [ 9.05385166e-02, 3.13364267e-01, -1.75755590e-01,\n", " 5.33922911e-02, -1.96638182e-01, 2.29208365e-01,\n", " 2.01122493e-01, 1.39216185e-01, 3.18074644e-01,\n", " -2.20635474e-01, 1.32455468e-01, -4.49397564e-02,\n", " -3.59446257e-02, 3.02191943e-01, -1.93763226e-02],\n", " [-9.32638943e-02, -9.22463536e-02, 2.90949643e-01,\n", " 1.38086572e-01, 1.45970643e-01, 1.11227736e-01,\n", " 1.92929432e-01, -2.20268294e-01, 8.51592422e-03,\n", " -1.14801973e-02, -1.08394876e-01, -4.92374301e-02,\n", " 1.53573290e-01, 3.89962226e-01, -1.53074652e-01],\n", " [ 2.90715694e-03, 1.83663428e-01, 3.76528651e-02,\n", " -1.73860639e-02, 1.83179498e-01, 4.10255790e-03,\n", " 9.65543687e-02, 7.96876550e-02, 2.19800651e-01,\n", " 2.27372974e-01, -1.51361659e-01, 2.04350561e-01,\n", " 1.18743330e-01, -3.37018371e-01, 1.12518296e-01],\n", " [-3.69710326e-02, 5.37259430e-02, -5.55092096e-03,\n", " -4.27453220e-03, -1.99901059e-01, -1.28307149e-01,\n", " 5.43928742e-02, 1.31580710e-01, -1.68890849e-01,\n", " -3.78164351e-01, -6.06988072e-02, 2.29855090e-01,\n", " -1.12242132e-01, 3.79682928e-02, -3.86388510e-01],\n", " [ 2.35311255e-01, 1.66386098e-01, -8.38957876e-02,\n", " 2.03469038e-01, -2.03613058e-01, -7.19606429e-02,\n", " 1.12464979e-01, 2.45347440e-01, 2.39544034e-01,\n", " -1.39806151e-01, -2.63890773e-02, -1.12561002e-01,\n", " -2.70868719e-01, -4.71335649e-03, 1.29959971e-01],\n", " [-5.57020754e-02, -3.46530616e-01, 2.98494935e-01,\n", " -1.66801214e-01, 6.14305437e-02, 9.28813517e-02,\n", " 1.35294989e-01, -1.34979218e-01, -7.89761543e-04,\n", " -2.50806570e-01, 1.01332352e-01, 1.00093633e-01,\n", " 1.26480818e-01, 1.32050708e-01, 2.52038181e-01],\n", " [ 1.38003528e-02, -1.96471453e-01, 1.48797393e-01,\n", " 3.88497114e-02, -1.44033611e-01, 3.50036263e-01,\n", " -3.26102078e-02, -1.19598106e-01, -3.50412369e-01,\n", " -9.01353359e-02, 1.68153301e-01, -1.73634619e-01,\n", " -2.64526069e-01, 1.89368397e-01, -3.03420037e-01],\n", " [ 1.53409019e-01, -1.65645137e-01, 2.75816798e-01,\n", " -2.61331946e-02, 9.50511992e-02, -1.15451679e-01,\n", " -2.91420221e-02, 9.09120291e-02, 1.60791710e-01,\n", " 1.59705788e-01, -1.63535506e-01, -2.39277333e-02,\n", " 2.17306137e-01, -3.75522256e-01, 3.61660779e-01],\n", " [ 2.75457501e-01, -5.11639714e-02, 3.02877873e-02,\n", " 3.83744717e-01, 1.89788371e-01, -3.05498004e-01,\n", " -1.36434168e-01, -9.84745473e-02, -8.30809176e-02,\n", " -1.72840640e-01, 2.08918750e-03, 2.70356536e-01,\n", " -1.43099576e-02, 1.42251849e-02, -2.30559021e-01],\n", " [-1.12811208e-01, -8.90487880e-02, 5.26717752e-02,\n", " -3.34500819e-02, 1.79550871e-01, 1.52729020e-01,\n", " -5.19486815e-02, 1.09064624e-01, 2.16731712e-01,\n", " -5.77685237e-02, 2.93150067e-01, -2.72270977e-01,\n", " 2.27185652e-01, -4.16627526e-02, 8.24269503e-02],\n", " [ 1.12723231e-01, 1.53931171e-01, -1.41348794e-01,\n", " -1.82246700e-01, -2.60128081e-01, 2.28294432e-01,\n", " -1.41973585e-01, -2.20511407e-01, 4.12718713e-01,\n", " 2.26580381e-01, 3.02141160e-02, -4.08996940e-02,\n", " -1.77697018e-01, 2.13378891e-01, 2.34325364e-01],\n", " [-8.89122784e-02, -1.33175939e-01, 1.13867760e-01,\n", " -3.15929264e-01, -5.33000082e-02, -3.32492590e-02,\n", " 1.61763191e-01, -4.45095897e-02, -3.13157290e-02,\n", " 1.08764470e-02, -1.84109345e-01, -4.28237021e-02,\n", " -6.48176223e-02, -5.03456593e-03, 7.79272169e-02],\n", " [ 1.07352838e-01, 1.53558820e-01, 9.59271789e-02,\n", " 1.95107758e-02, -7.16409087e-03, -2.52965033e-01,\n", " -2.34644741e-01, 3.56166244e-01, 1.75924093e-01,\n", " -1.80997580e-01, -2.52380699e-01, -5.56088388e-02,\n", " -2.03573480e-01, -1.34791806e-01, 1.43947557e-01],\n", " [-3.20794344e-01, 1.52731538e-02, 3.17644477e-01,\n", " 3.08467388e-01, 6.71731085e-02, -2.07317039e-01,\n", " 6.92533553e-02, 8.92890543e-02, -5.71464747e-02,\n", " 6.77597970e-02, -2.40857795e-01, 9.92331654e-02,\n", " -3.97704244e-01, -1.44478083e-01, 1.78948924e-01],\n", " [ 1.78290546e-01, -3.73491168e-01, -3.44891042e-01,\n", " -1.85131654e-01, -1.25298679e-01, -3.59349012e-01,\n", " -2.19286203e-01, 4.03494358e-01, -5.62228560e-02,\n", " -1.21513605e-01, 3.14101666e-01, -5.40824682e-02,\n", " 1.32633939e-01, 1.56066120e-02, 2.20855936e-01],\n", " [-1.59709528e-01, -1.92597717e-01, -3.99962455e-01,\n", " -1.37944147e-01, 2.50320375e-01, -3.02701652e-01,\n", " 2.37861603e-01, 3.54276657e-01, -2.65088379e-02,\n", " 2.30067194e-01, 1.21199116e-01, -1.71849012e-01,\n", " -1.15864873e-02, 2.79654086e-01, 6.52542561e-02],\n", " [-3.87499511e-01, -2.14815795e-01, -3.50072563e-01,\n", " -3.74599099e-01, 3.15827131e-03, -3.80146921e-01,\n", " -2.75051177e-01, 4.92296517e-02, -3.42986435e-02,\n", " 2.64439374e-01, 2.24484161e-01, -9.03718919e-02,\n", " -1.66824907e-01, 2.33413532e-01, -2.70957470e-01],\n", " [-2.09726006e-01, -1.01519674e-02, 1.65578052e-01,\n", " 2.08754063e-01, -1.90129966e-01, -3.17800641e-01,\n", " -3.11262459e-02, -6.45868927e-02, 3.97725523e-01,\n", " -2.66408682e-01, 3.11382055e-01, -6.38214052e-02,\n", " -3.96969140e-01, 1.07675835e-01, 1.15448385e-02],\n", " [ 1.90294176e-01, 2.62975395e-02, 1.02480963e-01,\n", " -2.17438534e-01, -2.62192190e-01, 1.49379149e-01,\n", " 1.36339113e-01, -2.20250994e-01, -6.47595972e-02,\n", " 2.78646350e-01, -1.04613230e-01, -3.21846724e-01,\n", " -3.60531211e-02, 4.10647094e-02, -2.01862305e-02],\n", " [ 4.84227687e-02, -1.74249634e-01, 1.22636214e-01,\n", " 3.27275246e-01, -8.30558389e-02, -3.14863682e-01,\n", " 1.06451482e-01, 9.95579511e-02, 7.17695206e-02,\n", " -2.05804944e-01, 4.14260328e-02, -2.82542408e-03,\n", " 1.59710690e-01, 1.95358038e-01, -2.18680993e-01],\n", " [ 1.80844143e-01, 9.28273797e-02, -2.79074967e-01,\n", " -3.21841598e-01, 8.46190900e-02, -1.31674752e-01,\n", " -2.22147837e-01, 6.93762749e-02, 1.08454213e-01,\n", " -1.54395998e-01, -2.52935737e-02, 3.96427214e-02,\n", " -1.77300423e-02, 4.08218354e-02, 1.57031059e-01],\n", " [ 1.72292829e-01, -2.74210989e-01, 3.91593874e-02,\n", " -1.06435806e-01, 1.53482422e-01, -4.07753527e-01,\n", " -1.45187899e-01, -1.97193578e-01, 1.51643381e-01,\n", " 8.71199816e-02, -1.80934519e-02, 3.16381007e-02,\n", " -3.16615790e-01, -8.88970047e-02, -3.15811843e-01],\n", " [-9.74183232e-02, -2.85981894e-02, 3.55937093e-01,\n", " -2.71050155e-01, 3.29447448e-01, 8.03399384e-02,\n", " 3.03259671e-01, -1.92141250e-01, -1.79099917e-01,\n", " 2.89139271e-01, 9.35540348e-02, -2.52466857e-01,\n", " 1.29278764e-01, 3.87039840e-01, -3.96069586e-01],\n", " [ 6.32186532e-02, 8.49261880e-02, -1.27169192e-02,\n", " 1.09564900e-01, 1.01468608e-01, -1.91385388e-01,\n", " 2.12272123e-01, 3.93263578e-01, 3.21644545e-03,\n", " 1.96522057e-01, -1.37627035e-01, 2.72568524e-01,\n", " -9.58500952e-02, -5.87881505e-02, 3.29161406e-01]], dtype=float32),\n", " },\n", " Dense_2: {\n", " bias: array([-1.9683748e-02, -2.1595210e-03, 6.7055225e-06, -5.0753772e-02,\n", " 1.3499202e-01, 4.5890406e-02, -1.5943155e-02, 6.7055225e-06],\n", " dtype=float32),\n", " kernel: array([[ 1.70742482e-01, -9.91618633e-02, -2.29330659e-02,\n", " 1.22218162e-01, -1.06615439e-01, -3.67971659e-02,\n", " 1.41981930e-01, 8.84072632e-02],\n", " [ 1.80515304e-01, 1.18458852e-01, 5.27178884e-01,\n", " -1.53531089e-01, -3.87934148e-02, -1.23133227e-01,\n", " 6.29240274e-03, 1.60064697e-02],\n", " [ 1.46227553e-01, 5.25590360e-01, -3.44348907e-01,\n", " -9.29221213e-02, 1.28932714e-01, -1.73384249e-02,\n", " 1.19381860e-01, -4.11042988e-01],\n", " [ 2.49569118e-01, 2.14901835e-01, 3.24754208e-01,\n", " -4.91983056e-01, -1.14351794e-01, -2.11403668e-02,\n", " 7.69451857e-02, 3.31384718e-01],\n", " [ 4.12717015e-02, -1.21542990e-01, -3.10934186e-01,\n", " 3.62211466e-01, -2.74409771e-01, -5.17052352e-01,\n", " -5.41649759e-02, 5.59365511e-01],\n", " [-4.59556133e-02, -6.34765029e-02, 1.12813324e-01,\n", " 3.72479081e-01, 1.19602963e-01, 2.29919955e-01,\n", " -1.38893351e-01, 4.47091460e-02],\n", " [-6.79109395e-02, -3.46875697e-01, -4.84580278e-01,\n", " 1.12352774e-01, -1.92538321e-01, -3.71922582e-01,\n", " 5.66129565e-01, -2.42807910e-01],\n", " [ 3.35851789e-01, -1.21225804e-01, 3.77379179e-01,\n", " 4.12614942e-01, 4.95961308e-02, -3.28645080e-01,\n", " 4.95315671e-01, -1.31850213e-01],\n", " [ 3.26091319e-01, 6.93900436e-02, 2.03807250e-01,\n", " 1.04140580e-01, 3.88821214e-01, 1.70190245e-01,\n", " -2.11963385e-01, 5.57184219e-04],\n", " [ 1.63282499e-01, -1.22407228e-02, -4.44926322e-01,\n", " -9.26266760e-02, 3.41380268e-01, 4.87684071e-01,\n", " -1.46689430e-01, -2.39341646e-01],\n", " [-2.73013204e-01, 1.98071614e-01, 1.63841352e-01,\n", " -4.02920216e-01, 1.04726478e-01, 3.63073707e-01,\n", " -5.87156415e-03, -2.55927563e-01],\n", " [ 2.71851361e-01, -6.50364012e-02, 5.75378239e-02,\n", " 8.61035287e-02, 4.62561399e-02, 6.84097409e-02,\n", " -3.49434376e-01, 3.20657223e-01],\n", " [ 3.87204051e-01, 1.02552548e-01, -3.67724180e-01,\n", " -1.37631521e-01, -2.76330113e-03, -9.74713266e-03,\n", " -2.03522891e-02, -3.78593743e-01],\n", " [ 1.42845362e-01, -1.62034005e-01, -4.73217815e-01,\n", " 2.21667886e-01, 9.57852453e-02, -3.21318150e-01,\n", " 1.93716571e-01, -1.64225832e-01],\n", " [-8.96123946e-02, 1.43070787e-01, -3.05288434e-02,\n", " 4.30531144e-01, -3.90004218e-01, 5.45606494e-01,\n", " -9.91108418e-02, -3.92888844e-01]], dtype=float32),\n", " },\n", " Dense_3: {\n", " bias: array([0.3012179], dtype=float32),\n", " kernel: array([[-0.04792835],\n", " [ 0.6472556 ],\n", " [ 0.19396764],\n", " [-0.19879013],\n", " [ 0.4868285 ],\n", " [ 0.70990056],\n", " [-0.05229847],\n", " [-0.20487049]], dtype=float32),\n", " },\n", " },\n", "})\n" ] } ], "source": [ "params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)\n", "params = sf.reveal(params_spu)\n", "print(params)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lastly, let's validate the model." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "auc=0.9921388797903701\n" ] } ], "source": [ "X_test, y_test = breast_cancer(train=False)\n", "auc = validate_model(params, X_test, y_test)\n", "print(f'auc={auc}')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the end of the lab." ] } ], "metadata": { "kernelspec": { "display_name": "sf", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "db45a4cb4cd37a8de684dfb7fcf899b68fccb8bd32d97c5ad13e5de1245c0986" } } }, "nbformat": 4, "nbformat_minor": 2 }