Split Learning—Bank Marketing#
The following codes are demos only. It’s NOT for production due to system security concerns, please DO NOT use it directly in production.
In this tutorial, we will use the bank’s marketing model as an example to show how to accomplish split learning in vertical scenarios under the SecretFlow framework. SecretFlow provides a user-friendly Api that makes it easy to apply your Keras model or PyTorch model to split learning scenarios to complete joint modeling tasks for vertical scenarios.
In this tutorial we will show you how to turn your existing ‘Keras’ model into a split learning model under Secretflow to complete federated multi-party modeling tasks.
What is Split Learning?#
The core idea of split learning is to split the network structure. Each device (silo) retains only a part of the network structure, and the sub-network structure of all devices is combined together to form a complete network model. In the training process, different devices (silos) only perform forward or reverse calculation on the local network structure, and transfer the calculation results to the next device. Multiple devices complete the training through joint model until convergence.

Alice uses its data to get hidden0 through model_base_Alice and send it to Bob.
Bob gets hidden1 with its data through model_base_bob.
hidden_0 and hidden_1 are input to the AggLayer for aggregation, and the aggregated hidden_merge is the output.
Bob input hidden_merge to model_fuse, get the gradient with label and send it back.
The gradient is split into two parts g_0, g_1 through AggLayer, which are sent to Alice and Bob respectively.
Then Alice and Bob update their local base net with g_0 or g_1.
Task#
Marketing is the banking industry in the ever-changing market environment, to meet the needs of customers, to achieve business objectives of the overall operation and sales activities. In the current environment of big data, data analysis provides a more effective analysis means for the banking industry. Customer demand analysis, understanding of target market trends and more macro market strategies can provide the basis and direction.
The data from kaggle is a set of classic marketing data bank, is a Portuguese bank agency telephone direct marketing activities, The target variable is whether the customer subscribes to deposit product.
Data#
The total sample size was 11162, including 8929 training set and 2233 test set
Feature dim is 16, target is binary classification
We have cut the data in advance. Alice holds the 4-dimensional basic attribute features, Bob holds the 12-dimensional bank transaction features, and only Alice holds the corresponding label
Let’s start by looking at what our bank’s marketing data look like?
The original data is divided into Bank Alice and Bank Bob, which stores in Alice and Bob respectively. Here, CSV is the original data that has only been separated without pre-processing, we will use secretflow preprocess for FedData preprocess
[1]:
%load_ext autoreload
%autoreload 2
import secretflow as sf
import matplotlib.pyplot as plt
sf.shutdown()
sf.init(['alice', 'bob'], address='local')
alice, bob = sf.PYU('alice'), sf.PYU('bob')
2023-04-24 10:20:58,140 INFO worker.py:1538 -- Started a local Ray instance.
prepare data#
[2]:
import pandas as pd
from secretflow.utils.simulation.datasets import dataset
df = pd.read_csv(dataset('bank_marketing'), sep=';')
We assume that Alice is a new bank, and they only have the basic information of the user and purchased the label of financial products from other bank.
[3]:
alice_data = df[["age", "job", "marital", "education", "y"]]
alice_data
[3]:
| age | job | marital | education | y | |
|---|---|---|---|---|---|
| 0 | 30 | unemployed | married | primary | no |
| 1 | 33 | services | married | secondary | no |
| 2 | 35 | management | single | tertiary | no |
| 3 | 30 | management | married | tertiary | no |
| 4 | 59 | blue-collar | married | secondary | no |
| ... | ... | ... | ... | ... | ... |
| 4516 | 33 | services | married | secondary | no |
| 4517 | 57 | self-employed | married | tertiary | no |
| 4518 | 57 | technician | married | secondary | no |
| 4519 | 28 | blue-collar | married | secondary | no |
| 4520 | 44 | entrepreneur | single | tertiary | no |
4521 rows × 5 columns
Bob is an old bank, they have the user’s account balance, house, loan, and recent marketing feedback
[4]:
bob_data = df[["default", "balance", "housing", "loan", "contact",
"day","month","duration","campaign","pdays","previous","poutcome"]]
bob_data
[4]:
| default | balance | housing | loan | contact | day | month | duration | campaign | pdays | previous | poutcome | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | no | 1787 | no | no | cellular | 19 | oct | 79 | 1 | -1 | 0 | unknown |
| 1 | no | 4789 | yes | yes | cellular | 11 | may | 220 | 1 | 339 | 4 | failure |
| 2 | no | 1350 | yes | no | cellular | 16 | apr | 185 | 1 | 330 | 1 | failure |
| 3 | no | 1476 | yes | yes | unknown | 3 | jun | 199 | 4 | -1 | 0 | unknown |
| 4 | no | 0 | yes | no | unknown | 5 | may | 226 | 1 | -1 | 0 | unknown |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 4516 | no | -333 | yes | no | cellular | 30 | jul | 329 | 5 | -1 | 0 | unknown |
| 4517 | yes | -3313 | yes | yes | unknown | 9 | may | 153 | 1 | -1 | 0 | unknown |
| 4518 | no | 295 | no | no | cellular | 19 | aug | 151 | 11 | -1 | 0 | unknown |
| 4519 | no | 1137 | no | no | cellular | 6 | feb | 129 | 4 | 211 | 3 | other |
| 4520 | no | 1136 | yes | yes | cellular | 3 | apr | 345 | 2 | 249 | 7 | other |
4521 rows × 12 columns
Create Secretflow Environment#
Create 2 entities in the Secretflow environment [Alice, Bob] Where ‘Alice’ and ‘Bob’ are two PYU Once you’ve constructed the two objects, you can happily start Splitting Learning
Import Dependency#
[5]:
from secretflow.data.split import train_test_split
from secretflow.ml.nn import SLModel
2023-04-24 10:20:59.841732: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-04-24 10:21:00.576963: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-04-24 10:21:00.577064: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-04-24 10:21:00.577078: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Prepare Data#
Build Federated Table
Federated table is a virtual concept that cross multiple parties, We define VDataFrame for vertical setting
The data of all parties in a federated table is stored locally and is not allowed to go out of the domain.
No one has access to data store except the party that owns the data.
Any operation of the federated table will be scheduled by the driver to each worker, and the execution instructions will be delivered layer by layer until the Python Runtime of the specific worker. The framework ensures that only
worker.deviceandObject. device can operate data at the same time.Federated tables are designed to management and manipulation multi-party data from a central perspective.
Interfaces to
Federated Tableare aligned topandas.DataFrameto reduce the cost of multi-party data operations.The SecretFlow framework provides Plain&Ciphertext hybrid programming capabilities. Vertical federated tables are built using
SPU, andMPC-PSIis used to safely get intersection and align data from all parties.

VDataFrame provides read_csv interface similar to pandas, except that secretflow.read_csv receives a dictionary that defines the path of data for both parties. We can use secretflow.vertical.read_csv to build the VDataFrame.
read_csv(file_dict,delimiter,ppu,keys,drop_key)
filepath: Path of the participant file. The address can be a relative or absolute path to a local file
ppu: PPU Device for PSI; If this parameter is not specified, data must be prealigned
keys: Key for intersection
Create spu object
[6]:
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))
[7]:
from secretflow.utils.simulation.datasets import load_bank_marketing
# Alice has the first four features,
# while bob has the left features
data = load_bank_marketing(parts={alice: (0, 4), bob: (4, 16)}, axis=1)
# Alice holds the label.
label = load_bank_marketing(parts={alice: (16, 17)}, axis=1)
data is a vertically federated table. It only has the Schema of all the data globally
Let’s take a closer look at VDF data management
age field belongs to Alice, so the corresponding column can be obtained in the partition of Alice, but Bob will report KeyError error when trying to obtain age.Partition, which is a data fragment defined by us. Each Partition has its own device to which it belongs, and only the device that belongs can operate data.[8]:
data['age'].partitions[alice].data
[8]:
<secretflow.device.device.pyu.PYUObject at 0x7fb54bc6d250>
[9]:
# You can uncomment this and you will get a KeyError.
# data['age'].partitions[bob]
VDataFrame.。LabelEncoder and MinMaxScaler as examples. These two preprocessor functions have corresponding concepts in SkLearn and their use methods are similar to those in SkLearn[10]:
from secretflow.preprocessing.scaler import MinMaxScaler
from secretflow.preprocessing.encoder import LabelEncoder
[11]:
encoder = LabelEncoder()
data['job'] = encoder.fit_transform(data['job'])
data['marital'] = encoder.fit_transform(data['marital'])
data['education'] = encoder.fit_transform(data['education'])
data['default'] = encoder.fit_transform(data['default'])
data['housing'] = encoder.fit_transform(data['housing'])
data['loan'] = encoder.fit_transform(data['loan'])
data['contact'] = encoder.fit_transform(data['contact'])
data['poutcome'] = encoder.fit_transform(data['poutcome'])
data['month'] = encoder.fit_transform(data['month'])
label = encoder.fit_transform(label)
(SPURuntime pid=240157) 2023-04-24 10:21:03.623 [error] [context.cc:operator():132] connect to rank=0 failed with error [external/yacl/yacl/link/transport/channel_brpc.cc:368] send, rpc failed=112, message=[E111]Fail to connect Socket{id=0 addr=127.0.0.1:34003} (0x0x4821fc0): Connection refused [R1][E112]Not connected to 127.0.0.1:34003 yet, server_id=0 [R2][E112]Not connected to 127.0.0.1:34003 yet, server_id=0 [R3][E112]Not connected to 127.0.0.1:34003 yet, server_id=0
[12]:
print(f"label= {type(label)},\ndata = {type(data)}")
label= <class 'secretflow.data.vertical.dataframe.VDataFrame'>,
data = <class 'secretflow.data.vertical.dataframe.VDataFrame'>
Standardize data via MinMaxScaler
[13]:
scaler = MinMaxScaler()
data = scaler.fit_transform(data)
Next we divide the data set into train-set and test-set
[14]:
from secretflow.data.split import train_test_split
random_state = 1234
train_data,test_data = train_test_split(data, train_size=0.8, random_state=random_state)
train_label,test_label = train_test_split(label, train_size=0.8, random_state=random_state)
(_run pid=239859) /home/limingbo/.conda/envs/sf/lib/python3.8/site-packages/sklearn/base.py:443: UserWarning: X has feature names, but MinMaxScaler was fitted without feature names
(_run pid=239859) warnings.warn(
(_run pid=239859) /home/limingbo/.conda/envs/sf/lib/python3.8/site-packages/sklearn/base.py:443: UserWarning: X has feature names, but MinMaxScaler was fitted without feature names
(_run pid=239859) warnings.warn(
Summary: At this point, we have completed the definition of federated tables, data preprocessing, and training set and test set partitioning The secretFlow framework defines a set of operations to be built on the federated table (its logical counterpart is pandas.DataFrame). The secretflow framework defines a set of operations to be built on the federated table (its logical counterpart is sklearn) Refer to our documentation and API introduction to learn more about other
features
Introduce Model#
local version: For this task, a basic DNN can be completed, input 16-dimensional features, through a DNN network, output the probability of positive and negative samples.
Federate version: * Alice: - base_net: Input 4-dimensional feature and go through a DNN network to get hidden - fuse_net: Receive hidden features calculated by Alice and Bob, input them to FUSENET for feature fusion, and complete the forward process and backward process * Bob: - base_net: Input 12-dimensional features, get hidden through a DNN network, and then send hidden to Alice to complete the following operation
Define Model#
Next we start creating the federated model we define SLTFModel and SLTorchModel(WIP), which are used to build split learning of vertical scene. We define a simple and easy to use extensible interface, which can easily transform your existing Model into SF-Model, and then conduct vertical scene federation modeling
Split learning is to break up a model so that one part is executed locally on the data and the other part is executed on the label side. First let’s define the locally executed model – base_model
[15]:
def create_base_model(input_dim, output_dim, name='base_model'):
# Create model
def create_model():
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
model = keras.Sequential(
[
keras.Input(shape=input_dim),
layers.Dense(100,activation ="relu" ),
layers.Dense(output_dim, activation="relu"),
]
)
# Compile model
model.summary()
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=["accuracy",tf.keras.metrics.AUC()])
return model
return create_model
We use create_base_model to create their base models for ‘Alice’ and ‘Bob’, respectively
[16]:
# prepare model
hidden_size = 64
# get the number of features of each party.
# When the input data changes, the network automatically adjusts to the input data
alice_input_feature_num = train_data.values.partition_shape()[alice][1]
bob_input_feature_num = train_data.values.partition_shape()[bob][1]
model_base_alice = create_base_model(alice_input_feature_num, hidden_size)
model_base_bob = create_base_model(bob_input_feature_num, hidden_size)
[17]:
model_base_alice()
model_base_bob()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 100) 500
dense_1 (Dense) (None, 64) 6464
=================================================================
Total params: 6,964
Trainable params: 6,964
Non-trainable params: 0
_________________________________________________________________
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_2 (Dense) (None, 100) 1300
dense_3 (Dense) (None, 64) 6464
=================================================================
Total params: 7,764
Trainable params: 7,764
Non-trainable params: 0
_________________________________________________________________
[17]:
<keras.engine.sequential.Sequential at 0x7fb54bc6d610>
Next we define the side with the label, or the server-side model – fuse_model In the definition of fuse_model, we need to correctly define loss, optimizer, and metrics. This is compatible with all configurations of your existing Keras model
[18]:
def create_fuse_model(input_dim, output_dim, party_nums, name='fuse_model'):
def create_model():
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
# input
input_layers = []
for i in range(party_nums):
input_layers.append(keras.Input(input_dim,))
merged_layer = layers.concatenate(input_layers)
fuse_layer = layers.Dense(64, activation='relu')(merged_layer)
output = layers.Dense(output_dim, activation='sigmoid')(fuse_layer)
model = keras.Model(inputs=input_layers, outputs=output)
model.summary()
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=["accuracy",tf.keras.metrics.AUC()])
return model
return create_model
[19]:
model_fuse = create_fuse_model(
input_dim=hidden_size, party_nums=2, output_dim=1)
[20]:
model_fuse()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_3 (InputLayer) [(None, 64)] 0 []
input_4 (InputLayer) [(None, 64)] 0 []
concatenate (Concatenate) (None, 128) 0 ['input_3[0][0]',
'input_4[0][0]']
dense_4 (Dense) (None, 64) 8256 ['concatenate[0][0]']
dense_5 (Dense) (None, 1) 65 ['dense_4[0][0]']
==================================================================================================
Total params: 8,321
Trainable params: 8,321
Non-trainable params: 0
__________________________________________________________________________________________________
[20]:
<keras.engine.functional.Functional at 0x7fb548506d00>
Create Split Learning Model#
Secretflow provides the split learning model SLModel To initial SLModel only need 3 parameters * base_model_dict: A dictionary needs to be passed in all clients participating in the training along with base_model mappings * device_y: PYU, which device has label * model_fuse: The fusion model
Define base_model_dict
base_model_dict:Dict[PYU,model_fn]
[21]:
base_model_dict = {
alice: model_base_alice,
bob: model_base_bob
}
[22]:
from secretflow.security.privacy import DPStrategy, GaussianEmbeddingDP, LabelDP
# Define DP operations
train_batch_size = 128
gaussian_embedding_dp = GaussianEmbeddingDP(
noise_multiplier=0.5,
l2_norm_clip=1.0,
batch_size=train_batch_size,
num_samples=train_data.values.partition_shape()[alice][0],
is_secure_generator=False,
)
dp_strategy_alice = DPStrategy(embedding_dp=gaussian_embedding_dp)
label_dp = LabelDP(eps=64.0)
dp_strategy_bob = DPStrategy(label_dp=label_dp)
dp_strategy_dict = {alice: dp_strategy_alice, bob: dp_strategy_bob}
dp_spent_step_freq = 10
[23]:
sl_model = SLModel(
base_model_dict=base_model_dict,
device_y=alice,
model_fuse=model_fuse,
dp_strategy_dict=dp_strategy_dict,)
INFO:root:Create proxy actor <class 'secretflow.ml.nn.sl.backend.tensorflow.sl_base.PYUSLTFModel'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.sl.backend.tensorflow.sl_base.PYUSLTFModel'> with party bob.
[24]:
sf.reveal(test_data.partitions[alice].data), sf.reveal(test_label.partitions[alice].data)
[24]:
( age job marital education
1426 0.279412 0.181818 0.5 0.333333
416 0.176471 0.636364 1.0 0.333333
3977 0.264706 0.000000 0.5 0.666667
2291 0.338235 0.000000 0.5 0.333333
257 0.132353 0.909091 1.0 0.333333
... ... ... ... ...
1508 0.264706 0.818182 1.0 0.333333
979 0.544118 0.090909 0.0 0.000000
3494 0.455882 0.090909 0.5 0.000000
42 0.485294 0.090909 0.5 0.333333
1386 0.455882 0.636364 0.5 0.333333
[905 rows x 4 columns],
y
1426 0
416 0
3977 0
2291 0
257 0
... ..
1508 0
979 0
3494 0
42 0
1386 0
[905 rows x 1 columns])
[25]:
sf.reveal(train_data.partitions[alice].data), sf.reveal(train_label.partitions[alice].data)
[25]:
( age job marital education
1106 0.235294 0.090909 0.5 0.333333
1309 0.176471 0.363636 0.5 0.333333
2140 0.411765 0.272727 1.0 0.666667
2134 0.573529 0.454545 0.5 0.333333
960 0.485294 0.818182 0.5 0.333333
... ... ... ... ...
664 0.397059 0.090909 1.0 0.333333
3276 0.235294 0.181818 0.5 0.666667
1318 0.220588 0.818182 0.5 0.333333
723 0.220588 0.636364 0.5 0.333333
2863 0.176471 0.363636 1.0 0.666667
[3616 rows x 4 columns],
y
1106 0
1309 0
2140 1
2134 0
960 0
... ..
664 0
3276 0
1318 0
723 0
2863 0
[3616 rows x 1 columns])
[26]:
history = sl_model.fit(train_data,
train_label,
validation_data=(test_data,test_label),
epochs=10,
batch_size=train_batch_size,
shuffle=True,
verbose=1,
validation_freq=1,
dp_spent_step_freq=dp_spent_step_freq,)
INFO:root:SL Train Params: {'self': <secretflow.ml.nn.sl.sl_model.SLModel object at 0x7fb548492910>, 'x': VDataFrame(partitions={alice: Partition(data=<secretflow.device.device.pyu.PYUObject object at 0x7fb54bc8ff10>), bob: Partition(data=<secretflow.device.device.pyu.PYUObject object at 0x7fb54bc8b970>)}, aligned=True), 'y': VDataFrame(partitions={alice: Partition(data=<secretflow.device.device.pyu.PYUObject object at 0x7fb54bc6cfd0>)}, aligned=True), 'batch_size': 128, 'epochs': 10, 'verbose': 1, 'callbacks': None, 'validation_data': (VDataFrame(partitions={alice: Partition(data=<secretflow.device.device.pyu.PYUObject object at 0x7fb54bc8f610>), bob: Partition(data=<secretflow.device.device.pyu.PYUObject object at 0x7fb54bc8b8e0>)}, aligned=True), VDataFrame(partitions={alice: Partition(data=<secretflow.device.device.pyu.PYUObject object at 0x7fb54bc6c970>)}, aligned=True)), 'shuffle': True, 'sample_weight': None, 'validation_freq': 1, 'dp_spent_step_freq': 10, 'dataset_builder': None, 'audit_log_dir': None, 'audit_log_params': {}, 'random_seed': 13780}
(pid=240300) 2023-04-24 10:21:13.143944: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
(pid=240309) 2023-04-24 10:21:13.339222: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
(pid=240300) 2023-04-24 10:21:13.923649: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
(pid=240300) 2023-04-24 10:21:13.923745: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
(pid=240300) 2023-04-24 10:21:13.923756: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
(pid=240309) 2023-04-24 10:21:14.119980: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
(pid=240309) 2023-04-24 10:21:14.120083: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
(pid=240309) 2023-04-24 10:21:14.120095: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
(PYUSLTFModel pid=240300) 2023-04-24 10:21:16.143007: E tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
(PYUSLTFModel pid=240309) 2023-04-24 10:21:16.357362: E tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
(PYUSLTFModel pid=240300) Model: "sequential"
(PYUSLTFModel pid=240300) _________________________________________________________________
(PYUSLTFModel pid=240300) Layer (type) Output Shape Param #
(PYUSLTFModel pid=240300) =================================================================
(PYUSLTFModel pid=240300) dense (Dense) (None, 100) 500
(PYUSLTFModel pid=240300)
(PYUSLTFModel pid=240300) dense_1 (Dense) (None, 64) 6464
(PYUSLTFModel pid=240300)
(PYUSLTFModel pid=240300) =================================================================
(PYUSLTFModel pid=240300) Total params: 6,964
(PYUSLTFModel pid=240300) Trainable params: 6,964
(PYUSLTFModel pid=240300) Non-trainable params: 0
(PYUSLTFModel pid=240300) _________________________________________________________________
(PYUSLTFModel pid=240300) Model: "model"
(PYUSLTFModel pid=240300) __________________________________________________________________________________________________
(PYUSLTFModel pid=240300) Layer (type) Output Shape Param # Connected to
(PYUSLTFModel pid=240300) ==================================================================================================
(PYUSLTFModel pid=240300) input_2 (InputLayer) [(None, 64)] 0 []
(PYUSLTFModel pid=240300)
(PYUSLTFModel pid=240300) input_3 (InputLayer) [(None, 64)] 0 []
(PYUSLTFModel pid=240300)
(PYUSLTFModel pid=240300) concatenate (Concatenate) (None, 128) 0 ['input_2[0][0]',
(PYUSLTFModel pid=240300) 'input_3[0][0]']
(PYUSLTFModel pid=240300)
(PYUSLTFModel pid=240300) dense_2 (Dense) (None, 64) 8256 ['concatenate[0][0]']
(PYUSLTFModel pid=240300)
(PYUSLTFModel pid=240300) dense_3 (Dense) (None, 1) 65 ['dense_2[0][0]']
(PYUSLTFModel pid=240300)
(PYUSLTFModel pid=240300) ==================================================================================================
(PYUSLTFModel pid=240300) Total params: 8,321
(PYUSLTFModel pid=240300) Trainable params: 8,321
(PYUSLTFModel pid=240300) Non-trainable params: 0
(PYUSLTFModel pid=240300) __________________________________________________________________________________________________
7%|▋ | 2/29 [00:00<00:02, 13.20it/s]
(PYUSLTFModel pid=240309) Model: "sequential"
(PYUSLTFModel pid=240309) _________________________________________________________________
(PYUSLTFModel pid=240309) Layer (type) Output Shape Param #
(PYUSLTFModel pid=240309) =================================================================
(PYUSLTFModel pid=240309) dense (Dense) (None, 100) 1300
(PYUSLTFModel pid=240309)
(PYUSLTFModel pid=240309) dense_1 (Dense) (None, 64) 6464
(PYUSLTFModel pid=240309)
(PYUSLTFModel pid=240309) =================================================================
(PYUSLTFModel pid=240309) Total params: 7,764
(PYUSLTFModel pid=240309) Trainable params: 7,764
(PYUSLTFModel pid=240309) Non-trainable params: 0
(PYUSLTFModel pid=240309) _________________________________________________________________
100%|██████████| 29/29 [00:03<00:00, 9.22it/s, epoch: 1/10 - train_loss:0.4876196086406708 train_accuracy:0.779037594795227 train_auc_1:0.523348331451416 val_loss:0.3805972635746002 val_accuracy:0.8729282021522522 val_auc_1:0.6476444602012634 ]
100%|██████████| 29/29 [00:00<00:00, 42.60it/s, epoch: 2/10 - train_loss:0.3252045810222626 train_accuracy:0.8960129022598267 train_auc_1:0.6373960375785828 val_loss:0.3656504452228546 val_accuracy:0.8729282021522522 val_auc_1:0.6615355014801025 ]
100%|██████████| 29/29 [00:00<00:00, 42.80it/s, epoch: 3/10 - train_loss:0.33422568440437317 train_accuracy:0.8832964897155762 train_auc_1:0.7038192749023438 val_loss:0.358445405960083 val_accuracy:0.8729282021522522 val_auc_1:0.6819758415222168 ]
100%|██████████| 29/29 [00:00<00:00, 42.65it/s, epoch: 4/10 - train_loss:0.31387007236480713 train_accuracy:0.88606196641922 train_auc_1:0.7519825100898743 val_loss:0.3427862823009491 val_accuracy:0.8729282021522522 val_auc_1:0.7419042587280273 ]
100%|██████████| 29/29 [00:00<00:00, 45.85it/s, epoch: 5/10 - train_loss:0.2894230782985687 train_accuracy:0.8866150379180908 train_auc_1:0.8085392713546753 val_loss:0.33072948455810547 val_accuracy:0.870718240737915 val_auc_1:0.7843313217163086 ]
100%|██████████| 29/29 [00:00<00:00, 44.84it/s, epoch: 6/10 - train_loss:0.27044418454170227 train_accuracy:0.8869742751121521 train_auc_1:0.8391960859298706 val_loss:0.3120502531528473 val_accuracy:0.8674033284187317 val_auc_1:0.8096477389335632 ]
100%|██████████| 29/29 [00:00<00:00, 42.93it/s, epoch: 7/10 - train_loss:0.25070708990097046 train_accuracy:0.8962942361831665 train_auc_1:0.8619815707206726 val_loss:0.31437328457832336 val_accuracy:0.8718231916427612 val_auc_1:0.838728666305542 ]
100%|██████████| 29/29 [00:00<00:00, 43.73it/s, epoch: 8/10 - train_loss:0.25882866978645325 train_accuracy:0.8933189511299133 train_auc_1:0.8460560441017151 val_loss:0.2909625768661499 val_accuracy:0.8773480653762817 val_auc_1:0.8433351516723633 ]
100%|██████████| 29/29 [00:00<00:00, 45.62it/s, epoch: 9/10 - train_loss:0.254334032535553 train_accuracy:0.8940818309783936 train_auc_1:0.8722440004348755 val_loss:0.2853069305419922 val_accuracy:0.8828729391098022 val_auc_1:0.8439790606498718 ]
100%|██████████| 29/29 [00:00<00:00, 50.14it/s, epoch: 10/10 - train_loss:0.24358023703098297 train_accuracy:0.8957411646842957 train_auc_1:0.8758358359336853 val_loss:0.2825777232646942 val_accuracy:0.8784530162811279 val_auc_1:0.8505613803863525 ]
Let’s visualize the training process
[27]:
print(history)
print(history.keys())
{'train_loss': [0.4876196, 0.32520458, 0.33422568, 0.31387007, 0.28942308, 0.27044418, 0.2507071, 0.25882867, 0.25433403, 0.24358024], 'train_accuracy': [0.7790376, 0.8960129, 0.8832965, 0.88606197, 0.88661504, 0.8869743, 0.89629424, 0.89331895, 0.89408183, 0.89574116], 'train_auc_1': [0.52334833, 0.63739604, 0.7038193, 0.7519825, 0.8085393, 0.8391961, 0.8619816, 0.84605604, 0.872244, 0.87583584], 'val_loss': [0.38059726, 0.36565045, 0.3584454, 0.34278628, 0.33072948, 0.31205025, 0.31437328, 0.29096258, 0.28530693, 0.28257772], 'val_accuracy': [0.8729282, 0.8729282, 0.8729282, 0.8729282, 0.87071824, 0.8674033, 0.8718232, 0.87734807, 0.88287294, 0.878453], 'val_auc_1': [0.64764446, 0.6615355, 0.68197584, 0.74190426, 0.7843313, 0.80964774, 0.83872867, 0.84333515, 0.84397906, 0.8505614]}
dict_keys(['train_loss', 'train_accuracy', 'train_auc_1', 'val_loss', 'val_accuracy', 'val_auc_1'])
[28]:
# Plot the change of loss during training
plt.plot(history['train_loss'])
plt.plot(history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train','Val'], loc='upper right')
plt.show()
[29]:
# Plot the change of accuracy during training
plt.plot(history['train_accuracy'])
plt.plot(history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Val'], loc='upper left')
plt.show()
[30]:
# Plot the Area Under Curve(AUC) of loss during training
plt.plot(history['train_auc_1'])
plt.plot(history['val_auc_1'])
plt.title('Model Area Under Curve')
plt.ylabel('Area Under Curve')
plt.xlabel('Epoch')
plt.legend(['Train', 'Val'], loc='upper left')
plt.show()
Let’s call the evaluation function
[31]:
global_metric = sl_model.evaluate(test_data, test_label, batch_size=128)
print(global_metric)
Evaluate Processing:: 100%|██████████| 8/8 [00:00<00:00, 175.93it/s, loss:0.28720757365226746 accuracy:0.8883978128433228 auc_1:0.8435608148574829]
{'loss': 0.28720757, 'accuracy': 0.8883978, 'auc_1': 0.8435608}
Contrast to local model#
The model structure is consistent with the model of split learning above, but only the model structure of Alice is used here. The model definition refers to the code below. #### Data The data also use kaggle’s anti-fraud data. Here, we just use Alice’s data of the new bank.
[32]:
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
from sklearn.model_selection import train_test_split
def create_model():
model = keras.Sequential(
[
keras.Input(shape=4),
layers.Dense(100,activation ="relu" ),
layers.Dense(64, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(1, activation='sigmoid')
]
)
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=["accuracy",tf.keras.metrics.AUC()])
return model
single_model = create_model()
data process
[33]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import LabelEncoder
encoder = LabelEncoder()
single_part_data = alice_data.copy()
single_part_data['job'] = encoder.fit_transform(alice_data['job'])
single_part_data['marital'] = encoder.fit_transform(alice_data['marital'])
single_part_data['education'] = encoder.fit_transform(alice_data['education'])
single_part_data['y'] = encoder.fit_transform(alice_data['y'])
[34]:
single_part_label = single_part_data['y']
single_part_data_no_label = single_part_data.drop(columns=['y'],inplace=False)
[35]:
scaler = MinMaxScaler()
single_part_data_no_label = scaler.fit_transform(single_part_data_no_label)
[36]:
train_data,test_data = train_test_split(single_part_data_no_label, train_size=0.8,random_state=random_state)
train_label,test_label = train_test_split(single_part_label, train_size=0.8,random_state=random_state)
[37]:
test_data.shape
[37]:
(905, 4)
[38]:
history =single_model.fit(train_data,train_label,validation_data=(test_data,test_label),batch_size=128,epochs=10,shuffle=False)
Epoch 1/10
29/29 [==============================] - 2s 13ms/step - loss: 0.5258 - accuracy: 0.8653 - auc_3: 0.4494 - val_loss: 0.4046 - val_accuracy: 0.8729 - val_auc_3: 0.4320
Epoch 2/10
29/29 [==============================] - 0s 4ms/step - loss: 0.3747 - accuracy: 0.8877 - auc_3: 0.4590 - val_loss: 0.4003 - val_accuracy: 0.8729 - val_auc_3: 0.4279
Epoch 3/10
29/29 [==============================] - 0s 4ms/step - loss: 0.3684 - accuracy: 0.8877 - auc_3: 0.4383 - val_loss: 0.3941 - val_accuracy: 0.8729 - val_auc_3: 0.4223
Epoch 4/10
29/29 [==============================] - 0s 4ms/step - loss: 0.3623 - accuracy: 0.8877 - auc_3: 0.4465 - val_loss: 0.3904 - val_accuracy: 0.8729 - val_auc_3: 0.4248
Epoch 5/10
29/29 [==============================] - 0s 4ms/step - loss: 0.3600 - accuracy: 0.8877 - auc_3: 0.4533 - val_loss: 0.3877 - val_accuracy: 0.8729 - val_auc_3: 0.4401
Epoch 6/10
29/29 [==============================] - 0s 4ms/step - loss: 0.3578 - accuracy: 0.8877 - auc_3: 0.4655 - val_loss: 0.3857 - val_accuracy: 0.8729 - val_auc_3: 0.4659
Epoch 7/10
29/29 [==============================] - 0s 4ms/step - loss: 0.3562 - accuracy: 0.8877 - auc_3: 0.4869 - val_loss: 0.3841 - val_accuracy: 0.8729 - val_auc_3: 0.4851
Epoch 8/10
29/29 [==============================] - 0s 4ms/step - loss: 0.3550 - accuracy: 0.8877 - auc_3: 0.4975 - val_loss: 0.3828 - val_accuracy: 0.8729 - val_auc_3: 0.4969
Epoch 9/10
29/29 [==============================] - 0s 4ms/step - loss: 0.3539 - accuracy: 0.8877 - auc_3: 0.5105 - val_loss: 0.3816 - val_accuracy: 0.8729 - val_auc_3: 0.5166
Epoch 10/10
29/29 [==============================] - 0s 4ms/step - loss: 0.3528 - accuracy: 0.8877 - auc_3: 0.5216 - val_loss: 0.3807 - val_accuracy: 0.8729 - val_auc_3: 0.5241
Referring to the above visualization code, the training process of a local model can also be visualized
Summary#
The above two experiments simulate a typical vertical scene training problem. Alice and Bob have the same sample group, but each side has only a part of the features. If Alice only uses her own data to train the model, an accuracy of 0.8729, AUC 0.5241 model can be obtained. However, if Bob’s data are combined, a model with an accuracy of 0.8884 and AUC 0.8436 can be obtained.
Conclusion#
This tutorial introduces what is split learning and how to do it in secretFlow
It can be seen from the experimental data that split learning has significant advantages in expanding sample dimension and improving model effect through joint multi-party training
This tutorial uses plaintext aggregation to demonstrate, without considering the leakage problem of hidden layer. Secretflow provides AggLayer to avoid the leakage problem of hidden layer plaintext transmission through MPC,TEE,HE, and DP. If you are interested, please refer to relevant documents.
Next, you may want to try different data sets, you need to vertically shard the data first and then follow the flow of this tutorial