Privacy-Preserving Scientific Computing with NumPy in SPU#

NumPy is one of the most popular tool for scientific computing. It is so common that we could find lots of equivalents of NumPy in other languages like xtensor and Gonum. So we can’t help thinking whether we could express computation with NumPy-like APIs in privacy-preserving settings since everyone loves NumPy.

Luckily, with the power of JAX NumPy package, we could easily accomplish this goal. In this tutorial, we would go through: - The relation between JAX and SPU - Write a Jittable JAX Program - Execute JAX Program with SPU

The relation between JAX and SPU#

TL;DR#

SPU actually consists of two components - Compiler and Runtime. SPU Runtime could only execute PPHlo. One example of PPHlo kernel is **pphlo.add**. Actually we just feed PPHlo programs to SPU Runtime directly to execute MPC programs in some internal applications when the logic is extremely simple and clear.

SPU compiler could translate XLA programs to PPHlo. You could check “Supported” XLA ops in this documentation. You may find XLA ops are very similar to PPHlo ops in most cases. It seems we still couldn’t write XLA programs by hand. You are absolutely correct. If you happen to check here, you should find actually there are lot’s of AI frameworks which could emit XLA programs without your effort, including:

  • TensorFLow

  • Pytorch

  • JAX

Let’s go through each step to have a look at different programs!

JAX Program#

The below is a jax program to add an array and a scalar. It should make sense to you if you are familiar with NumPy.

[1]:
import jax
import numpy as np


def simple_add(x, y):
    return jax.numpy.add(x, y)


simple_add(np.array([[1, 2], [3, 4]]), 4)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[1]:
Array([[5, 6],
       [7, 8]], dtype=int32)

XLA Program#

Let’s check what the XLA program for this JAX program looks like. JAX provides xla_computation to convert JAX programs to XLA programs.

[2]:
c = jax.xla_computation(simple_add)(np.array([[1, 2], [3, 4]]), 4)

c.as_hlo_text()

[2]:
'HloModule xla_computation_simple_add, entry_computation_layout={(s32[2,2]{1,0},s32[])->(s32[2,2]{1,0})}\n\nENTRY main.6 {\n  Arg_0.1 = s32[2,2]{1,0} parameter(0)\n  Arg_1.2 = s32[] parameter(1)\n  broadcast.3 = s32[2,2]{1,0} broadcast(Arg_1.2), dimensions={}\n  add.4 = s32[2,2]{1,0} add(Arg_0.1, broadcast.3)\n  ROOT tuple.5 = (s32[2,2]{1,0}) tuple(add.4)\n}\n\n'

You should be aware of the following facts:

  • shape and dtype is fixed in XLA program like s32[2,2]{1,0} in each command.

  • an implicit broadcast op is inserted before add op.

PPHlo Program#

Lastly, let’s check the PPHlo program for this XLA program. spu.compile could convert XLA programs to PPHlo programs.

[3]:
import spu

pphlo = spu.compile(
    c.as_serialized_hlo_module_proto(),
    "hlo",
    [spu.Visibility.VIS_SECRET, spu.Visibility.VIS_SECRET],
)

pphlo

[3]:
b'module @xla_computation_simple_add attributes {mhlo.cross_program_prefetches = [], mhlo.dynamic_parameter_bindings = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {\n  func.func @main(%arg0: tensor<2x2x!pphlo.sec<i32>>, %arg1: tensor<!pphlo.sec<i32>>) -> tensor<2x2x!pphlo.sec<i32>> {\n    %0 = "pphlo.broadcast"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<!pphlo.sec<i32>>) -> tensor<2x2x!pphlo.sec<i32>>\n    %1 = "pphlo.add"(%arg0, %0) : (tensor<2x2x!pphlo.sec<i32>>, tensor<2x2x!pphlo.sec<i32>>) -> tensor<2x2x!pphlo.sec<i32>>\n    return %1 : tensor<2x2x!pphlo.sec<i32>>\n  }\n}\n'

You may find the PPHlo program is identical to XLA program. The only differences are:

  • You have to provide the input visibility to SPU compiler, i.e. [spu.Visibility.VIS_SECRET, spu.Visibility.VIS_SECRET] in our case.

  • Comparing to XLA program, Visibility is an extra attribute to all variables in PPHlo program like tensor<2x2x!pphlo.sec> means this is a secret 2x2 i32 tensor.

SPU compiler would deduce visibility in each step, let’s modify input visibility and check what would happen.

[4]:
pphlo = spu.compile(
    c.as_serialized_hlo_module_proto(),
    "hlo",
    [spu.Visibility.VIS_SECRET, spu.Visibility.VIS_PUBLIC],
)

pphlo

[4]:
b'module @xla_computation_simple_add attributes {mhlo.cross_program_prefetches = [], mhlo.dynamic_parameter_bindings = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {\n  func.func @main(%arg0: tensor<2x2x!pphlo.sec<i32>>, %arg1: tensor<!pphlo.pub<i32>>) -> tensor<2x2x!pphlo.sec<i32>> {\n    %0 = "pphlo.broadcast"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<!pphlo.pub<i32>>) -> tensor<2x2x!pphlo.pub<i32>>\n    %1 = "pphlo.add"(%arg0, %0) : (tensor<2x2x!pphlo.sec<i32>>, tensor<2x2x!pphlo.pub<i32>>) -> tensor<2x2x!pphlo.sec<i32>>\n    return %1 : tensor<2x2x!pphlo.sec<i32>>\n  }\n}\n'

From JAX to SPU#

So this is the whole story. 1. You write a JAX program in Python. 2. Then you could turn JAX program to XLA program with the first-party API from JAX, i.e. jax.xla_computation. 3. Afterwards, SPU compiler could transfer XLA program to PPHlo program - the only language could be understood by SPU Runtime. 4. In the end, the PPHlo program is sent to SPU Runtimes and executed.

In SecretFlow, we have implemented some helper methods so that you could just write a JAX program in the beginning, we would take care of the remaining steps for you.

Write a Jittable JAX Program#

Jittable means a JAX program could be Just In Time (JIT) compilation into XLA program. So only when a JAX program is Jittable, it then could be possibly executed by SPU.

Since SPU doesn’t support all XLA operators, even a JAX program is jittable, SPU runtime still could refuse to execute.

JAX NumPy Package#

We could use these NumPy-like APIs from JAX. JAX NumPy APIs are very similar to original ones, while - JAX NumPy arrays are immutable, so you have to use ndarray.at instead of in-place array updates - You have to provide some extra args to make the method call jittable(we would discuss this later).

And actually SPU doesn’t support all JAX NumPy operators, please also check this documentation. We are updating this document promptly and we have listed the current status of each operators.

Next, we are going to write some JAX Numpy programs.

Euclidean Distance#

Just one-line code we could compute Euclidean Distance of two points.

[5]:
def euclidean_distance(p1, p2):
    return jax.numpy.linalg.norm(p1 - p2)

Let’s check whether it is jittable by jax.jit. You could also use jax.xla_computation for testing as well.

[6]:
euclidean_distance_jit = jax.jit(euclidean_distance)

print(euclidean_distance_jit(np.array([0, 0]), np.array([3, 4])))


# or
print(
    (
        jax.xla_computation(euclidean_distance)(np.array([0, 0]), np.array([3, 4]))
    ).as_hlo_text()
)

5.0
HloModule xla_computation_euclidean_distance, entry_computation_layout={(s32[2]{0},s32[2]{0})->(f32[])}

region_0.4 {
  Arg_0.5 = f32[] parameter(0)
  Arg_1.6 = f32[] parameter(1)
  ROOT add.7 = f32[] add(Arg_0.5, Arg_1.6)
}

norm.8 {
  Arg_0.9 = s32[2]{0} parameter(0)
  convert.11 = f32[2]{0} convert(Arg_0.9)
  multiply.12 = f32[2]{0} multiply(convert.11, convert.11)
  constant.10 = f32[] constant(0)
  reduce.13 = f32[] reduce(multiply.12, constant.10), dimensions={0}, to_apply=region_0.4
  ROOT sqrt.14 = f32[] sqrt(reduce.13)
}

ENTRY main.17 {
  Arg_0.1 = s32[2]{0} parameter(0)
  Arg_1.2 = s32[2]{0} parameter(1)
  subtract.3 = s32[2]{0} subtract(Arg_0.1, Arg_1.2)
  call.15 = f32[] call(subtract.3), to_apply=norm.8
  ROOT tuple.16 = (f32[]) tuple(call.15)
}


Area of a Simple Polygon#

Given a list of Cartesian coordinates of vertices of a simply polygon, we could calculate the area by Shoelace formula.

[7]:
import jax.numpy as jnp


def area_of_simple_polygon(vertices):
    area = 0
    for i in range(0, vertices.shape[0]):
        a = jnp.expand_dims(vertices[i, :], axis=0)
        b = jnp.expand_dims(vertices[(i + 1) % vertices.shape[0], :], axis=0)
        x = jax.numpy.concatenate((a, b))
        x_t = jax.numpy.transpose(x)
        area += 0.5 * jax.numpy.linalg.det(x_t)
    return area


vertices = np.array([[1, 6], [3, 1], [7, 2], [4, 4], [8, 5]])

area_of_simple_polygon(vertices)

[7]:
Array(16.5, dtype=float32)

Let’s check whether area_of_simple_polygon is jittable.

[8]:
print(jax.xla_computation(area_of_simple_polygon)(vertices).as_hlo_text())

HloModule xla_computation_area_of_simple_polygon, entry_computation_layout={(s32[5,2]{1,0})->(f32[])}

det.7 {
  Arg_0.8 = s32[2,2]{1,0} parameter(0)
  convert.9 = f32[2,2]{1,0} convert(Arg_0.8)
  slice.10 = f32[1,1]{1,0} slice(convert.9), slice={[0:1], [0:1]}
  reshape.11 = f32[] reshape(slice.10)
  slice.12 = f32[1,1]{1,0} slice(convert.9), slice={[1:2], [1:2]}
  reshape.13 = f32[] reshape(slice.12)
  multiply.14 = f32[] multiply(reshape.11, reshape.13)
  slice.15 = f32[1,1]{1,0} slice(convert.9), slice={[0:1], [1:2]}
  reshape.16 = f32[] reshape(slice.15)
  slice.17 = f32[1,1]{1,0} slice(convert.9), slice={[1:2], [0:1]}
  reshape.18 = f32[] reshape(slice.17)
  multiply.19 = f32[] multiply(reshape.16, reshape.18)
  ROOT subtract.20 = f32[] subtract(multiply.14, multiply.19)
}

det_0.27 {
  Arg_0.28 = s32[2,2]{1,0} parameter(0)
  convert.29 = f32[2,2]{1,0} convert(Arg_0.28)
  slice.30 = f32[1,1]{1,0} slice(convert.29), slice={[0:1], [0:1]}
  reshape.31 = f32[] reshape(slice.30)
  slice.32 = f32[1,1]{1,0} slice(convert.29), slice={[1:2], [1:2]}
  reshape.33 = f32[] reshape(slice.32)
  multiply.34 = f32[] multiply(reshape.31, reshape.33)
  slice.35 = f32[1,1]{1,0} slice(convert.29), slice={[0:1], [1:2]}
  reshape.36 = f32[] reshape(slice.35)
  slice.37 = f32[1,1]{1,0} slice(convert.29), slice={[1:2], [0:1]}
  reshape.38 = f32[] reshape(slice.37)
  multiply.39 = f32[] multiply(reshape.36, reshape.38)
  ROOT subtract.40 = f32[] subtract(multiply.34, multiply.39)
}

det_1.48 {
  Arg_0.49 = s32[2,2]{1,0} parameter(0)
  convert.50 = f32[2,2]{1,0} convert(Arg_0.49)
  slice.51 = f32[1,1]{1,0} slice(convert.50), slice={[0:1], [0:1]}
  reshape.52 = f32[] reshape(slice.51)
  slice.53 = f32[1,1]{1,0} slice(convert.50), slice={[1:2], [1:2]}
  reshape.54 = f32[] reshape(slice.53)
  multiply.55 = f32[] multiply(reshape.52, reshape.54)
  slice.56 = f32[1,1]{1,0} slice(convert.50), slice={[0:1], [1:2]}
  reshape.57 = f32[] reshape(slice.56)
  slice.58 = f32[1,1]{1,0} slice(convert.50), slice={[1:2], [0:1]}
  reshape.59 = f32[] reshape(slice.58)
  multiply.60 = f32[] multiply(reshape.57, reshape.59)
  ROOT subtract.61 = f32[] subtract(multiply.55, multiply.60)
}

det_2.69 {
  Arg_0.70 = s32[2,2]{1,0} parameter(0)
  convert.71 = f32[2,2]{1,0} convert(Arg_0.70)
  slice.72 = f32[1,1]{1,0} slice(convert.71), slice={[0:1], [0:1]}
  reshape.73 = f32[] reshape(slice.72)
  slice.74 = f32[1,1]{1,0} slice(convert.71), slice={[1:2], [1:2]}
  reshape.75 = f32[] reshape(slice.74)
  multiply.76 = f32[] multiply(reshape.73, reshape.75)
  slice.77 = f32[1,1]{1,0} slice(convert.71), slice={[0:1], [1:2]}
  reshape.78 = f32[] reshape(slice.77)
  slice.79 = f32[1,1]{1,0} slice(convert.71), slice={[1:2], [0:1]}
  reshape.80 = f32[] reshape(slice.79)
  multiply.81 = f32[] multiply(reshape.78, reshape.80)
  ROOT subtract.82 = f32[] subtract(multiply.76, multiply.81)
}

det_3.90 {
  Arg_0.91 = s32[2,2]{1,0} parameter(0)
  convert.92 = f32[2,2]{1,0} convert(Arg_0.91)
  slice.93 = f32[1,1]{1,0} slice(convert.92), slice={[0:1], [0:1]}
  reshape.94 = f32[] reshape(slice.93)
  slice.95 = f32[1,1]{1,0} slice(convert.92), slice={[1:2], [1:2]}
  reshape.96 = f32[] reshape(slice.95)
  multiply.97 = f32[] multiply(reshape.94, reshape.96)
  slice.98 = f32[1,1]{1,0} slice(convert.92), slice={[0:1], [1:2]}
  reshape.99 = f32[] reshape(slice.98)
  slice.100 = f32[1,1]{1,0} slice(convert.92), slice={[1:2], [0:1]}
  reshape.101 = f32[] reshape(slice.100)
  multiply.102 = f32[] multiply(reshape.99, reshape.101)
  ROOT subtract.103 = f32[] subtract(multiply.97, multiply.102)
}

ENTRY main.108 {
  Arg_0.1 = s32[5,2]{1,0} parameter(0)
  slice.3 = s32[1,2]{1,0} slice(Arg_0.1), slice={[0:1], [0:2]}
  slice.4 = s32[1,2]{1,0} slice(Arg_0.1), slice={[1:2], [0:2]}
  concatenate.5 = s32[2,2]{1,0} concatenate(slice.3, slice.4), dimensions={0}
  transpose.6 = s32[2,2]{0,1} transpose(concatenate.5), dimensions={1,0}
  call.21 = f32[] call(transpose.6), to_apply=det.7
  constant.2 = f32[] constant(0.5)
  multiply.22 = f32[] multiply(call.21, constant.2)
  slice.23 = s32[1,2]{1,0} slice(Arg_0.1), slice={[1:2], [0:2]}
  slice.24 = s32[1,2]{1,0} slice(Arg_0.1), slice={[2:3], [0:2]}
  concatenate.25 = s32[2,2]{1,0} concatenate(slice.23, slice.24), dimensions={0}
  transpose.26 = s32[2,2]{0,1} transpose(concatenate.25), dimensions={1,0}
  call.41 = f32[] call(transpose.26), to_apply=det_0.27
  multiply.42 = f32[] multiply(call.41, constant.2)
  add.43 = f32[] add(multiply.22, multiply.42)
  slice.44 = s32[1,2]{1,0} slice(Arg_0.1), slice={[2:3], [0:2]}
  slice.45 = s32[1,2]{1,0} slice(Arg_0.1), slice={[3:4], [0:2]}
  concatenate.46 = s32[2,2]{1,0} concatenate(slice.44, slice.45), dimensions={0}
  transpose.47 = s32[2,2]{0,1} transpose(concatenate.46), dimensions={1,0}
  call.62 = f32[] call(transpose.47), to_apply=det_1.48
  multiply.63 = f32[] multiply(call.62, constant.2)
  add.64 = f32[] add(add.43, multiply.63)
  slice.65 = s32[1,2]{1,0} slice(Arg_0.1), slice={[3:4], [0:2]}
  slice.66 = s32[1,2]{1,0} slice(Arg_0.1), slice={[4:5], [0:2]}
  concatenate.67 = s32[2,2]{1,0} concatenate(slice.65, slice.66), dimensions={0}
  transpose.68 = s32[2,2]{0,1} transpose(concatenate.67), dimensions={1,0}
  call.83 = f32[] call(transpose.68), to_apply=det_2.69
  multiply.84 = f32[] multiply(call.83, constant.2)
  add.85 = f32[] add(add.64, multiply.84)
  slice.86 = s32[1,2]{1,0} slice(Arg_0.1), slice={[4:5], [0:2]}
  slice.87 = s32[1,2]{1,0} slice(Arg_0.1), slice={[0:1], [0:2]}
  concatenate.88 = s32[2,2]{1,0} concatenate(slice.86, slice.87), dimensions={0}
  transpose.89 = s32[2,2]{0,1} transpose(concatenate.88), dimensions={1,0}
  call.104 = f32[] call(transpose.89), to_apply=det_3.90
  multiply.105 = f32[] multiply(call.104, constant.2)
  add.106 = f32[] add(add.85, multiply.105)
  ROOT tuple.107 = (f32[]) tuple(add.106)
}


Could We Jit Anything?#

Absolutely not, please check this documentation from JAX!

The most common cause to unjittable program is your control flow relies on the value of input. For instance,

[9]:
# Cited from JAX documentation.
# While loop conditioned on x and n.

def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

g_jit = jax.jit(g)

import traceback
try:
  g_jit(10, 20)  # Should raise an error.
except Exception:
    traceback.print_exc()
Traceback (most recent call last):
  File "/tmp/ipykernel_2314574/682526420.py", line 14, in <module>
    g_jit(10, 20)  # Should raise an error.
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 163, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 235, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 179, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/api.py", line 440, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 513, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 965, in _pjit_jaxpr
    jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
    ans = call(fun, *args)
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 923, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2033, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2050, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/tmp/ipykernel_2314574/682526420.py", line 6, in g
    while i < n:
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/core.py", line 653, in __bool__
    def __bool__(self): return self.aval._bool(self)
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/core.py", line 1344, in error
    raise ConcretizationTypeError(arg, fname_context)
jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `bool` function.
The error occurred while tracing the function g at /tmp/ipykernel_2314574/682526420.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument 'n'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

There are two possible solutions. 1. You could replace control flow with low-level **jax.lax** APIs. You need to spend some time figure out how to use these APIs.

[10]:
def g_with_lax_control_flow(x, n):
  def body_fun(i):
    i += 1
    return i
  return x + jax.lax.while_loop(lambda i: i < n, body_fun, 0)


g_with_lax_control_flow_jit = jax.jit(g_with_lax_control_flow)
g_with_lax_control_flow_jit(10, 20)  # good to go!
[10]:
Array(30, dtype=int32, weak_type=True)
  1. The other possible solution is to use static_argnames.

[11]:
g_with_static_argnames_jit = jax.jit(g, static_argnames=['n'])
g_with_static_argnames_jit(10, 20)  # good to go!
[11]:
Array(30, dtype=int32, weak_type=True)

so which method we should choose when the program is unjittable?

This is our suggestion:

  • Rewrite the control flow with jax.lax APIs first. Although these are some learning costs here, but it deserves that.

  • If the visibility of affected input values are VIS_PUBLIC like n in the above example, you could mark it as static_argnames and these affected input values would be compiled into XLA program.

More Examples#

If you would like to check more examples, please check Python examples in SPU repo. In most examples, the MPC part are written with jax.numpy package. And you could find we are using jax.lax APIs and static_argnames heavily to make JAX program jittable!

Execute JAX Program with SPU#

Once you have your jittable JAX program ready, we could execute them with SPU!

(Optional) SPU Simulation#

If you hope to get a quick try, I would like to introduce spu.sim_jax to you. Let’s show how does it work!

spu.sim_jax is only exposed after spu v0.3.1b8.

Here we create an SPU simulator with the following settings: - world size of 3. - with ABY3 protocol. Check all supported protocol here. - field of 64 which the values in SPU are expressed in 2^64 ring.

However, if you just want to confirm if the JAX program could be executed by SPU, any settings should be fine. Different settings could only affect the elapsed time and precision of computation.

[12]:
from spu.utils.simulation import Simulator, sim_jax

sim = Simulator.simple(3, spu.ProtocolKind.ABY3, spu.FieldType.FM64)

spu_euclidean_distance_fn = sim_jax(sim, euclidean_distance)

spu_euclidean_distance_fn(np.array([0, 0]), np.array([3, 4]))
[12]:
array(4.999962, dtype=float32)

If you execute the code above repeatedly, you may find the result is slightly different between runs, which is expected due to randomness in MPC computation.

After testing with euclidean_distance, we have a try with area_of_simple_polygon.

[13]:
spu_area_of_simple_polygon_fn=sim_jax(sim, area_of_simple_polygon)

spu_area_of_simple_polygon_fn(vertices)
[13]:
array(16.5, dtype=float32)

Run with SPU Device#

Finally, we are going to run the JAX program with SecretFlow.

I guess you should be familiar with the following steps if you have checked out other tutorials.

Here we create a local standalone SecretFlow cluster with three devices:

  • Two PYU device - alice and bob

  • An SPU device

[14]:
import secretflow as sf

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(parties=['alice', 'bob'], address='local')

alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))
2023-03-21 21:03:27,671 INFO worker.py:1538 -- Started a local Ray instance.

We try euclidean_distance with spu device first.

[15]:
p1 = sf.to(alice,np.array([0, 0]))
p2 = sf.to(bob,np.array([3, 4]))

distance = spu(euclidean_distance)(p1, p2)

sf.reveal(distance)
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=2316721) WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=2316358) WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[15]:
array(5., dtype=float32)

Then we try area_of_simple_polygon.

[16]:
v = sf.to(alice, vertices)
area = spu(area_of_simple_polygon)(v)

sf.reveal(area)

[16]:
array(16.5, dtype=float32)

Summary#

This is the end of the tutorial. Let’s summarize the steps to do privacy-preserving scientific computation with JAX NumPy APIS:

  1. Write a jittable JAX NumPy program. You should test it with jax.jit or jax.xla_computation.

  2. (Optional) Try the JAX program with SPU simulation.

  3. Run this JAX NumPy with SPU device in SecretFlow.

If you find your JAX program is jittable but fails with SPU compiler or runtime. Please check JAX NumPy Operators Status and XLA Implementation Status. Or you could contact us directly with GitHub Issues.