{ "cells": [ { "cell_type": "markdown", "id": "cef3429e-11bc-4799-9ec7-4ad082fce0e7", "metadata": { "tags": [] }, "source": [ "# SPU Quickstart\n", "\n", "## Program with JAX\n", "\n", "SPU, as an [XLA](https://www.tensorflow.org/xla) backend, does not provide a high-level programming API by itself, instead, we can use any API that supports the XLA backend to program. In this tutorial, we use [JAX](https://github.com/google/jax) as the programming API and demonstrate how to run a JAX program on SPU.\n", "\n", "JAX is an AI framework from Google. Users can write the program in NumPy syntax, and let JAX translate it to GPU/TPU for acceleration, please read the following pages before you start:\n", "\n", "- [JAX Quickstart](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)\n", "- [How to Think in JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)\n", "- [JAX - The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)\n", "\n", "Now we start to write some simple JAX code." ] }, { "cell_type": "code", "execution_count": 1, "id": "867c35e5-8c0a-4f76-9f77-6da8fc6d2e86", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "x = [59]\n", "y = [53]\n", "x>y = [ True]\n" ] } ], "source": [ "import numpy as np\n", "import jax.numpy as jnp\n", "\n", "def make_rand():\n", " np.random.seed()\n", " return np.random.randint(100, size=(1, ))\n", "\n", "def greater(x, y):\n", " return jnp.greater(x, y)\n", "\n", "x = make_rand()\n", "y = make_rand()\n", "ans = greater(x, y)\n", "\n", "print(f\"x = {x}\")\n", "print(f\"y = {y}\")\n", "print(f\"x>y = {ans}\")" ] }, { "cell_type": "markdown", "id": "4998ee4d-7069-4ebb-8f16-9f070f59ca83", "metadata": {}, "source": [ "The above code snippet creates two random variables and compares which one is greater. Yes, the code snippet is not interesting yet~" ] }, { "cell_type": "markdown", "id": "0c08907a-1aec-424f-8525-d3efdf2a9903", "metadata": { "tags": [] }, "source": [ "## Program with SPU\n", "\n", "Now, let's convert the above code to an SPU program.\n", "\n", "### A Quick introduction to device system\n", "\n", "MPC programs are \"designed\" to be used in distributed way. In this tutorial, we use SPU builtin distributed framework for demonstration.\n", "\n", "> Warn: it's for demonstration purpose only, you should use an industrial framework like SecretFlow in production.\n", "\n", "To start the ppd cluster. In a separate terminal, run\n", "\n", "```sh\n", "python -m spu.utils.distributed up\n", "```\n", "\n", "This command starts multi-processes to simulate parties that do not trust each other. Please keep the terminal alive." ] }, { "cell_type": "code", "execution_count": 2, "id": "5affb5d2-53ed-4114-9178-160fe4a9853e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2022-05-12 14:57:18.017] [info] [thread_pool.cc:18] Create a fixed thread pool with size 64\n" ] } ], "source": [ "import spu.utils.distributed as ppd\n", "\n", "# initialized the distributed environment.\n", "ppd.init(ppd.SAMPLE_NODES_DEF, ppd.SAMPLE_DEVICES_DEF)" ] }, { "cell_type": "code", "execution_count": 3, "id": "8c59a098-0ed3-4a0f-9588-53a679902cb9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'node:0': '127.0.0.1:9327',\n", " 'node:1': '127.0.0.1:9328',\n", " 'node:2': '127.0.0.1:9329'}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ppd.current().nodes_def" ] }, { "cell_type": "code", "execution_count": 4, "id": "bee367ab-a76e-463c-b2f7-8a8aac6436b9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'SPU': SPU(SPU) hosted by: ['127.0.0.1:9327', '127.0.0.1:9328', '127.0.0.1:9329'],\n", " 'P1': PYU(P1) hosted by: 127.0.0.1:9327,\n", " 'P2': PYU(P2) hosted by: 127.0.0.1:9328}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ppd.current().devices" ] }, { "cell_type": "markdown", "id": "3c2f740f-9ca6-4cc0-a809-5d1b752e2110", "metadata": {}, "source": [ "`ppd.init` initialize the SPU device system on the given cluster.\n", "\n", "The cluster has three nodes, each node is a process that listens on a given port.\n", "\n", "The 3 physical nodes construct 3 virtual devices.\n", "\n", "- `P1` and `P2` are so called `PYU Device`, which is just a simple Python device that can run a python program.\n", "- `SPU` is a virtual device hosted by all 3-nodes, which use MPC protocols to compute securely.\n", "\n", "Virtually, it looks like below picture.\n", "\n", "![alt text](../imgs/device/server_aided.svg)\n", "\n", "- On the left side, there are three physical nodes, a circle means the node runs a `PYU Device` and a triangle means the node runs a `SPU Device Slice`.\n", "- On the right side, its virtual device layout is constructed by the left physical node.\n", "\n", "We can also check the detail of `SPU device`." ] }, { "cell_type": "code", "execution_count": 5, "id": "cdb5758d-183d-4b6a-9f37-7a51e290823c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "name: SPU\n", "hosted by: ['127.0.0.1:9327', '127.0.0.1:9328', '127.0.0.1:9329']\n", "internal addrs: ['127.0.0.1:9437', '127.0.0.1:9438', '127.0.0.1:9439']\n", "protocol: ABY3\n", "field: FM128\n", "enable_pphlo_profile: true\n", "\n" ] } ], "source": [ "print(ppd.device('SPU').details())" ] }, { "cell_type": "markdown", "id": "4172b68a-e926-4467-911b-e92d4386b0ee", "metadata": {}, "source": [ "The `SPU` device uses `ABY3` as the its backend protocol and runs on `Ring128` field.\n", "\n", "### Move JAX program to SPU\n", "\n", "Now, let's move the JAX program from CPU to SPU." ] }, { "cell_type": "code", "execution_count": 6, "id": "2ce13b7c-cf03-471b-8387-2609d8bce24b", "metadata": {}, "outputs": [], "source": [ "# run make_rand on P1, the value is visible for P1 only.\n", "x = ppd.device(\"P1\")(make_rand)()\n", "\n", "# run make_rand on P2, the value is visible for P2 only.\n", "y = ppd.device(\"P2\")(make_rand)()\n", "\n", "# run greater on SPU, it automatically fetches x/y from P1/P2 (as ciphertext), and compute the result securely.\n", "ans = ppd.device(\"SPU\")(greater)(x, y)" ] }, { "cell_type": "markdown", "id": "862ec305-6746-45e7-8773-b8ef41282374", "metadata": {}, "source": [ "`ppd.device(\"P1\")(make_rand)` convert a python function to a `DeviceFunction` which will be called on `P1` device.\n", "\n", "The terminal that starts the cluster will print log like this, which means the `make_rand` function is relocated to another node and executed there.\n", "\n", "```sh\n", "[2022-05-03 19:17:44,363] [Process-1] Run: make_rand from node:0\n", "[2022-05-03 19:17:44,373] [Process-2] Run: make_rand from node:1\n", "```" ] }, { "cell_type": "markdown", "id": "70d9041b-86ed-426e-bea2-215a8fd67eb6", "metadata": {}, "source": [ "The result of `make_rand` is also stored on `P1` and invisible for other device/node. For example, when printing them, all the above objects are `DeviceObject`, the plaintext object is invisible." ] }, { "cell_type": "code", "execution_count": 7, "id": "df54bb60-5a5c-4fa7-ae08-7cb4c4aa8c9a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(DeviceObject(140183709095536 at P1),\n", " DeviceObject(140183709050912 at P2),\n", " DeviceObject(140183709127968 at SPU))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x, y, ans" ] }, { "cell_type": "markdown", "id": "85bc9088-f645-4263-a08b-fbebd754734d", "metadata": {}, "source": [ "And finally, we can reveal the result via `ppd.get`, which will fetch the plaintext from devices to this host(notebook)." ] }, { "cell_type": "code", "execution_count": 8, "id": "e6d4e7d2-533a-41ce-b4ce-b09d8a969a55", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('x>y = ', array([False]))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"x>y = \", ppd.get(ans)" ] }, { "cell_type": "markdown", "id": "454d4b43-a214-4b79-9fd4-bd41fac613fd", "metadata": {}, "source": [ "The result shows that the random variable `x` from `P1` is greater than `y` from `P2`, we can check the result by revealing origin inputs." ] }, { "cell_type": "code", "execution_count": 9, "id": "dc41de38-585e-4aeb-b481-d1b7ba8c3370", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([33]), array([58]), array([False]))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_revealed = ppd.get(x)\n", "y_revealed = ppd.get(y)\n", "x_revealed, y_revealed, np.greater(x_revealed, y_revealed)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "cafda5f8-25c3-4361-af44-88e25654bd1a", "metadata": {}, "source": [ "With above code, we implements the classic [Yao's millionares' problem](https://en.wikipedia.org/wiki/Yao%27s_Millionaires%27_problem) on SPU. Note:\n", "\n", "- SPU re-uses `jax` api, and translates it to SPU executable, there is no `import spu.jax as jax` stuffs.\n", "- SPU hides secure semantic, to compute a function *securely*, just simply mark it on SPU. " ] }, { "cell_type": "markdown", "id": "42a54e49-3e84-4da2-9c3b-768a1a6951b4", "metadata": {}, "source": [ "## Logistic regression\n", "\n", "Now, let's check a more complicated example, privacy-preserving logistic regression.\n", "\n", "First, write the raw JAX program." ] }, { "cell_type": "code", "execution_count": 10, "id": "5e5188a2-ed81-402a-af7b-ae6b1c40b038", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn import metrics\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "def sigmoid(x):\n", " return 1 / (1 + jnp.exp(-x))\n", "\n", "def predict(x, w, b):\n", " return sigmoid(jnp.matmul(x, w) + b)\n", "\n", "def loss(x, y, w, b):\n", " pred = predict(x, w, b)\n", " label_prob = pred * y + (1 - pred) * (1 - y)\n", " return -jnp.mean(jnp.log(label_prob))\n", "\n", "def train(feature, label, n_epochs=10, n_iters=10, step_size=0.1):\n", " w = jnp.zeros(feature.shape[1])\n", " b = 0.0\n", "\n", " xs = jnp.array_split(feature, n_iters, axis=0)\n", " ys = jnp.array_split(label, n_iters, axis=0)\n", "\n", " def body_fun(_, loop_carry):\n", " w_, b_ = loop_carry\n", " for (x, y) in zip(xs, ys):\n", " grad = jax.grad(loss, argnums=(2, 3))(x, y, w_, b_)\n", " w_ -= grad[0] * step_size\n", " b_ -= grad[1] * step_size\n", "\n", " return w_, b_\n", "\n", " return jax.lax.fori_loop(0, n_epochs, body_fun, (w, b))\n", "\n", "def load_dataset():\n", " from sklearn.datasets import load_breast_cancer\n", " ds = load_breast_cancer()\n", " def normalize(x):\n", " return (x - np.min(x)) / (np.max(x) - np.min(x))\n", " return normalize(ds['data']), ds['target'].astype(dtype=np.float64)" ] }, { "cell_type": "markdown", "id": "49dc33e9-4d66-45fb-917a-0098a5ff7b09", "metadata": {}, "source": [ "Run the program on CPU, the result (AUC) works as expected." ] }, { "cell_type": "code", "execution_count": 11, "id": "04b0e68b-4f6b-416b-9a60-96df8c2ea352", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AUC= 0.9636779239997886\n" ] } ], "source": [ "x, y = load_dataset()\n", "w, b = jax.jit(train)(x, y)\n", "\n", "print(\"AUC=\", metrics.roc_auc_score(y, predict(x, w, b)))" ] }, { "cell_type": "markdown", "id": "52b4aeab-1619-4c10-8eef-37d23dbf6ef1", "metadata": {}, "source": [ "Now, use `ppd.device` to make the above code run on SPU." ] }, { "cell_type": "code", "execution_count": 12, "id": "e0403dd8-5c41-415a-afdf-c9f47b7522cf", "metadata": {}, "outputs": [], "source": [ "# load features on Alice\n", "X, _ = ppd.device(\"P1\")(load_dataset)()\n", "\n", "# load labels on Bob\n", "_, Y = ppd.device(\"P2\")(load_dataset)()\n", "\n", "# run the algorithm on SPU\n", "W, B = ppd.device(\"SPU\")(train)(X, Y)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "88210bba-9c79-4645-99f9-bf7b6e61220f", "metadata": {}, "source": [ "`P1` loads the features(X) only, `P2` loads the labels(Y) only, and for convenience, P1/P2 uses the same dataset, but only loads partial (either feature or label). Now `P1 and P2` want to train the model without telling each other the privacy data, so they use SPU to run the `train` function.\n", "\n", "It takes a little while to run the above program since privacy preserving program runs much slower than plaintext version.\n", "\n", "The parameters W and bias B are also located at SPU (no one knows the result). Finally, let's reveal the parameters to check correctness." ] }, { "cell_type": "code", "execution_count": 13, "id": "39c95152-d12c-46df-8c0e-5b8cfa7590d8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AUC= 0.9636779239997886\n" ] } ], "source": [ "w_ = ppd.get(W)\n", "b_ = ppd.get(B)\n", "\n", "print(\"AUC=\", metrics.roc_auc_score(y, predict(x, w_, b_)))" ] }, { "attachments": {}, "cell_type": "markdown", "id": "cea959e1-f2cf-49f5-9cab-cec0e49877a5", "metadata": {}, "source": [ "For this simple dataset, AUC metric shows exactly the same, but since SPU uses fixed point arithmetic, which is not as accurate as IEEE floating point arithmetic, the trained parameters are not exactly the same." ] }, { "cell_type": "code", "execution_count": 14, "id": "5df1d557-ee85-45d1-ac02-0c574a18fcb5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU result: [-9.63748374e-04 5.85974485e-04 -7.72659713e-03 -1.89471960e-01\n", " 7.29127760e-06 -2.13097555e-05 -5.01855429e-05 -2.71105491e-05\n", " 1.41814962e-05 8.52456287e-06 -1.20480654e-04 1.68244922e-04\n", " -8.66169750e-04 -2.22656243e-02 1.16955653e-06 -2.46253649e-06\n", " -4.29665033e-06 -1.23381062e-06 2.82509859e-06 2.76519387e-07\n", " -1.99214974e-03 2.93411314e-04 -1.48938354e-02 -3.46419692e-01\n", " 6.96186044e-06 -7.03605692e-05 -1.17880481e-04 -4.30178479e-05\n", " 1.00183543e-05 4.66188339e-06]\n" ] } ], "source": [ "print(\"CPU result: \", w)" ] }, { "cell_type": "code", "execution_count": 15, "id": "698fe8fe-33f0-4f63-b5ee-0cf68241d2e1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SPU result: [-9.54568386e-04 5.94928861e-04 -7.72123039e-03 -1.89489499e-01\n", " 1.71959400e-05 -1.12503767e-05 -4.02480364e-05 -1.71363354e-05\n", " 2.40653753e-05 1.85817480e-05 -1.10551715e-04 1.78068876e-04\n", " -8.56399536e-04 -2.22569108e-02 1.10864639e-05 7.45058060e-06\n", " 5.69224358e-06 8.76188278e-06 1.27851963e-05 1.01476908e-05\n", " -1.98297203e-03 3.01986933e-04 -1.48890316e-02 -3.46443683e-01\n", " 1.69873238e-05 -6.04242086e-05 -1.07973814e-04 -3.31550837e-05\n", " 1.99228525e-05 1.44690275e-05]\n" ] } ], "source": [ "print(\"SPU result: \", w_)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "3037a3c8-2f70-4bab-9789-50591848dfd4", "metadata": {}, "source": [ "## Visibility inference\n", "\n", "SPU compiler/runtime pipeline works together to protect privacy information. \n", "\n", "When an object is transferred from PYU to SPU device, the data is first encrypted (secret shared) and then sent to SPU hosts.\n", "\n", "The SPU compiler deduces the visibility of the entire program, including all nodes in the compute DAG, from the input's visibility with a very simple rule: for each SPU instruction, when any input is a secret, the output is a secret. In this way, the `secure semantic` is propagated through the entire DAG.\n", "\n", "For example," ] }, { "cell_type": "code", "execution_count": 16, "id": "a06fe664-8609-4763-b568-bccdc88d25b7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DeviceFunction(140181491712736 at SPU)\n" ] } ], "source": [ "@ppd.device(\"SPU\")\n", "def sigmoid(x):\n", " return 1 / (1 + jnp.exp(-x))\n", "\n", "print(sigmoid)" ] }, { "cell_type": "markdown", "id": "78a76c60-a2df-4922-955b-be968885485c", "metadata": {}, "source": [ "It shows that `ppd.device` decorated `sigmoid` is a `DeviceFunction` which could be launched by SPU.\n", "\n", "We can print the SPU bytecode via" ] }, { "cell_type": "code", "execution_count": 17, "id": "b44ee111-b188-4c00-96f9-d7847ebe253a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "module @xla_computation_sigmoid.11 {\n", " func @main(%arg0: tensor<3x3x!pphlo.pub>) -> tensor<3x3x!pphlo.pub> {\n", " %0 = \"pphlo.constant\"() {value = dense<1.000000e+00> : tensor<3x3xf32>} : () -> tensor<3x3x!pphlo.pub>\n", " %1 = \"pphlo.negate\"(%arg0) : (tensor<3x3x!pphlo.pub>) -> tensor<3x3x!pphlo.pub>\n", " %2 = \"pphlo.exponential\"(%1) : (tensor<3x3x!pphlo.pub>) -> tensor<3x3x!pphlo.pub>\n", " %3 = \"pphlo.add\"(%2, %0) : (tensor<3x3x!pphlo.pub>, tensor<3x3x!pphlo.pub>) -> tensor<3x3x!pphlo.pub>\n", " %4 = \"pphlo.divide\"(%0, %3) : (tensor<3x3x!pphlo.pub>, tensor<3x3x!pphlo.pub>) -> tensor<3x3x!pphlo.pub>\n", " return %4 : tensor<3x3x!pphlo.pub>\n", " }\n", "}\n", "\n" ] } ], "source": [ "print(sigmoid.dump_pphlo(np.random.rand(3, 3)))" ] }, { "cell_type": "markdown", "id": "5a694008-5745-4002-af70-1fef76d8442c", "metadata": {}, "source": [ "It shows that the function type signature is:\n", "```\n", "tensor<3x3x!pphlo.pub>) -> tensor<3x3x!pphlo.pub>\n", "```\n", "\n", "Note, since the input is random from the driver (this notebook), which is not privacy information by default, so the input is `tensor<3x3x!pphlo.pub>`, which means it accepts a `3x3 public f32 tensor`.\n", "\n", "The compiler deduces the whole program's visibility type, and finds output should be `tensor<3x3x!pphlo.pub>`, which means a `3x3 public f32 tensor`.\n", "\n", "Now let's generate input from `P1` and run the program again." ] }, { "cell_type": "code", "execution_count": 18, "id": "7bad0106-133c-4465-b770-b7ad0fe4760b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "module @xla_computation_sigmoid.12 {\n", " func @main(%arg0: tensor<1x!pphlo.sec>) -> tensor<1x!pphlo.sec> {\n", " %0 = \"pphlo.constant\"() {value = dense<1.000000e+00> : tensor<1xf32>} : () -> tensor<1x!pphlo.pub>\n", " %1 = \"pphlo.negate\"(%arg0) : (tensor<1x!pphlo.sec>) -> tensor<1x!pphlo.sec>\n", " %2 = \"pphlo.convert\"(%1) : (tensor<1x!pphlo.sec>) -> tensor<1x!pphlo.sec>\n", " %3 = \"pphlo.exponential\"(%2) : (tensor<1x!pphlo.sec>) -> tensor<1x!pphlo.sec>\n", " %4 = \"pphlo.add\"(%3, %0) : (tensor<1x!pphlo.sec>, tensor<1x!pphlo.pub>) -> tensor<1x!pphlo.sec>\n", " %5 = \"pphlo.divide\"(%0, %4) : (tensor<1x!pphlo.pub>, tensor<1x!pphlo.sec>) -> tensor<1x!pphlo.sec>\n", " return %5 : tensor<1x!pphlo.sec>\n", " }\n", "}\n", "\n" ] } ], "source": [ "X = ppd.device(\"P1\")(make_rand)()\n", "\n", "print(sigmoid.dump_pphlo(X))" ] }, { "cell_type": "markdown", "id": "4b6fb169-de48-48f7-bc11-21d0da1658c6", "metadata": {}, "source": [ "Since the input comes from `P1`, which is private, so the function signature becomes:\n", "\n", "```\n", "tensor<1x!pphlo.sec>) -> tensor<1x!pphlo.sec\n", "```\n", "\n", "This means accepts `1 secret i32` data and outputs `1 secret f32` data, inside the compiled function, all internal instruction's visibility type is also correctly deduced.\n", "\n", "With the JIT(just in time) type deduction, SPU protects the clients' privacy." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 5 }