在SPU中使用NumPy实现隐私保护的科学计算#
NumPy是最流行的科学计算工具之一。它非常常见,以至于我们可以在其他语言中找到许多类似于NumPy的工具,比如 xtensor 和 Gonum 。因此,我们不禁想到,是否可以在隐私保护的环境中使用类似于NumPy的API表达计算,因为每个人都喜欢NumPy。
幸运的是,借助 JAX NumPy软件包的强大功能,我们可以轻松实现这个目标。在本教程中,我们将介绍以下内容:1. JAX和SPU之间的关系2. 编写可编译的JAX程序3. 在SPU上执行JAX程序
JAX和SPU之间的关系#
TL;DR#
SPU实际上由两个组件组成 - 编译器和运行时。SPU运行时只能执行 PPHlo 。一个PPHlo内核的例子是 **pphlo.add** 。实际上,当逻辑非常简单和清晰时,我们只需将PPHlo程序直接输入SPU运行时以执行MPC程序,用于一些内部应用。
SPU编译器可以将 XLA 程序转换为 PPHlo 。您可以在 此文档 中检查“支持的”XLA操作。在大多数情况下,您会发现XLA操作与PPHlo操作非常相似。似乎我们仍然不能手动编写XLA程序。事实的确如此。如果您查看 这里 ,您应该会发现实际上有很多人工智能框架可以自动生成XLA程序,包括:
TensorFLow
Pytorch
JAX
让我们逐步看一下不同的程序!
JAX 程序#
下面是一个JAX程序,用于将一个数组和一个标量相加。如果您熟悉NumPy,这应该很容易理解。
[1]:
import jax
import numpy as np
def simple_add(x, y):
return jax.numpy.add(x, y)
simple_add(np.array([[1, 2], [3, 4]]), 4)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[1]:
Array([[5, 6],
[7, 8]], dtype=int32)
XLA 程序#
让我们来看看这个JAX程序的XLA程序是什么样子。JAX提供 xla_computation 函数将JAX程序转换为XLA程序。
[2]:
c = jax.xla_computation(simple_add)(np.array([[1, 2], [3, 4]]), 4)
c.as_hlo_text()
[2]:
'HloModule xla_computation_simple_add, entry_computation_layout={(s32[2,2]{1,0},s32[])->(s32[2,2]{1,0})}\n\nENTRY main.6 {\n Arg_0.1 = s32[2,2]{1,0} parameter(0)\n Arg_1.2 = s32[] parameter(1)\n broadcast.3 = s32[2,2]{1,0} broadcast(Arg_1.2), dimensions={}\n add.4 = s32[2,2]{1,0} add(Arg_0.1, broadcast.3)\n ROOT tuple.5 = (s32[2,2]{1,0}) tuple(add.4)\n}\n\n'
您应该知道以下事实:
在XLA程序中,每个参数的形状和数据类型都是固定的,如 s32[2,2]{1,0}
在 add 操作之前会插入一个隐式的 broadcast 操作。
PPHlo 程序#
最后,让我们查看这个XLA程序的PPHlo程序。 spu.compile 函数可以将XLA程序转换为PPHlo程序。
[3]:
import spu
pphlo = spu.compile(
c.as_serialized_hlo_module_proto(),
"hlo",
[spu.Visibility.VIS_SECRET, spu.Visibility.VIS_SECRET],
)
pphlo
[3]:
b'module @xla_computation_simple_add attributes {mhlo.cross_program_prefetches = [], mhlo.dynamic_parameter_bindings = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {\n func.func @main(%arg0: tensor<2x2x!pphlo.sec<i32>>, %arg1: tensor<!pphlo.sec<i32>>) -> tensor<2x2x!pphlo.sec<i32>> {\n %0 = "pphlo.broadcast"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<!pphlo.sec<i32>>) -> tensor<2x2x!pphlo.sec<i32>>\n %1 = "pphlo.add"(%arg0, %0) : (tensor<2x2x!pphlo.sec<i32>>, tensor<2x2x!pphlo.sec<i32>>) -> tensor<2x2x!pphlo.sec<i32>>\n return %1 : tensor<2x2x!pphlo.sec<i32>>\n }\n}\n'
您可能会发现,PPHlo程序与XLA程序几乎完全相同。区别是:
您必须向SPU编译器提供输入可见性,例如我们的情况下为 [spu.Visibility.VIS_SECRET, spu.Visibility.VIS_SECRET]
与XLA程序相比,Visibility是PPHlo程序中所有变量的额外属性。例如, tensor<2x2x!pphlo.sec> 表示这是一个密态的2x2 i32张量。
SPU编译器会在每一步中推导可见性,让我们修改输入可见性并查看会发生什么。
[4]:
pphlo = spu.compile(
c.as_serialized_hlo_module_proto(),
"hlo",
[spu.Visibility.VIS_SECRET, spu.Visibility.VIS_PUBLIC],
)
pphlo
[4]:
b'module @xla_computation_simple_add attributes {mhlo.cross_program_prefetches = [], mhlo.dynamic_parameter_bindings = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {\n func.func @main(%arg0: tensor<2x2x!pphlo.sec<i32>>, %arg1: tensor<!pphlo.pub<i32>>) -> tensor<2x2x!pphlo.sec<i32>> {\n %0 = "pphlo.broadcast"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<!pphlo.pub<i32>>) -> tensor<2x2x!pphlo.pub<i32>>\n %1 = "pphlo.add"(%arg0, %0) : (tensor<2x2x!pphlo.sec<i32>>, tensor<2x2x!pphlo.pub<i32>>) -> tensor<2x2x!pphlo.sec<i32>>\n return %1 : tensor<2x2x!pphlo.sec<i32>>\n }\n}\n'
从JAX到SPU#
这就是整个故事。1. 您可以使用Python编写JAX程序。2. 您可以使用JAX的第一方API(即jax.xla_computation)将JAX程序转换为XLA程序。3. SPU编译器可以将XLA程序转换为PPHlo程序——这是SPU运行时唯一能理解的语言。4. 将PPHlo程序发送到SPU运行时并执行。
在SecretFlow中,我们已经实现了一些辅助方法,这样您就可以一开始就编写JAX程序,我们会为您处理其余的步骤。
写一个可即时编译(Jittable)的JAX程序#
可即时编译(Jittable)指的是JAX程序可以即时编译(JIT)为XLA程序。因此,只有当一个JAX程序是可即时编译的,它才可能被SPU执行。
由于SPU不支持所有的XLA运算符,即使一个JAX程序是可即时编译的,SPU运行时仍然可能拒绝执行。
JAX NumPy库#
我们可以使用 JAX 中的这些 类似于 NumPy 的 API。JAX NumPy 的 API 和原始的 NumPy API 非常相似,但是有以下几点不同:- JAX NumPy 数组是不可变的,因此您必须使用 ndarray.at 来进行非就地数组更新。- 您必须提供一些额外的参数使方法调用可即时编译(稍后我们会讨论这个)。
实际上,SPU 不支持所有的 JAX NumPy 操作符,请查看这份 文档 .我们正在及时更新这份文档,并列出了每个操作符的当前状态。
接下来,我们将编写一些 JAX NumPy 程序。
欧几里得距离#
只需一行代码,我们就可以计算两个点的欧几里得距离。
[5]:
def euclidean_distance(p1, p2):
return jax.numpy.linalg.norm(p1 - p2)
我们可以用 jax.jit 来检查它是否是jittable的, 也可以用 jax.xla_computation
[6]:
euclidean_distance_jit = jax.jit(euclidean_distance)
print(euclidean_distance_jit(np.array([0, 0]), np.array([3, 4])))
# or
print(
(
jax.xla_computation(euclidean_distance)(np.array([0, 0]), np.array([3, 4]))
).as_hlo_text()
)
5.0
HloModule xla_computation_euclidean_distance, entry_computation_layout={(s32[2]{0},s32[2]{0})->(f32[])}
region_0.4 {
Arg_0.5 = f32[] parameter(0)
Arg_1.6 = f32[] parameter(1)
ROOT add.7 = f32[] add(Arg_0.5, Arg_1.6)
}
norm.8 {
Arg_0.9 = s32[2]{0} parameter(0)
convert.11 = f32[2]{0} convert(Arg_0.9)
multiply.12 = f32[2]{0} multiply(convert.11, convert.11)
constant.10 = f32[] constant(0)
reduce.13 = f32[] reduce(multiply.12, constant.10), dimensions={0}, to_apply=region_0.4
ROOT sqrt.14 = f32[] sqrt(reduce.13)
}
ENTRY main.17 {
Arg_0.1 = s32[2]{0} parameter(0)
Arg_1.2 = s32[2]{0} parameter(1)
subtract.3 = s32[2]{0} subtract(Arg_0.1, Arg_1.2)
call.15 = f32[] call(subtract.3), to_apply=norm.8
ROOT tuple.16 = (f32[]) tuple(call.15)
}
简单多边形面积#
给定一个简单多边形顶点的笛卡尔坐标列表,我们可以通过 Shoelace formula 计算面积
[7]:
import jax.numpy as jnp
def area_of_simple_polygon(vertices):
area = 0
for i in range(0, vertices.shape[0]):
a = jnp.expand_dims(vertices[i, :], axis=0)
b = jnp.expand_dims(vertices[(i + 1) % vertices.shape[0], :], axis=0)
x = jax.numpy.concatenate((a, b))
x_t = jax.numpy.transpose(x)
area += 0.5 * jax.numpy.linalg.det(x_t)
return area
vertices = np.array([[1, 6], [3, 1], [7, 2], [4, 4], [8, 5]])
area_of_simple_polygon(vertices)
[7]:
Array(16.5, dtype=float32)
让我们检查一下 area_of_simple_polygon 是否是 jittable 的。
[8]:
print(jax.xla_computation(area_of_simple_polygon)(vertices).as_hlo_text())
HloModule xla_computation_area_of_simple_polygon, entry_computation_layout={(s32[5,2]{1,0})->(f32[])}
det.7 {
Arg_0.8 = s32[2,2]{1,0} parameter(0)
convert.9 = f32[2,2]{1,0} convert(Arg_0.8)
slice.10 = f32[1,1]{1,0} slice(convert.9), slice={[0:1], [0:1]}
reshape.11 = f32[] reshape(slice.10)
slice.12 = f32[1,1]{1,0} slice(convert.9), slice={[1:2], [1:2]}
reshape.13 = f32[] reshape(slice.12)
multiply.14 = f32[] multiply(reshape.11, reshape.13)
slice.15 = f32[1,1]{1,0} slice(convert.9), slice={[0:1], [1:2]}
reshape.16 = f32[] reshape(slice.15)
slice.17 = f32[1,1]{1,0} slice(convert.9), slice={[1:2], [0:1]}
reshape.18 = f32[] reshape(slice.17)
multiply.19 = f32[] multiply(reshape.16, reshape.18)
ROOT subtract.20 = f32[] subtract(multiply.14, multiply.19)
}
det_0.27 {
Arg_0.28 = s32[2,2]{1,0} parameter(0)
convert.29 = f32[2,2]{1,0} convert(Arg_0.28)
slice.30 = f32[1,1]{1,0} slice(convert.29), slice={[0:1], [0:1]}
reshape.31 = f32[] reshape(slice.30)
slice.32 = f32[1,1]{1,0} slice(convert.29), slice={[1:2], [1:2]}
reshape.33 = f32[] reshape(slice.32)
multiply.34 = f32[] multiply(reshape.31, reshape.33)
slice.35 = f32[1,1]{1,0} slice(convert.29), slice={[0:1], [1:2]}
reshape.36 = f32[] reshape(slice.35)
slice.37 = f32[1,1]{1,0} slice(convert.29), slice={[1:2], [0:1]}
reshape.38 = f32[] reshape(slice.37)
multiply.39 = f32[] multiply(reshape.36, reshape.38)
ROOT subtract.40 = f32[] subtract(multiply.34, multiply.39)
}
det_1.48 {
Arg_0.49 = s32[2,2]{1,0} parameter(0)
convert.50 = f32[2,2]{1,0} convert(Arg_0.49)
slice.51 = f32[1,1]{1,0} slice(convert.50), slice={[0:1], [0:1]}
reshape.52 = f32[] reshape(slice.51)
slice.53 = f32[1,1]{1,0} slice(convert.50), slice={[1:2], [1:2]}
reshape.54 = f32[] reshape(slice.53)
multiply.55 = f32[] multiply(reshape.52, reshape.54)
slice.56 = f32[1,1]{1,0} slice(convert.50), slice={[0:1], [1:2]}
reshape.57 = f32[] reshape(slice.56)
slice.58 = f32[1,1]{1,0} slice(convert.50), slice={[1:2], [0:1]}
reshape.59 = f32[] reshape(slice.58)
multiply.60 = f32[] multiply(reshape.57, reshape.59)
ROOT subtract.61 = f32[] subtract(multiply.55, multiply.60)
}
det_2.69 {
Arg_0.70 = s32[2,2]{1,0} parameter(0)
convert.71 = f32[2,2]{1,0} convert(Arg_0.70)
slice.72 = f32[1,1]{1,0} slice(convert.71), slice={[0:1], [0:1]}
reshape.73 = f32[] reshape(slice.72)
slice.74 = f32[1,1]{1,0} slice(convert.71), slice={[1:2], [1:2]}
reshape.75 = f32[] reshape(slice.74)
multiply.76 = f32[] multiply(reshape.73, reshape.75)
slice.77 = f32[1,1]{1,0} slice(convert.71), slice={[0:1], [1:2]}
reshape.78 = f32[] reshape(slice.77)
slice.79 = f32[1,1]{1,0} slice(convert.71), slice={[1:2], [0:1]}
reshape.80 = f32[] reshape(slice.79)
multiply.81 = f32[] multiply(reshape.78, reshape.80)
ROOT subtract.82 = f32[] subtract(multiply.76, multiply.81)
}
det_3.90 {
Arg_0.91 = s32[2,2]{1,0} parameter(0)
convert.92 = f32[2,2]{1,0} convert(Arg_0.91)
slice.93 = f32[1,1]{1,0} slice(convert.92), slice={[0:1], [0:1]}
reshape.94 = f32[] reshape(slice.93)
slice.95 = f32[1,1]{1,0} slice(convert.92), slice={[1:2], [1:2]}
reshape.96 = f32[] reshape(slice.95)
multiply.97 = f32[] multiply(reshape.94, reshape.96)
slice.98 = f32[1,1]{1,0} slice(convert.92), slice={[0:1], [1:2]}
reshape.99 = f32[] reshape(slice.98)
slice.100 = f32[1,1]{1,0} slice(convert.92), slice={[1:2], [0:1]}
reshape.101 = f32[] reshape(slice.100)
multiply.102 = f32[] multiply(reshape.99, reshape.101)
ROOT subtract.103 = f32[] subtract(multiply.97, multiply.102)
}
ENTRY main.108 {
Arg_0.1 = s32[5,2]{1,0} parameter(0)
slice.3 = s32[1,2]{1,0} slice(Arg_0.1), slice={[0:1], [0:2]}
slice.4 = s32[1,2]{1,0} slice(Arg_0.1), slice={[1:2], [0:2]}
concatenate.5 = s32[2,2]{1,0} concatenate(slice.3, slice.4), dimensions={0}
transpose.6 = s32[2,2]{0,1} transpose(concatenate.5), dimensions={1,0}
call.21 = f32[] call(transpose.6), to_apply=det.7
constant.2 = f32[] constant(0.5)
multiply.22 = f32[] multiply(call.21, constant.2)
slice.23 = s32[1,2]{1,0} slice(Arg_0.1), slice={[1:2], [0:2]}
slice.24 = s32[1,2]{1,0} slice(Arg_0.1), slice={[2:3], [0:2]}
concatenate.25 = s32[2,2]{1,0} concatenate(slice.23, slice.24), dimensions={0}
transpose.26 = s32[2,2]{0,1} transpose(concatenate.25), dimensions={1,0}
call.41 = f32[] call(transpose.26), to_apply=det_0.27
multiply.42 = f32[] multiply(call.41, constant.2)
add.43 = f32[] add(multiply.22, multiply.42)
slice.44 = s32[1,2]{1,0} slice(Arg_0.1), slice={[2:3], [0:2]}
slice.45 = s32[1,2]{1,0} slice(Arg_0.1), slice={[3:4], [0:2]}
concatenate.46 = s32[2,2]{1,0} concatenate(slice.44, slice.45), dimensions={0}
transpose.47 = s32[2,2]{0,1} transpose(concatenate.46), dimensions={1,0}
call.62 = f32[] call(transpose.47), to_apply=det_1.48
multiply.63 = f32[] multiply(call.62, constant.2)
add.64 = f32[] add(add.43, multiply.63)
slice.65 = s32[1,2]{1,0} slice(Arg_0.1), slice={[3:4], [0:2]}
slice.66 = s32[1,2]{1,0} slice(Arg_0.1), slice={[4:5], [0:2]}
concatenate.67 = s32[2,2]{1,0} concatenate(slice.65, slice.66), dimensions={0}
transpose.68 = s32[2,2]{0,1} transpose(concatenate.67), dimensions={1,0}
call.83 = f32[] call(transpose.68), to_apply=det_2.69
multiply.84 = f32[] multiply(call.83, constant.2)
add.85 = f32[] add(add.64, multiply.84)
slice.86 = s32[1,2]{1,0} slice(Arg_0.1), slice={[4:5], [0:2]}
slice.87 = s32[1,2]{1,0} slice(Arg_0.1), slice={[0:1], [0:2]}
concatenate.88 = s32[2,2]{1,0} concatenate(slice.86, slice.87), dimensions={0}
transpose.89 = s32[2,2]{0,1} transpose(concatenate.88), dimensions={1,0}
call.104 = f32[] call(transpose.89), to_apply=det_3.90
multiply.105 = f32[] multiply(call.104, constant.2)
add.106 = f32[] add(add.85, multiply.105)
ROOT tuple.107 = (f32[]) tuple(add.106)
}
我们是否可以 Jit 所有程序#
绝对不是,请查看来自 JAX 的 此文档 !
导致程序无法运行的最常见原因是您的控制流依赖于 input 的值。例如,
[9]:
# Cited from JAX documentation.
# While loop conditioned on x and n.
def g(x, n):
i = 0
while i < n:
i += 1
return x + i
g_jit = jax.jit(g)
import traceback
try:
g_jit(10, 20) # Should raise an error.
except Exception:
traceback.print_exc()
Traceback (most recent call last):
File "/tmp/ipykernel_2314574/682526420.py", line 14, in <module>
g_jit(10, 20) # Should raise an error.
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 163, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 235, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 179, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/api.py", line 440, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 513, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 965, in _pjit_jaxpr
jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
ans = call(fun, *args)
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/pjit.py", line 923, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2033, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2050, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/ipykernel_2314574/682526420.py", line 6, in g
while i < n:
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/core.py", line 653, in __bool__
def __bool__(self): return self.aval._bool(self)
File "/home/fengjun.feng/miniconda3/envs/sf/lib/python3.8/site-packages/jax/_src/core.py", line 1344, in error
raise ConcretizationTypeError(arg, fname_context)
jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `bool` function.
The error occurred while tracing the function g at /tmp/ipykernel_2314574/682526420.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument 'n'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
有两种可能的解决方案。1. 您可以用低级的 **jax.lax** API 替换控制流。您需要花一些时间弄清楚如何使用这些 API。
[10]:
def g_with_lax_control_flow(x, n):
def body_fun(i):
i += 1
return i
return x + jax.lax.while_loop(lambda i: i < n, body_fun, 0)
g_with_lax_control_flow_jit = jax.jit(g_with_lax_control_flow)
g_with_lax_control_flow_jit(10, 20) # good to go!
[10]:
Array(30, dtype=int32, weak_type=True)
另一种可能的解决方案是使用 static_argnames。
[11]:
g_with_static_argnames_jit = jax.jit(g, static_argnames=['n'])
g_with_static_argnames_jit(10, 20) # good to go!
[11]:
Array(30, dtype=int32, weak_type=True)
so which method we should choose when the program is unjittable?
那么当程序 unjittable 时,我们应该选择哪种方法呢?
首先使用 jax.lax API 重写控制流。虽然这里存在一些学习成本,但这是值得的。
如果受影响的输入值的可见性是 VIS_PUBLIC,如上例中的 n,您可以将其标记为 static_argnames,这些受影响的输入值将被编译到 XLA 程序中。
更多例子#
如果您想查看更多示例,请查看 SPU 存储库中的 Python 示例 。在大多数示例中,MPC 部分都是使用 jax.numpy 包编写的。您会发现我们正在大量使用 jax.lax API 和 static_argnames 来使 JAX 程序变得 jittable!
使用 SPU 执行 JAX 程序#
一旦您准备好 jittable 的 JAX 程序,我们就可以使用 SPU 执行它们!
(可选)SPU 模拟#
如果您希望快速尝试,我想向您介绍 spu.sim_jax。让我们展示它是如何工作的!
spu.sim_jax 仅在 spu v0.3.1b8 之后提供。
在这里,我们使用以下设置创建一个 SPU 模拟器:- 三方- 使用 ABY3 协议。在 此处 检查所有支持的协议。- 64位字段,SPU中的值在2^64环上表示。
但是,如果您只是想确认 JAX 程序是否可以被 SPU 执行,那么任何设置都应该没问题。不同的设置只会影响经过的时间和计算的精度。
[12]:
from spu.utils.simulation import Simulator, sim_jax
sim = Simulator.simple(3, spu.ProtocolKind.ABY3, spu.FieldType.FM64)
spu_euclidean_distance_fn = sim_jax(sim, euclidean_distance)
spu_euclidean_distance_fn(np.array([0, 0]), np.array([3, 4]))
[12]:
array(4.999962, dtype=float32)
如果您重复执行上面的代码,您可能会发现两次运行的结果略有不同,这是由于 MPC 计算中的随机性所致。
在使用 euclidean_distance 进行测试后,我们尝试使用 area_of_simple_polygon。
[13]:
spu_area_of_simple_polygon_fn=sim_jax(sim, area_of_simple_polygon)
spu_area_of_simple_polygon_fn(vertices)
[13]:
array(16.5, dtype=float32)
使用 SPU 设备运行#
最后,我们将使用 SecretFlow 运行 JAX 程序。
如果您查看了其他教程,我想您应该熟悉以下步骤。
在这里,我们创建了一个包含三个设备的本地独立 SecretFlow 集群:
两个PYU设备 - alice 和 bob
一个SPU设备
[14]:
import secretflow as sf
# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(parties=['alice', 'bob'], address='local')
alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))
2023-03-21 21:03:27,671 INFO worker.py:1538 -- Started a local Ray instance.
我们先用 spu 设备尝试 euclidean_distance。
[15]:
p1 = sf.to(alice,np.array([0, 0]))
p2 = sf.to(bob,np.array([3, 4]))
distance = spu(euclidean_distance)(p1, p2)
sf.reveal(distance)
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2316721) INFO:jax._src.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=2316721) WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2316358) INFO:jax._src.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=2316358) WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[15]:
array(5., dtype=float32)
然后我们尝试**area_of_simple_polygon**。
[16]:
v = sf.to(alice, vertices)
area = spu(area_of_simple_polygon)(v)
sf.reveal(area)
[16]:
array(16.5, dtype=float32)
总结#
本教程到此结束。让我们总结一下使用 JAX NumPy APIS 进行隐私保护科学计算的步骤:
编写一个 jittable JAX NumPy 程序。您应该使用 jax.jit 或 jax.xla_computation 对其进行测试。
可选)尝试使用 SPU 模拟 执行 JAX 程序。
Run this JAX NumPy with SPU device in SecretFlow.
如果您发现您的 JAX 程序是 jittable 但在 SPU 编译器或运行时失败。请查看 JAX NumPy Operators Status 和 XLA Implementation Status 。或者您可以通过 GitHub Issues 直接联系我们。