使用 SPU 进行逻辑回归#

以下代码仅作为示例,请勿在生产环境直接使用。

SPU 是一个特定领域的编译器和运行时套件,提供可证明的安全计算服务。SPU编译器使用 XLA 作为前端IR,支持多种AI框架(如Tensorflow、JAX和PyTorch)。SPU 编译器将 XLA 转换为可由 SPU 运行时解释的 IR。 目前 SPU 团队强烈推荐使用 JAX 作为前端。

学习目标:#

完成本实验后,您将知道如何:

  • 如何使用 JAX 编写逻辑回归训练程序。

  • 如何轻松地将 JAX 程序转换为 SPU(MPC) 程序。

在本实验室中,我们选择 Breast Cancer 作为数据集。 我们需要通过 30 个特征来判断癌症是恶性还是良性。 在 MPC 程序中,两方共同训练模型,每一方提供一半的特征(15)。

首先,让我们忘记 MPC 语意,直接使用 JAX 编写逻辑回归训练程序。

使用 JAX 训练模型#

加载数据集#

我们将在使用“breast_cancer”标准化后将整个数据集拆分为训练和测试子集。 如果 trainTrue,返回训练子集,另外,为了模拟垂直数据集拆分的训练,还需要提供“party_id”参数。 否则,返回测试子集。

[1]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer


def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
    x, y = load_breast_cancer(return_X_y=True)
    x = (x - np.min(x)) / (np.max(x) - np.min(x))
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.2, random_state=42
    )

    if train:
        if party_id:
            if party_id == 1:
                return x_train[:, :15], _
            else:
                return x_train[:, 15:], y_train
        else:
            return x_train, y_train
    else:
        return x_test, y_test

定义模型#

首先,让我们定义损失函数,在我们的例子中它是一个负对数似然。

[2]:
import jax.numpy as jnp


def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))


# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)


# Training loss is the negative log-likelihood of the training examples.
def loss(W, b, inputs, targets):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.mean(jnp.log(label_probs))

其次,让我们使用 SGD 优化器定义单个训练步骤。 提醒一下,x1 代表来自一方的 15 个特征,而 x2 代表来自另一方的其他 15 个特征。

[3]:
from jax import grad


def train_step(W, b, x1, x2, y, learning_rate):
    x = jnp.concatenate([x1, x2], axis=1)
    Wb_grad = grad(loss, (0, 1))(W, b, x, y)
    W -= learning_rate * Wb_grad[0]
    b -= learning_rate * Wb_grad[1]
    return W, b

最后,让我们将所有内容构建为“fit”方法,该方法返回每个epoch的模型和损失。

[4]:
def fit(W, b, x1, x2, y, epochs=1, learning_rate=1e-2):
    for _ in range(epochs):
        W, b = train_step(W, b, x1, x2, y, learning_rate=learning_rate)
    return W, b

验证模型#

我们可以使用 AUC 来验证二元分类模型。

[5]:
from sklearn.metrics import roc_auc_score


def validate_model(W, b, X_test, y_test):
    y_pred = predict(W, b, X_test)
    return roc_auc_score(y_test, y_pred)

试试!#

让我们把所有的东西放在一起,训练一个 LR 模型!

[6]:
%matplotlib inline

# Load the data
x1, _ = breast_cancer(party_id=1,train=True)
x2, y = breast_cancer(party_id=2,train=True)

# Hyperparameter
W = jnp.zeros((30,))
b = 0.0
epochs = 10
learning_rate = 1e-2

# Train the model
W, b = fit(W, b, x1, x2, y, epochs=10, learning_rate=1e-2)

# Validate the model
X_test, y_test = breast_cancer(train=False)
auc=validate_model(W,b, X_test, y_test)
print(f'auc={auc}')

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
auc=0.9878807730101539

请记住这里的 AUC,因为我们想用 SPU 做类似的事情!

使用 SPU 训练模型#

在这一部分,我们将向您展示如何安全地使用 MPC 进行类似的训练!

初始化环境#

我们将在我们的物理环境中初始化三个虚拟设备。 - alice, bob:两个用于本地明文计算的 PYU 设备。- spu:SPU 设备由 alice 和 bob 组成,用于 MPC 安全计算。

[7]:
import secretflow as sf

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(['alice', 'bob'], address='local')

alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))

INFO:root:Run secretflow in simulation mode.
2023-03-14 12:59:36,141 INFO worker.py:1538 -- Started a local Ray instance.

加载数据集#

我们指示 alice 和 bob 分别加载训练子集。

[8]:
x1, _ = alice(breast_cancer)(party_id=1)
x2, y = bob(breast_cancer)(party_id=2)

x1, x2, y

[8]:
(<secretflow.device.device.pyu.PYUObject at 0x7f74b6586a30>,
 <secretflow.device.device.pyu.PYUObject at 0x7f74b58fd6d0>,
 <secretflow.device.device.pyu.PYUObject at 0x7f74b58fd8e0>)

在训练之前,我们需要将超参数和所有数据传递给 SPU 设备。 SecretFlow 提供两种方法: - secretflow.to:将 PythonObject 或 DeviceObject 传输到特定设备。 - DeviceObject.to:将 DeviceObject 传输到特定设备。

[9]:
device = spu

W = jnp.zeros((30,))
b = 0.0

W_, b_, x1_, x2_, y_ = (
    sf.to(alice, W).to(device),
    sf.to(alice, b).to(device),
    x1.to(device),
    x2.to(device),
    y.to(device),
)

(_run pid=1517660) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=1517660) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1517660) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1517660) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=1517660) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=1517660) WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=1515629) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=1515629) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1515629) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1515629) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=1515629) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=1515629) WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=1517309) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=1517309) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1517309) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1517309) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=1517309) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=1517309) WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=1516708) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=1516708) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1516708) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1516708) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=1516708) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=1516708) WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

训练模型#

现在我们准备好用 SPU 训练一个 LR 模型。 经过训练,损失和模型是仍然保密的 SPU 对象。

[10]:
W_, b_ = device(
    fit,
    static_argnames=['epochs'],
    num_returns_policy=sf.device.SPUCompilerNumReturnsPolicy.FROM_USER,
    user_specified_num_returns=2,
)(W_, b_, x1_, x2_, y_, epochs=10, learning_rate=1e-2)

W_, b_

[10]:
(<secretflow.device.device.spu.SPUObject at 0x7f7608ad88e0>,
 <secretflow.device.device.spu.SPUObject at 0x7f7608ad83a0>)
(SPURuntime pid=1525949) 2023-03-14 12:59:46.548 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
(SPURuntime pid=1525959) 2023-03-14 12:59:46.548 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127

揭示结果#

为了检查训练的模型,我们需要将 SPUObject(secret) 转换为 Python object(明文)。 SecretFlow 提供 sf.reveal 将任何 DeviceObject 转换为 Python object。

请小心使用 sf.reveal,因为它可能导致秘密泄露。

最后,让我们用 AUC 验证模型。

[11]:
auc = validate_model(sf.reveal(W_), sf.reveal(b_), X_test, y_test)
print(f'auc={auc}')

auc=0.987880773010154

你可能会发现 SPU 训练程序中的模型达到了与 JAX 程序相同的 AUC。

实验到此结束。