SPU Quickstart#

Program with JAX#

SPU, as an XLA backend, does not provide a high-level programming API by itself, instead, we can use any API that supports the XLA backend to program. In this tutorial, we use JAX as the programming API and demonstrate how to run a JAX program on SPU.

JAX is an AI framework from Google. Users can write the program in NumPy syntax, and let JAX translate it to GPU/TPU for acceleration, please read the following pages before you start:

Now we start to write some simple JAX code.

[1]:
import numpy as np
import jax.numpy as jnp

def make_rand():
    np.random.seed()
    return np.random.randint(100, size=(1, ))

def greater(x, y):
    return jnp.greater(x, y)

x = make_rand()
y = make_rand()
ans = greater(x, y)

print(f"x = {x}")
print(f"y = {y}")
print(f"x>y = {ans}")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
x = [59]
y = [53]
x>y = [ True]

The above code snippet creates two random variables and compares which one is greater. Yes, the code snippet is not interesting yet~

Program with SPU#

Now, let’s convert the above code to an SPU program.

A Quick introduction to device system#

MPC programs are “designed” to be used in distributed way. In this tutorial, we use SPU builtin distributed framework for demonstration.

Warn: it’s for demonstration purpose only, you should use an industrial framework like SecretFlow in production.

To start the ppd cluster. In a separate terminal, run

python -m spu.utils.distributed up

This command starts multi-processes to simulate parties that do not trust each other. Please keep the terminal alive.

[2]:
import spu.utils.distributed as ppd

# initialized the distributed environment.
ppd.init(ppd.SAMPLE_NODES_DEF, ppd.SAMPLE_DEVICES_DEF)
[2022-05-12 14:57:18.017] [info] [thread_pool.cc:18] Create a fixed thread pool with size 64
[3]:
ppd.current().nodes_def
[3]:
{'node:0': '127.0.0.1:9327',
 'node:1': '127.0.0.1:9328',
 'node:2': '127.0.0.1:9329'}
[4]:
ppd.current().devices
[4]:
{'SPU': SPU(SPU) hosted by: ['127.0.0.1:9327', '127.0.0.1:9328', '127.0.0.1:9329'],
 'P1': PYU(P1) hosted by: 127.0.0.1:9327,
 'P2': PYU(P2) hosted by: 127.0.0.1:9328}

ppd.init initialize the SPU device system on the given cluster.

The cluster has three nodes, each node is a process that listens on a given port.

The 3 physical nodes construct 3 virtual devices.

  • P1 and P2 are so called PYU Device, which is just a simple Python device that can run a python program.

  • SPU is a virtual device hosted by all 3-nodes, which use MPC protocols to compute securely.

Virtually, it looks like below picture.

alt text

  • On the left side, there are three physical nodes, a circle means the node runs a PYU Device and a triangle means the node runs a SPU Device Slice.

  • On the right side, its virtual device layout is constructed by the left physical node.

We can also check the detail of SPU device.

[5]:
print(ppd.device('SPU').details())
name: SPU
hosted by: ['127.0.0.1:9327', '127.0.0.1:9328', '127.0.0.1:9329']
internal addrs: ['127.0.0.1:9437', '127.0.0.1:9438', '127.0.0.1:9439']
protocol: ABY3
field: FM128
enable_pphlo_profile: true

The SPU device uses ABY3 as the its backend protocol and runs on Ring128 field.

Move JAX program to SPU#

Now, let’s move the JAX program from CPU to SPU.

[6]:
# run make_rand on P1, the value is visible for P1 only.
x = ppd.device("P1")(make_rand)()

# run make_rand on P2, the value is visible for P2 only.
y = ppd.device("P2")(make_rand)()

# run greater on SPU, it automatically fetches x/y from P1/P2 (as ciphertext), and compute the result securely.
ans = ppd.device("SPU")(greater)(x, y)

ppd.device("P1")(make_rand) convert a python function to a DeviceFunction which will be called on P1 device.

The terminal that starts the cluster will print log like this, which means the make_rand function is relocated to another node and executed there.

[2022-05-03 19:17:44,363] [Process-1] Run: make_rand from node:0
[2022-05-03 19:17:44,373] [Process-2] Run: make_rand from node:1

The result of make_rand is also stored on P1 and invisible for other device/node. For example, when printing them, all the above objects are DeviceObject, the plaintext object is invisible.

[7]:
x, y, ans
[7]:
(DeviceObject(140183709095536 at P1),
 DeviceObject(140183709050912 at P2),
 DeviceObject(140183709127968 at SPU))

And finally, we can reveal the result via ppd.get, which will fetch the plaintext from devices to this host(notebook).

[8]:
"x>y = ", ppd.get(ans)
[8]:
('x>y = ', array([False]))

The result shows that the random variable x from P1 is greater than y from P2, we can check the result by revealing origin inputs.

[9]:
x_revealed = ppd.get(x)
y_revealed = ppd.get(y)
x_revealed, y_revealed, np.greater(x_revealed, y_revealed)
[9]:
(array([33]), array([58]), array([False]))

With above code, we implements the classic Yao’s millionares’ problem on SPU. Note:

  • SPU re-uses jax api, and translates it to SPU executable, there is no import spu.jax as jax stuffs.

  • SPU hides secure semantic, to compute a function securely, just simply mark it on SPU.

Logistic regression#

Now, let’s check a more complicated example, privacy-preserving logistic regression.

First, write the raw JAX program.

[10]:
import numpy as np
from sklearn import metrics

import jax
import jax.numpy as jnp

def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

def predict(x, w, b):
    return sigmoid(jnp.matmul(x, w) + b)

def loss(x, y, w, b):
    pred = predict(x, w, b)
    label_prob = pred * y + (1 - pred) * (1 - y)
    return -jnp.mean(jnp.log(label_prob))

def train(feature, label, n_epochs=10, n_iters=10, step_size=0.1):
    w = jnp.zeros(feature.shape[1])
    b = 0.0

    xs = jnp.array_split(feature, n_iters, axis=0)
    ys = jnp.array_split(label, n_iters, axis=0)

    def body_fun(_, loop_carry):
        w_, b_ = loop_carry
        for (x, y) in zip(xs, ys):
            grad = jax.grad(loss, argnums=(2, 3))(x, y, w_, b_)
            w_ -= grad[0] * step_size
            b_ -= grad[1] * step_size

        return w_, b_

    return jax.lax.fori_loop(0, n_epochs, body_fun, (w, b))

def load_dataset():
    from sklearn.datasets import load_breast_cancer
    ds = load_breast_cancer()
    def normalize(x):
        return (x - np.min(x)) / (np.max(x) - np.min(x))
    return normalize(ds['data']), ds['target'].astype(dtype=np.float64)

Run the program on CPU, the result (AUC) works as expected.

[11]:
x, y = load_dataset()
w, b = jax.jit(train)(x, y)

print("AUC=", metrics.roc_auc_score(y, predict(x, w, b)))
AUC= 0.9636779239997886

Now, use ppd.device to make the above code run on SPU.

[12]:
# load features on Alice
X, _ = ppd.device("P1")(load_dataset)()

# load labels on Bob
_, Y = ppd.device("P2")(load_dataset)()

# run the algorithm on SPU
W, B = ppd.device("SPU")(train)(X, Y)

P1 loads the features(X) only, P2 loads the labels(Y) only, and for convenience, P1/P2 uses the same dataset, but only loads partial (either feature or label). Now P1 and P2 want to train the model without telling each other the privacy data, so they use SPU to run the train function.

It takes a little while to run the above program since privacy preserving program runs much slower than plaintext version.

The parameters W and bias B are also located at SPU (no one knows the result). Finally, let’s reveal the parameters to check correctness.

[13]:
w_ = ppd.get(W)
b_ = ppd.get(B)

print("AUC=", metrics.roc_auc_score(y, predict(x, w_, b_)))
AUC= 0.9636779239997886

For this simple dataset, AUC metric shows exactly the same, but since SPU uses fixed point arithmetic, which is not as accurate as IEEE floating point arithmetic, the trained parameters are not exactly the same.

[14]:
print("CPU result: ", w)
CPU result:  [-9.63748374e-04  5.85974485e-04 -7.72659713e-03 -1.89471960e-01
  7.29127760e-06 -2.13097555e-05 -5.01855429e-05 -2.71105491e-05
  1.41814962e-05  8.52456287e-06 -1.20480654e-04  1.68244922e-04
 -8.66169750e-04 -2.22656243e-02  1.16955653e-06 -2.46253649e-06
 -4.29665033e-06 -1.23381062e-06  2.82509859e-06  2.76519387e-07
 -1.99214974e-03  2.93411314e-04 -1.48938354e-02 -3.46419692e-01
  6.96186044e-06 -7.03605692e-05 -1.17880481e-04 -4.30178479e-05
  1.00183543e-05  4.66188339e-06]
[15]:
print("SPU result: ", w_)
SPU result:  [-9.54568386e-04  5.94928861e-04 -7.72123039e-03 -1.89489499e-01
  1.71959400e-05 -1.12503767e-05 -4.02480364e-05 -1.71363354e-05
  2.40653753e-05  1.85817480e-05 -1.10551715e-04  1.78068876e-04
 -8.56399536e-04 -2.22569108e-02  1.10864639e-05  7.45058060e-06
  5.69224358e-06  8.76188278e-06  1.27851963e-05  1.01476908e-05
 -1.98297203e-03  3.01986933e-04 -1.48890316e-02 -3.46443683e-01
  1.69873238e-05 -6.04242086e-05 -1.07973814e-04 -3.31550837e-05
  1.99228525e-05  1.44690275e-05]

Visibility inference#

SPU compiler/runtime pipeline works together to protect privacy information.

When an object is transferred from PYU to SPU device, the data is first encrypted (secret shared) and then sent to SPU hosts.

The SPU compiler deduces the visibility of the entire program, including all nodes in the compute DAG, from the input’s visibility with a very simple rule: for each SPU instruction, when any input is a secret, the output is a secret. In this way, the secure semantic is propagated through the entire DAG.

For example,

[16]:
@ppd.device("SPU")
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

print(sigmoid)
DeviceFunction(140181491712736 at SPU)

It shows that ppd.device decorated sigmoid is a DeviceFunction which could be launched by SPU.

We can print the SPU bytecode via

[17]:
print(sigmoid.dump_pphlo(np.random.rand(3, 3)))
module @xla_computation_sigmoid.11 {
  func @main(%arg0: tensor<3x3x!pphlo.pub<f32>>) -> tensor<3x3x!pphlo.pub<f32>> {
    %0 = "pphlo.constant"() {value = dense<1.000000e+00> : tensor<3x3xf32>} : () -> tensor<3x3x!pphlo.pub<f32>>
    %1 = "pphlo.negate"(%arg0) : (tensor<3x3x!pphlo.pub<f32>>) -> tensor<3x3x!pphlo.pub<f32>>
    %2 = "pphlo.exponential"(%1) : (tensor<3x3x!pphlo.pub<f32>>) -> tensor<3x3x!pphlo.pub<f32>>
    %3 = "pphlo.add"(%2, %0) : (tensor<3x3x!pphlo.pub<f32>>, tensor<3x3x!pphlo.pub<f32>>) -> tensor<3x3x!pphlo.pub<f32>>
    %4 = "pphlo.divide"(%0, %3) : (tensor<3x3x!pphlo.pub<f32>>, tensor<3x3x!pphlo.pub<f32>>) -> tensor<3x3x!pphlo.pub<f32>>
    return %4 : tensor<3x3x!pphlo.pub<f32>>
  }
}

It shows that the function type signature is:

tensor<3x3x!pphlo.pub<f32>>) -> tensor<3x3x!pphlo.pub<f32>>

Note, since the input is random from the driver (this notebook), which is not privacy information by default, so the input is tensor<3x3x!pphlo.pub<f32>>, which means it accepts a 3x3 public f32 tensor.

The compiler deduces the whole program’s visibility type, and finds output should be tensor<3x3x!pphlo.pub<f32>>, which means a 3x3 public f32 tensor.

Now let’s generate input from P1 and run the program again.

[18]:
X = ppd.device("P1")(make_rand)()

print(sigmoid.dump_pphlo(X))
module @xla_computation_sigmoid.12 {
  func @main(%arg0: tensor<1x!pphlo.sec<i32>>) -> tensor<1x!pphlo.sec<f32>> {
    %0 = "pphlo.constant"() {value = dense<1.000000e+00> : tensor<1xf32>} : () -> tensor<1x!pphlo.pub<f32>>
    %1 = "pphlo.negate"(%arg0) : (tensor<1x!pphlo.sec<i32>>) -> tensor<1x!pphlo.sec<i32>>
    %2 = "pphlo.convert"(%1) : (tensor<1x!pphlo.sec<i32>>) -> tensor<1x!pphlo.sec<f32>>
    %3 = "pphlo.exponential"(%2) : (tensor<1x!pphlo.sec<f32>>) -> tensor<1x!pphlo.sec<f32>>
    %4 = "pphlo.add"(%3, %0) : (tensor<1x!pphlo.sec<f32>>, tensor<1x!pphlo.pub<f32>>) -> tensor<1x!pphlo.sec<f32>>
    %5 = "pphlo.divide"(%0, %4) : (tensor<1x!pphlo.pub<f32>>, tensor<1x!pphlo.sec<f32>>) -> tensor<1x!pphlo.sec<f32>>
    return %5 : tensor<1x!pphlo.sec<f32>>
  }
}

Since the input comes from P1, which is private, so the function signature becomes:

tensor<1x!pphlo.sec<i32>>) -> tensor<1x!pphlo.sec<f32>

This means accepts 1 secret i32 data and outputs 1 secret f32 data, inside the compiled function, all internal instruction’s visibility type is also correctly deduced.

With the JIT(just in time) type deduction, SPU protects the clients’ privacy.