Logistic Regression with SPU#
The following codes are demos only. It’s NOT for production due to system security concerns, please DO NOT use it directly in production.
SPU is a domain specific compiler and runtime suite, which provides provable secure computation service. SPU compiler uses XLA as its front-end IR, which supports diverse AI framework (like Tensorflow, JAX and PyTorch). SPU compiler translates XLA to an IR which could be interpreted by the SPU runtime. Currently SPU team highly recommends using JAX as the frontend.
Learning Objectives:#
After doing this lab, you’ll know how to:
How to write a Logistic Regression trainning program with JAX.
How to convert a JAX program to an SPU(MPC) program painlessly.
In this lab, we select Breast Cancer as the dataset. We need to decide whether cancer is malignant or benign with 30 features. In the MPC program, two parties will train the model jointly and each party would provide half of features(15).
While, first, let’s just forget MPC settings and just write a Logistic Regression training program with JAX directly.
Train a model with JAX#
Load the Dataset#
We are going to split the whole dataset into train and test subsets after normalization with breast_cancer. * if train is True, returns train subsets. In order to simulate training with vertical dataset splitting, the party_id is provided. * else, returns test subsets.
[1]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer
def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
x, y = load_breast_cancer(return_X_y=True)
x = (x - np.min(x)) / (np.max(x) - np.min(x))
x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=0.2, random_state=42
)
if train:
if party_id:
if party_id == 1:
return x_train[:, :15], _
else:
return x_train[:, 15:], y_train
else:
return x_train, y_train
else:
return x_test, y_test
Define the Model#
First, let’s define the loss function, which is a negative log-likelihood in our case.
[2]:
import jax.numpy as jnp
def sigmoid(x):
return 1 / (1 + jnp.exp(-x))
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b, inputs, targets):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.mean(jnp.log(label_probs))
Second, let’s define a single train step with SGD optimizer. Just to remind you, x1 represents 15 features from one party while x2 represents the other 15 features from the other party.
[3]:
from jax import grad
def train_step(W, b, x1, x2, y, learning_rate):
x = jnp.concatenate([x1, x2], axis=1)
Wb_grad = grad(loss, (0, 1))(W, b, x, y)
W -= learning_rate * Wb_grad[0]
b -= learning_rate * Wb_grad[1]
return W, b
Last, let’s build everything together as a fit method which returns the model and losses of each epoch.
[4]:
def fit(W, b, x1, x2, y, epochs=1, learning_rate=1e-2):
for _ in range(epochs):
W, b = train_step(W, b, x1, x2, y, learning_rate=learning_rate)
return W, b
Validate the Model#
We could use the AUC to validate a binary classification model.
[5]:
from sklearn.metrics import roc_auc_score
def validate_model(W, b, X_test, y_test):
y_pred = predict(W, b, X_test)
return roc_auc_score(y_test, y_pred)
Have a try!#
Let’s put everything we have together and train a LR model!
[6]:
%matplotlib inline
# Load the data
x1, _ = breast_cancer(party_id=1,train=True)
x2, y = breast_cancer(party_id=2,train=True)
# Hyperparameter
W = jnp.zeros((30,))
b = 0.0
epochs = 10
learning_rate = 1e-2
# Train the model
W, b = fit(W, b, x1, x2, y, epochs=10, learning_rate=1e-2)
# Validate the model
X_test, y_test = breast_cancer(train=False)
auc=validate_model(W,b, X_test, y_test)
print(f'auc={auc}')
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
auc=0.9878807730101539
Just remember the AUC here since we would like to do the similar thing with SPU!
Train a Model with SPU#
At this part, we are going to show you how to do the similar training with MPC securely!
Init the Environment#
We are going to init three virtual devices on our physical environment. - alice, bob:Two PYU devices for local plaintext computation. - spu:SPU device consists with alice and bob for MPC secure computation.
[7]:
import secretflow as sf
# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob'], address='local')
alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))
INFO:root:Run secretflow in simulation mode.
2023-03-14 12:59:36,141 INFO worker.py:1538 -- Started a local Ray instance.
Load the Dataset#
we instruct alice and bob to load the train subset respectively.
[8]:
x1, _ = alice(breast_cancer)(party_id=1)
x2, y = bob(breast_cancer)(party_id=2)
x1, x2, y
[8]:
(<secretflow.device.device.pyu.PYUObject at 0x7f74b6586a30>,
<secretflow.device.device.pyu.PYUObject at 0x7f74b58fd6d0>,
<secretflow.device.device.pyu.PYUObject at 0x7f74b58fd8e0>)
Before training, we need to pass hyperparamters and all data to SPU device. SecretFlow provides two methods: - secretflow.to: transfer a PythonObject or DeviceObject to a specific device. - DeviceObject.to: transfer the DeviceObject to a specific device.
[9]:
device = spu
W = jnp.zeros((30,))
b = 0.0
W_, b_, x1_, x2_, y_ = (
sf.to(alice, W).to(device),
sf.to(alice, b).to(device),
x1.to(device),
x2.to(device),
y.to(device),
)
(_run pid=1517660) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=1517660) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1517660) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1517660) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=1517660) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=1517660) WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=1515629) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=1515629) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1515629) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1515629) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=1515629) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=1515629) WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=1517309) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=1517309) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1517309) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1517309) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=1517309) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=1517309) WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=1516708) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=1516708) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1516708) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=1516708) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=1516708) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=1516708) WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Train the model#
Now we are ready to train a LR model with SPU. After training, losses and model are SPUObjects which are still secret.
[10]:
W_, b_ = device(
fit,
static_argnames=['epochs'],
num_returns_policy=sf.device.SPUCompilerNumReturnsPolicy.FROM_USER,
user_specified_num_returns=2,
)(W_, b_, x1_, x2_, y_, epochs=10, learning_rate=1e-2)
W_, b_
[10]:
(<secretflow.device.device.spu.SPUObject at 0x7f7608ad88e0>,
<secretflow.device.device.spu.SPUObject at 0x7f7608ad83a0>)
(SPURuntime pid=1525949) 2023-03-14 12:59:46.548 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
(SPURuntime pid=1525959) 2023-03-14 12:59:46.548 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
Reveal the result#
In order to check the trained model, we need to convert SPUObject(secret) to Python object(plaintext). SecretFlow provide sf.reveal to convert any DeviceObject to Python object.
Be care with
sf.reveal,since it may result in secret leak。
Finally, let’s validate the model with AUC.
[11]:
auc = validate_model(sf.reveal(W_), sf.reveal(b_), X_test, y_test)
print(f'auc={auc}')
auc=0.987880773010154
You may find the model from SPU training program achieve the same AUC as JAX program.
This is the end of lab.