Federated Learning for Image Classification#
The following codes are demos only. It’s NOT for production due to system security concerns, please DO NOT use it directly in production.
In this tutorial, we will use the image classification task to show how to complete the horizontal federated learning task in the SecretFlow framework. The SecretFlow framework provides a user-friendly API that makes it easy to apply your Keras or PyTorch model to a federated learning scenario as a federated learning model. In the rest of the tutorial we will show you how to turn your existing model into a federated model in SecretFlow to complete federated multi-party modeling tasks
What is Federated Learning#
The federated learning here refers specifically to the federated learning of horizontal scenarios. This mode applies to the situation where each participant has the same business but different customer groups are reached. In this case, samples from various parties can be combined to train a joint model with better performance. For example, in the medical scene, each hospital has its own unique patient group, and hospitals in different regions almost do not overlap each other, but their examination records for medical records (such as images, blood tests, etc.) are of the same type.

Training process: 1. Each participant downloads the latest model from the server. 2. Each participant uses its own local data to train the model, and uploads gradient encryption (or parameter encryption) to the server, which obtains the encryption gradient (encryption parameter) uploaded by all parties for security aggregation at the server, and updates model parameters with the aggregated gradient. 3. The server returns the updated model to each participant. 4. Each participant updates their local model, and prepare next training.
Federated learning on SecretFlow#
[1]:
%load_ext autoreload
%autoreload 2
Create 3 entities in the Secretflow environment [Alice, Bob, Charlie] Alice, Bob and Charlie are the three PYUs. Alice and Bob to be the clients and Charlie to be the server
[2]:
import secretflow as sf
# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob', 'charlie'], address='local')
alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')
2022-08-18 17:13:51.247907: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib
[3]:
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))
Prepare Data#
Alice and bob each own half the data.
[4]:
from secretflow.data.ndarray import load
from secretflow.utils.simulation.datasets import load_mnist
(x_train, y_train), (x_test, y_test) = load_mnist(parts=[alice, bob], normalized_x=True, categorical_y=True)
x_train, y_train, x_test, y_test are both FedNdarray. Let’s take a look at the data obtained from FedNdarray. FedNdarray is a virtual Ndarray built on a multi-party concept to protect data privacy. The underlying data is stored in each participant. The FedNdarray operation is actually performed by each participant on their own local data. The server or other clients do not touch the original data. For demonstration purposes, we will manually download the data to the
driver This data will be used later in the unilateral model comparison
[5]:
import numpy as np
from secretflow.utils.simulation.datasets import dataset
mnist = np.load(dataset('mnist'), allow_pickle=True)
image = mnist['x_train']
label = mnist['y_train']
Let’s grab some samples from the data set, and just visually see, what does the data look like for Both Alice and Bob?
[6]:
from matplotlib import pyplot as plt
figure = plt.figure(figsize=(20, 4))
j = 0
for example in image[:40]:
plt.subplot(4, 10, j+1)
plt.imshow(example, cmap='gray', aspect='equal')
plt.axis('off')
j += 1
[7]:
figure = plt.figure(figsize=(20, 4))
j = 0
for example in image[:40]:
plt.subplot(4, 10, j+1)
plt.imshow(example, cmap='gray', aspect='equal')
plt.axis('off')
j += 1
It can be seen from the above two examples that the data types and tasks of Alice and Bob are consistent, but the samples are different due to the different user groups they reach.
Define Model#
[8]:
def create_conv_model(input_shape, num_classes, name='model'):
def create_model():
from tensorflow import keras
from tensorflow.keras import layers
# Create model
model = keras.Sequential(
[
keras.Input(shape=input_shape),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(num_classes, activation="softmax"),
]
)
# Compile model
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=["accuracy"])
return model
return create_model
Training FL Model#
Import packages
[9]:
from secretflow.security.aggregation import SPUAggregator, SecureAggregator
from secretflow.ml.nn import FLModel
Define Model
[10]:
num_classes = 10
input_shape = (28, 28, 1)
model = create_conv_model(input_shape, num_classes)
Define the device list for participating training, which is the PYUS of each participant prepared previously
[11]:
device_list = [alice, bob]
Define Aggregator Secretflow offer a variety of aggregation schemes,
SecureAggregatorandPPUAggregatorcan be used security aggregation, more information about aggregation,see Secure Aggregator.
[12]:
secure_aggregator = SecureAggregator(charlie, [alice, bob])
spu_aggregator = SPUAggregator(spu)
Define
FLModel
[13]:
fed_model = FLModel(server=charlie,
device_list=device_list,
model=model,
aggregator=secure_aggregator,
strategy="fed_avg_w",
backend = "tensorflow")
Lets run model
[14]:
history = fed_model.fit(x_train,
y_train,
validation_data=(x_test, y_test),
epochs=10,
sampler_method="batch",
batch_size=128,
aggregate_freq=1)
100%|█████████▉| 234/235 [00:04<00:00, 49.26it/s]2022-08-18 17:14:21.735534: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib
2022-08-18 17:14:21.735571: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
100%|██████████| 235/235 [00:18<00:00, 12.61it/s, epoch: 1/10 - loss:0.5259245038032532 accuracy:0.8461666703224182 val_loss:0.14433448016643524 val_accuracy:0.9577999711036682 ]
100%|██████████| 40/40 [00:04<00:00, 8.08it/s, epoch: 2/10 - loss:0.1685342937707901 accuracy:0.9504940509796143 val_loss:0.12423974275588989 val_accuracy:0.964900016784668 ]
100%|██████████| 40/40 [00:02<00:00, 13.97it/s, epoch: 3/10 - loss:0.1499660760164261 accuracy:0.9557806253433228 val_loss:0.11904746294021606 val_accuracy:0.9649999737739563 ]
100%|██████████| 40/40 [00:02<00:00, 14.88it/s, epoch: 4/10 - loss:0.14443494379520416 accuracy:0.9566205739974976 val_loss:0.1067579835653305 val_accuracy:0.9693999886512756 ]
100%|██████████| 40/40 [00:02<00:00, 14.78it/s, epoch: 5/10 - loss:0.1303529143333435 accuracy:0.9610671997070312 val_loss:0.09679792076349258 val_accuracy:0.9718999862670898 ]
100%|██████████| 40/40 [00:02<00:00, 14.61it/s, epoch: 6/10 - loss:0.11564651876688004 accuracy:0.9659584760665894 val_loss:0.09385047107934952 val_accuracy:0.9726999998092651 ]
100%|██████████| 40/40 [00:02<00:00, 14.48it/s, epoch: 7/10 - loss:0.10946863144636154 accuracy:0.9681274890899658 val_loss:0.09159354865550995 val_accuracy:0.972599983215332 ]
100%|██████████| 40/40 [00:02<00:00, 14.81it/s, epoch: 8/10 - loss:0.10932281613349915 accuracy:0.9678359627723694 val_loss:0.08414007723331451 val_accuracy:0.9754999876022339 ]
100%|██████████| 40/40 [00:02<00:00, 14.47it/s, epoch: 9/10 - loss:0.10051391273736954 accuracy:0.9696640372276306 val_loss:0.08173993974924088 val_accuracy:0.9753999710083008 ]
100%|██████████| 40/40 [00:02<00:00, 13.60it/s, epoch: 10/10 - loss:0.10390906035900116 accuracy:0.968478262424469 val_loss:0.07482419162988663 val_accuracy:0.9775000214576721 ]
[15]:
# Draw accuracy values for training & validation
plt.plot(history.global_history['accuracy'])
plt.plot(history.global_history['val_accuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
# Draw loss for training & validation
plt.plot(history.global_history['loss'])
plt.plot(history.global_history['val_loss'])
plt.title('FLModel loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
[16]:
global_metric = fed_model.evaluate(x_test, y_test, batch_size=128)
print(global_metric)
([Mean(name='loss', total=748.24194, count=10000.0), Mean(name='accuracy', total=9775.0, count=10000.0)], {'alice': [Mean(name='loss', total=527.7763, count=5000.0), Mean(name='accuracy', total=4833.0, count=5000.0)], 'bob': [Mean(name='loss', total=220.46567, count=5000.0), Mean(name='accuracy', total=4942.0, count=5000.0)]})
Contrast experiment to local training#
Model#
The model structure is consistent with the fl model above #### Data Here, we only used data after a horizontal segmentation, with a total of 20,000 samples for Alice
[17]:
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
def create_model():
model = keras.Sequential(
[
keras.Input(shape=input_shape),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(num_classes, activation="softmax"),
]
)
# Compile model
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=["accuracy"])
return model
single_model = create_model()
[18]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
alice_x = image[:10000]
alice_y = label[:10000]
alice_y = OneHotEncoder(sparse=False).fit_transform(alice_y.reshape(-1, 1))
random_seed = 1234
alice_X_train, alice_X_test, alice_y_train, alice_y_test = train_test_split(alice_x,
alice_y, test_size=0.33, random_state=random_seed)
[19]:
single_model.fit(alice_X_train, alice_y_train, validation_data=(alice_X_test, alice_y_test), batch_size=128, epochs=10)
Epoch 1/10
53/53 [==============================] - 1s 16ms/step - loss: 8.1856 - accuracy: 0.5194 - val_loss: 0.5021 - val_accuracy: 0.8461
Epoch 2/10
53/53 [==============================] - 1s 13ms/step - loss: 0.7643 - accuracy: 0.7754 - val_loss: 0.3211 - val_accuracy: 0.9024
Epoch 3/10
53/53 [==============================] - 1s 13ms/step - loss: 0.5189 - accuracy: 0.8452 - val_loss: 0.2557 - val_accuracy: 0.9233
Epoch 4/10
53/53 [==============================] - 1s 14ms/step - loss: 0.3795 - accuracy: 0.8899 - val_loss: 0.1997 - val_accuracy: 0.9388
Epoch 5/10
53/53 [==============================] - 1s 13ms/step - loss: 0.3246 - accuracy: 0.9024 - val_loss: 0.1864 - val_accuracy: 0.9406
Epoch 6/10
53/53 [==============================] - 1s 13ms/step - loss: 0.2747 - accuracy: 0.9182 - val_loss: 0.1696 - val_accuracy: 0.9464
Epoch 7/10
53/53 [==============================] - 1s 12ms/step - loss: 0.2245 - accuracy: 0.9324 - val_loss: 0.1484 - val_accuracy: 0.9545
Epoch 8/10
53/53 [==============================] - 1s 13ms/step - loss: 0.2123 - accuracy: 0.9361 - val_loss: 0.1427 - val_accuracy: 0.9570
Epoch 9/10
53/53 [==============================] - 1s 13ms/step - loss: 0.1884 - accuracy: 0.9425 - val_loss: 0.1290 - val_accuracy: 0.9633
Epoch 10/10
53/53 [==============================] - 1s 13ms/step - loss: 0.1775 - accuracy: 0.9451 - val_loss: 0.1179 - val_accuracy: 0.9627
[19]:
<keras.callbacks.History at 0x7fcd9c53f7c0>
The two experiments above simulated a training problem in a typical horizontal federation scenario, * Alice and Bob have same type of data * Each side had only a portion of the sample, but the training objectives is the same If Alice only uses her own data to train the model, could only obtain a model with an accuracy of 0.945. However, if Bob’s data is combined, a model with an accuracy close to 0.995 can be obtained. In addition, the generalization performance of the model jointly trained with multi-party data will also be better
Conclusion#
This tutorial introduces what federated learning is and how to perform horizontal federated learning in
secretFlowIt can be seen from the experimental data that horizontal federation can improve the model effect by expanding the sample size and combining multi-party training.
This tutorial uses a SecureAggregator to demonstrate, and secretflow provides a variety of aggregation schemes,for more infomation, see Secure Aggregation.
next, you can use your data or model to explore how to do federate learning