在SPU中使用NumPy实现隐私保护的科学计算#

NumPy是最流行的科学计算工具之一。它非常常见,以至于我们可以在其他语言中找到许多类似于NumPy的工具,比如 xtensorGonum 。因此,我们不禁想到,是否可以在隐私保护的环境中使用类似于NumPy的API表达计算,因为每个人都喜欢NumPy。

幸运的是,借助 JAX NumPy软件包的强大功能,我们可以轻松实现这个目标。在本教程中,我们将介绍以下内容:1. JAX和SPU之间的关系2. 编写可编译的JAX程序3. 在SPU上执行JAX程序

JAX和SPU之间的关系#

TL;DR#

SPU实际上由两个组件组成 - 编译器和运行时。SPU运行时只能执行 PPHlo 。一个PPHlo内核的例子是 **pphlo.add** 。实际上,当逻辑非常简单和清晰时,我们只需将PPHlo程序直接输入SPU运行时以执行MPC程序,用于一些内部应用。

SPU编译器可以将 XLA 程序转换为 PPHlo 。您可以在 此文档 中检查“支持的”XLA操作。在大多数情况下,您会发现XLA操作与PPHlo操作非常相似。似乎我们仍然不能手动编写XLA程序。事实的确如此。如果您查看 这里 ,您应该会发现实际上有很多人工智能框架可以自动生成XLA程序,包括:

  • TensorFLow

  • Pytorch

  • JAX

让我们逐步看一下不同的程序!

JAX 程序#

下面是一个JAX程序,用于将一个数组和一个标量相加。如果您熟悉NumPy,这应该很容易理解。

[1]:
import jax
import numpy as np


def simple_add(x, y):
    return jax.numpy.add(x, y)


simple_add(np.array([[1, 2], [3, 4]]), 4)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[1]:
Array([[5, 6],
       [7, 8]], dtype=int32)

XLA 程序#

让我们来看看这个JAX程序的XLA程序是什么样子。JAX提供 xla_computation 函数将JAX程序转换为XLA程序。

[2]:
c = jax.xla_computation(simple_add)(np.array([[1, 2], [3, 4]]), 4)

c.as_hlo_text()

[2]:
'HloModule xla_computation_simple_add, entry_computation_layout={(s32[2,2]{1,0},s32[])->(s32[2,2]{1,0})}\n\nENTRY main.6 {\n  Arg_0.1 = s32[2,2]{1,0} parameter(0)\n  Arg_1.2 = s32[] parameter(1)\n  broadcast.3 = s32[2,2]{1,0} broadcast(Arg_1.2), dimensions={}\n  add.4 = s32[2,2]{1,0} add(Arg_0.1, broadcast.3)\n  ROOT tuple.5 = (s32[2,2]{1,0}) tuple(add.4)\n}\n\n'

您应该知道以下事实:

  • 在XLA程序中,每个参数的形状和数据类型都是固定的,如 s32[2,2]{1,0}

  • add 操作之前会插入一个隐式的 broadcast 操作。

PPHlo 程序#

最后,让我们查看这个XLA程序的PPHlo程序。 spu.compile 函数可以将XLA程序转换为PPHlo程序。

[3]:
import spu

pphlo = spu.compile(
    c.as_serialized_hlo_module_proto(),
    "hlo",
    [spu.Visibility.VIS_SECRET, spu.Visibility.VIS_SECRET],
)

pphlo

[3]:
b'module @xla_computation_simple_add attributes {mhlo.cross_program_prefetches = [], mhlo.dynamic_parameter_bindings = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {\n  func.func @main(%arg0: tensor<2x2x!pphlo.sec<i32>>, %arg1: tensor<!pphlo.sec<i32>>) -> tensor<2x2x!pphlo.sec<i32>> {\n    %0 = "pphlo.broadcast"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<!pphlo.sec<i32>>) -> tensor<2x2x!pphlo.sec<i32>>\n    %1 = "pphlo.add"(%arg0, %0) : (tensor<2x2x!pphlo.sec<i32>>, tensor<2x2x!pphlo.sec<i32>>) -> tensor<2x2x!pphlo.sec<i32>>\n    return %1 : tensor<2x2x!pphlo.sec<i32>>\n  }\n}\n'

您可能会发现,PPHlo程序与XLA程序几乎完全相同。区别是:

  • 您必须向SPU编译器提供输入可见性,例如我们的情况下为 [spu.Visibility.VIS_SECRET, spu.Visibility.VIS_SECRET]

  • 与XLA程序相比,Visibility是PPHlo程序中所有变量的额外属性。例如, tensor<2x2x!pphlo.sec> 表示这是一个密态的2x2 i32张量。

SPU编译器会在每一步中推导可见性,让我们修改输入可见性并查看会发生什么。

[4]:
pphlo = spu.compile(
    c.as_serialized_hlo_module_proto(),
    "hlo",
    [spu.Visibility.VIS_SECRET, spu.Visibility.VIS_PUBLIC],
)

pphlo

[4]:
b'module @xla_computation_simple_add attributes {mhlo.cross_program_prefetches = [], mhlo.dynamic_parameter_bindings = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {\n  func.func @main(%arg0: tensor<2x2x!pphlo.sec<i32>>, %arg1: tensor<!pphlo.pub<i32>>) -> tensor<2x2x!pphlo.sec<i32>> {\n    %0 = "pphlo.broadcast"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<!pphlo.pub<i32>>) -> tensor<2x2x!pphlo.pub<i32>>\n    %1 = "pphlo.add"(%arg0, %0) : (tensor<2x2x!pphlo.sec<i32>>, tensor<2x2x!pphlo.pub<i32>>) -> tensor<2x2x!pphlo.sec<i32>>\n    return %1 : tensor<2x2x!pphlo.sec<i32>>\n  }\n}\n'

从JAX到SPU#

这就是整个故事。1. 您可以使用Python编写JAX程序。2. 您可以使用JAX的第一方API(即jax.xla_computation)将JAX程序转换为XLA程序。3. SPU编译器可以将XLA程序转换为PPHlo程序——这是SPU运行时唯一能理解的语言。4. 将PPHlo程序发送到SPU运行时并执行。

在SecretFlow中,我们已经实现了一些辅助方法,这样您就可以一开始就编写JAX程序,我们会为您处理其余的步骤。

写一个可即时编译(Jittable)的JAX程序#

可即时编译(Jittable)指的是JAX程序可以即时编译(JIT)为XLA程序。因此,只有当一个JAX程序是可即时编译的,它才可能被SPU执行。

由于SPU不支持所有的XLA运算符,即使一个JAX程序是可即时编译的,SPU运行时仍然可能拒绝执行。

JAX NumPy库#

我们可以使用 JAX 中的这些 类似于 NumPy 的 API。JAX NumPy 的 API 和原始的 NumPy API 非常相似,但是有以下几点不同:- JAX NumPy 数组是不可变的,因此您必须使用 ndarray.at 来进行非就地数组更新。- 您必须提供一些额外的参数使方法调用可即时编译(稍后我们会讨论这个)。

实际上,SPU 不支持所有的 JAX NumPy 操作符,请查看这份 文档 .我们正在及时更新这份文档,并列出了每个操作符的当前状态。

接下来,我们将编写一些 JAX NumPy 程序。

欧几里得距离#

只需一行代码,我们就可以计算两个点的欧几里得距离。

[5]:
def euclidean_distance(p1, p2):
    return jax.numpy.linalg.norm(p1 - p2)

我们可以用 jax.jit 来检查它是否是jittable的, 也可以用 jax.xla_computation

[6]:
euclidean_distance_jit = jax.jit(euclidean_distance)

print(euclidean_distance_jit(np.array([0, 0]), np.array([3, 4])))


# or
print(
    (
        jax.xla_computation(euclidean_distance)(np.array([0, 0]), np.array([3, 4]))
    ).as_hlo_text()
)

5.0
HloModule xla_computation_euclidean_distance, entry_computation_layout={(s32[2]{0},s32[2]{0})->(f32[])}

region_0.4 {
  Arg_0.5 = f32[] parameter(0)
  Arg_1.6 = f32[] parameter(1)
  ROOT add.7 = f32[] add(Arg_0.5, Arg_1.6)
}

norm.8 {
  Arg_0.9 = s32[2]{0} parameter(0)
  convert.11 = f32[2]{0} convert(Arg_0.9)
  multiply.12 = f32[2]{0} multiply(convert.11, convert.11)
  constant.10 = f32[] constant(0)
  reduce.13 = f32[] reduce(multiply.12, constant.10), dimensions={0}, to_apply=region_0.4
  ROOT sqrt.14 = f32[] sqrt(reduce.13)
}

ENTRY main.17 {
  Arg_0.1 = s32[2]{0} parameter(0)
  Arg_1.2 = s32[2]{0} parameter(1)
  subtract.3 = s32[2]{0} subtract(Arg_0.1, Arg_1.2)
  call.15 = f32[] call(subtract.3), to_apply=norm.8
  ROOT tuple.16 = (f32[]) tuple(call.15)
}


简单多边形面积#

给定一个简单多边形顶点的笛卡尔坐标列表,我们可以通过 Shoelace formula 计算面积

[7]:
import jax.numpy as jnp


def area_of_simple_polygon(vertices):
    area = 0
    for i in range(0, vertices.shape[0]):
        a = jnp.expand_dims(vertices[i, :], axis=0)
        b = jnp.expand_dims(vertices[(i + 1) % vertices.shape[0], :], axis=0)
        x = jax.numpy.concatenate((a, b))
        x_t = jax.numpy.transpose(x)
        area += 0.5 * jax.numpy.linalg.det(x_t)
    return area


vertices = np.array([[1, 6], [3, 1], [7, 2], [4, 4], [8, 5]])

area_of_simple_polygon(vertices)

[7]:
Array(16.5, dtype=float32)

让我们检查一下 area_of_simple_polygon 是否是 jittable 的。

[8]:
print(jax.xla_computation(area_of_simple_polygon)(vertices).as_hlo_text())

HloModule xla_computation_area_of_simple_polygon, entry_computation_layout={(s32[5,2]{1,0})->(f32[])}

det.7 {
  Arg_0.8 = s32[2,2]{1,0} parameter(0)
  convert.9 = f32[2,2]{1,0} convert(Arg_0.8)
  slice.10 = f32[1,1]{1,0} slice(convert.9), slice={[0:1], [0:1]}
  reshape.11 = f32[] reshape(slice.10)
  slice.12 = f32[1,1]{1,0} slice(convert.9), slice={[1:2], [1:2]}
  reshape.13 = f32[] reshape(slice.12)
  multiply.14 = f32[] multiply(reshape.11, reshape.13)
  slice.15 = f32[1,1]{1,0} slice(convert.9), slice={[0:1], [1:2]}
  reshape.16 = f32[] reshape(slice.15)
  slice.17 = f32[1,1]{1,0} slice(convert.9), slice={[1:2], [0:1]}
  reshape.18 = f32[] reshape(slice.17)
  multiply.19 = f32[] multiply(reshape.16, reshape.18)
  ROOT subtract.20 = f32[] subtract(multiply.14, multiply.19)
}

det_0.27 {
  Arg_0.28 = s32[2,2]{1,0} parameter(0)
  convert.29 = f32[2,2]{1,0} convert(Arg_0.28)
  slice.30 = f32[1,1]{1,0} slice(convert.29), slice={[0:1], [0:1]}
  reshape.31 = f32[] reshape(slice.30)
  slice.32 = f32[1,1]{1,0} slice(convert.29), slice={[1:2], [1:2]}
  reshape.33 = f32[] reshape(slice.32)
  multiply.34 = f32[] multiply(reshape.31, reshape.33)
  slice.35 = f32[1,1]{1,0} slice(convert.29), slice={[0:1], [1:2]}
  reshape.36 = f32[] reshape(slice.35)
  slice.37 = f32[1,1]{1,0} slice(convert.29), slice={[1:2], [0:1]}
  reshape.38 = f32[] reshape(slice.37)
  multiply.39 = f32[] multiply(reshape.36, reshape.38)
  ROOT subtract.40 = f32[] subtract(multiply.34, multiply.39)
}

det_1.48 {
  Arg_0.49 = s32[2,2]{1,0} parameter(0)
  convert.50 = f32[2,2]{1,0} convert(Arg_0.49)
  slice.51 = f32[1,1]{1,0} slice(convert.50), slice={[0:1], [0:1]}
  reshape.52 = f32[] reshape(slice.51)
  slice.53 = f32[1,1]{1,0} slice(convert.50), slice={[1:2], [1:2]}
  reshape.54 = f32[] reshape(slice.53)
  multiply.55 = f32[] multiply(reshape.52, reshape.54)
  slice.56 = f32[1,1]{1,0} slice(convert.50), slice={[0:1], [1:2]}
  reshape.57 = f32[] reshape(slice.56)
  slice.58 = f32[1,1]{1,0} slice(convert.50), slice={[1:2], [0:1]}
  reshape.59 = f32[] reshape(slice.58)
  multiply.60 = f32[] multiply(reshape.57, reshape.59)
  ROOT subtract.61 = f32[] subtract(multiply.55, multiply.60)
}

det_2.69 {
  Arg_0.70 = s32[2,2]{1,0} parameter(0)
  convert.71 = f32[2,2]{1,0} convert(Arg_0.70)
  slice.72 = f32[1,1]{1,0} slice(convert.71), slice={[0:1], [0:1]}
  reshape.73 = f32[] reshape(slice.72)
  slice.74 = f32[1,1]{1,0} slice(convert.71), slice={[1:2], [1:2]}
  reshape.75 = f32[] reshape(slice.74)
  multiply.76 = f32[] multiply(reshape.73, reshape.75)
  slice.77 = f32[1,1]{1,0} slice(convert.71), slice={[0:1], [1:2]}
  reshape.78 = f32[] reshape(slice.77)
  slice.79 = f32[1,1]{1,0} slice(convert.71), slice={[1:2], [0:1]}
  reshape.80 = f32[] reshape(slice.79)
  multiply.81 = f32[] multiply(reshape.78, reshape.80)
  ROOT subtract.82 = f32[] subtract(multiply.76, multiply.81)
}

det_3.90 {
  Arg_0.91 = s32[2,2]{1,0} parameter(0)
  convert.92 = f32[2,2]{1,0} convert(Arg_0.91)
  slice.93 = f32[1,1]{1,0} slice(convert.92), slice={[0:1], [0:1]}
  reshape.94 = f32[] reshape(slice.93)
  slice.95 = f32[1,1]{1,0} slice(convert.92), slice={[1:2], [1:2]}
  reshape.96 = f32[] reshape(slice.95)
  multiply.97 = f32[] multiply(reshape.94, reshape.96)
  slice.98 = f32[1,1]{1,0} slice(convert.92), slice={[0:1], [1:2]}
  reshape.99 = f32[] reshape(slice.98)
  slice.100 = f32[1,1]{1,0} slice(convert.92), slice={[1:2], [0:1]}
  reshape.101 = f32[] reshape(slice.100)
  multiply.102 = f32[] multiply(reshape.99, reshape.101)
  ROOT subtract.103 = f32[] subtract(multiply.97, multiply.102)
}

ENTRY main.108 {
  Arg_0.1 = s32[5,2]{1,0} parameter(0)
  slice.3 = s32[1,2]{1,0} slice(Arg_0.1), slice={[0:1], [0:2]}
  slice.4 = s32[1,2]{1,0} slice(Arg_0.1), slice={[1:2], [0:2]}
  concatenate.5 = s32[2,2]{1,0} concatenate(slice.3, slice.4), dimensions={0}
  transpose.6 = s32[2,2]{0,1} transpose(concatenate.5), dimensions={1,0}
  call.21 = f32[] call(transpose.6), to_apply=det.7
  constant.2 = f32[] constant(0.5)
  multiply.22 = f32[] multiply(call.21, constant.2)
  slice.23 = s32[1,2]{1,0} slice(Arg_0.1), slice={[1:2], [0:2]}
  slice.24 = s32[1,2]{1,0} slice(Arg_0.1), slice={[2:3], [0:2]}
  concatenate.25 = s32[2,2]{1,0} concatenate(slice.23, slice.24), dimensions={0}
  transpose.26 = s32[2,2]{0,1} transpose(concatenate.25), dimensions={1,0}
  call.41 = f32[] call(transpose.26), to_apply=det_0.27
  multiply.42 = f32[] multiply(call.41, constant.2)
  add.43 = f32[] add(multiply.22, multiply.42)
  slice.44 = s32[1,2]{1,0} slice(Arg_0.1), slice={[2:3], [0:2]}
  slice.45 = s32[1,2]{1,0} slice(Arg_0.1), slice={[3:4], [0:2]}
  concatenate.46 = s32[2,2]{1,0} concatenate(slice.44, slice.45), dimensions={0}
  transpose.47 = s32[2,2]{0,1} transpose(concatenate.46), dimensions={1,0}
  call.62 = f32[] call(transpose.47), to_apply=det_1.48
  multiply.63 = f32[] multiply(call.62, constant.2)
  add.64 = f32[] add(add.43, multiply.63)
  slice.65 = s32[1,2]{1,0} slice(Arg_0.1), slice={[3:4], [0:2]}
  slice.66 = s32[1,2]{1,0} slice(Arg_0.1), slice={[4:5], [0:2]}
  concatenate.67 = s32[2,2]{1,0} concatenate(slice.65, slice.66), dimensions={0}
  transpose.68 = s32[2,2]{0,1} transpose(concatenate.67), dimensions={1,0}
  call.83 = f32[] call(transpose.68), to_apply=det_2.69
  multiply.84 = f32[] multiply(call.83, constant.2)
  add.85 = f32[] add(add.64, multiply.84)
  slice.86 = s32[1,2]{1,0} slice(Arg_0.1), slice={[4:5], [0:2]}
  slice.87 = s32[1,2]{1,0} slice(Arg_0.1), slice={[0:1], [0:2]}
  concatenate.88 = s32[2,2]{1,0} concatenate(slice.86, slice.87), dimensions={0}
  transpose.89 = s32[2,2]{0,1} transpose(concatenate.88), dimensions={1,0}
  call.104 = f32[] call(transpose.89), to_apply=det_3.90
  multiply.105 = f32[] multiply(call.104, constant.2)
  add.106 = f32[] add(add.85, multiply.105)
  ROOT tuple.107 = (f32[]) tuple(add.106)
}


我们是否可以 Jit 所有程序#

绝对不是,请查看来自 JAX 的 此文档

导致程序无法运行的最常见原因是您的控制流依赖于 input 的值。例如,

[9]:
# Cited from JAX documentation.
# While loop conditioned on x and n.

def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

g_jit = jax.jit(g)

import traceback
try:
  g_jit(10, 20)  # Should raise an error.
except Exception:
    traceback.print_exc()
Traceback (most recent call last):
  File "/tmp/ipykernel_2314574/682526420.py", line 14, in <module>
    g_jit(10, 20)  # Should raise an error.
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 163, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 235, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 179, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/api.py", line 440, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 513, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 965, in _pjit_jaxpr
    jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
    ans = call(fun, *args)
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 923, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2033, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2050, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/tmp/ipykernel_2314574/682526420.py", line 6, in g
    while i < n:
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/core.py", line 653, in __bool__
    def __bool__(self): return self.aval._bool(self)
  File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/core.py", line 1344, in error
    raise ConcretizationTypeError(arg, fname_context)
jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `bool` function.
The error occurred while tracing the function g at /tmp/ipykernel_2314574/682526420.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument 'n'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

有两种可能的解决方案。1. 您可以用低级的 **jax.lax** API 替换控制流。您需要花一些时间弄清楚如何使用这些 API。

[10]:
def g_with_lax_control_flow(x, n):
  def body_fun(i):
    i += 1
    return i
  return x + jax.lax.while_loop(lambda i: i < n, body_fun, 0)


g_with_lax_control_flow_jit = jax.jit(g_with_lax_control_flow)
g_with_lax_control_flow_jit(10, 20)  # good to go!
[10]:
Array(30, dtype=int32, weak_type=True)
  1. 另一种可能的解决方案是使用 static_argnames

[11]:
g_with_static_argnames_jit = jax.jit(g, static_argnames=['n'])
g_with_static_argnames_jit(10, 20)  # good to go!
[11]:
Array(30, dtype=int32, weak_type=True)

so which method we should choose when the program is unjittable?

那么当程序 unjittable 时,我们应该选择哪种方法呢?

  • 首先使用 jax.lax API 重写控制流。虽然这里存在一些学习成本,但这是值得的。

  • 如果受影响的输入值的可见性是 VIS_PUBLIC,如上例中的 n,您可以将其标记为 static_argnames,这些受影响的输入值将被编译到 XLA 程序中。

更多例子#

如果您想查看更多示例,请查看 SPU 存储库中的 Python 示例 。在大多数示例中,MPC 部分都是使用 jax.numpy 包编写的。您会发现我们正在大量使用 jax.lax API 和 static_argnames 来使 JAX 程序变得 jittable!

使用 SPU 执行 JAX 程序#

一旦您准备好 jittable 的 JAX 程序,我们就可以使用 SPU 执行它们!

(可选)SPU 模拟#

如果您希望快速尝试,我想向您介绍 spu.sim_jax。让我们展示它是如何工作的!

spu.sim_jax 仅在 spu v0.3.1b8 之后提供。

在这里,我们使用以下设置创建一个 SPU 模拟器:- 三方- 使用 ABY3 协议。在 此处 检查所有支持的协议。- 64位字段,SPU中的值在2^64环上表示。

但是,如果您只是想确认 JAX 程序是否可以被 SPU 执行,那么任何设置都应该没问题。不同的设置只会影响经过的时间和计算的精度。

[12]:
from spu.utils.simulation import Simulator, sim_jax

sim = Simulator.simple(3, spu.ProtocolKind.ABY3, spu.FieldType.FM64)

spu_euclidean_distance_fn = sim_jax(sim, euclidean_distance)

spu_euclidean_distance_fn(np.array([0, 0]), np.array([3, 4]))
[12]:
array(4.999962, dtype=float32)

如果您重复执行上面的代码,您可能会发现两次运行的结果略有不同,这是由于 MPC 计算中的随机性所致。

在使用 euclidean_distance 进行测试后,我们尝试使用 area_of_simple_polygon

[13]:
spu_area_of_simple_polygon_fn=sim_jax(sim, area_of_simple_polygon)

spu_area_of_simple_polygon_fn(vertices)
[13]:
array(16.5, dtype=float32)

使用 SPU 设备运行#

最后,我们将使用 SecretFlow 运行 JAX 程序。

如果您查看了其他教程,我想您应该熟悉以下步骤。

在这里,我们创建了一个包含三个设备的本地独立 SecretFlow 集群:

  • 两个PYU设备 - alicebob

  • 一个SPU设备

[14]:
import secretflow as sf

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(parties=['alice', 'bob'], address='local')

alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))
2023-03-21 21:03:27,671 INFO worker.py:1538 -- Started a local Ray instance.

我们先用 spu 设备尝试 euclidean_distance

[15]:
p1 = sf.to(alice,np.array([0, 0]))
p2 = sf.to(bob,np.array([3, 4]))

distance = spu(euclidean_distance)(p1, p2)

sf.reveal(distance)
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=2316721) WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=2316358) WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[15]:
array(5., dtype=float32)

然后我们尝试**area_of_simple_polygon**。

[16]:
v = sf.to(alice, vertices)
area = spu(area_of_simple_polygon)(v)

sf.reveal(area)

[16]:
array(16.5, dtype=float32)

总结#

本教程到此结束。让我们总结一下使用 JAX NumPy APIS 进行隐私保护科学计算的步骤:

  1. 编写一个 jittable JAX NumPy 程序。您应该使用 jax.jitjax.xla_computation 对其进行测试。

  2. 可选)尝试使用 SPU 模拟 执行 JAX 程序。

  3. Run this JAX NumPy with SPU device in SecretFlow.

如果您发现您的 JAX 程序是 jittable 但在 SPU 编译器或运行时失败。请查看 JAX NumPy Operators StatusXLA Implementation Status 。或者您可以通过 GitHub Issues 直接联系我们。