Federated Learning with Pytorch Backend#
The following codes are demos only. It’s NOT for production due to system security concerns, please DO NOT use it directly in production.
In this tutorial, We will walk you through how to use pytorch backend on SecretFlow for federated learning. + We will use the image clasification task as example + Use pytorch as backend + We will show how to use multi fl strategy
If you want to learn more about federated learning, datasets, etc., you can move to Federated Learning for Image Classification
Here we go!
[1]:
%load_ext autoreload
%autoreload 2
[2]:
import secretflow as sf
# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob', 'charlie'], address='local')
alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')
2022-08-31 23:43:03.362818: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib
torch.nn.moduleloss_fn,optim_fn,model_def,metricsbackend to specify which bachend will be use, use strategy to spcify which federated strategy will be use[3]:
from secretflow.ml.nn.fl.backend.torch.utils import BaseModule, TorchModel
from secretflow.ml.nn.fl.utils import metric_wrapper, optim_wrapper
from secretflow.ml.nn import FLModel
from torchmetrics import Accuracy, Precision
from secretflow.security.aggregation import SecureAggregator
from secretflow.utils.simulation.datasets import load_mnist
from torch import nn, optim
from torch.nn import functional as F
When we define the model, we only need to inherit BaseModule instead of nn.Module, and the others are consistent with pytorch
[4]:
class ConvNet(BaseModule):
"""Small ConvNet for MNIST."""
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
self.fc_in_dim = 192
self.fc = nn.Linear(self.fc_in_dim, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 3))
x = x.view(-1, self.fc_in_dim)
x = self.fc(x)
return F.softmax(x, dim=1)
We can continue to use the loss function and optimizer defined in pytorch, the only difference is that we need to wrap it with the wrapper provided in secretflow
[5]:
(train_data, train_label), (test_data, test_label) = load_mnist(
parts={alice: 0.4, bob: 0.6},
normalized_x=True,
categorical_y=True,
is_torch=True,
)
loss_fn = nn.CrossEntropyLoss
optim_fn = optim_wrapper(optim.Adam, lr=1e-2)
model_def = TorchModel(
model_fn=ConvNet,
loss_fn=loss_fn,
optim_fn=optim_fn,
metrics=[
metric_wrapper(Accuracy, task="multiclass", num_classes=10, average='micro'),
metric_wrapper(Precision, task="multiclass", num_classes=10, average='micro'),
],
)
[6]:
device_list = [alice, bob]
server = charlie
aggregator = SecureAggregator(server,[alice,bob])
# spcify params
fl_model = FLModel(
server=server,
device_list=device_list,
model=model_def,
aggregator=aggregator,
strategy='fed_avg_w', # fl strategy
backend="torch", # backend support ['tensorflow', 'torch']
)
[7]:
history = fl_model.fit(
train_data,
train_label,
validation_data=(test_data, test_label),
epochs=20,
batch_size=32,
aggregate_freq=1,
)
100%|█████████▉| 749/750 [00:15<00:00, 47.41it/s]2022-08-31 23:43:34.759168: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib
2022-08-31 23:43:34.759205: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
100%|██████████| 750/750 [00:16<00:00, 46.78it/s, epoch: 1/20 - accuracy:0.9709533452987671 precision:0.8571249842643738 val_accuracy:0.9840199947357178 val_precision:0.8955000042915344 ]
100%|██████████| 125/125 [00:03<00:00, 31.28it/s, epoch: 2/20 - accuracy:0.9825800061225891 precision:0.9190000295639038 val_accuracy:0.9850000143051147 val_precision:0.903249979019165 ]
100%|██████████| 125/125 [00:02<00:00, 41.70it/s, epoch: 3/20 - accuracy:0.9850000143051147 precision:0.9302499890327454 val_accuracy:0.9856399893760681 val_precision:0.906499981880188 ]
100%|██████████| 125/125 [00:03<00:00, 41.66it/s, epoch: 4/20 - accuracy:0.9859799742698669 precision:0.9334999918937683 val_accuracy:0.9861800074577332 val_precision:0.9085000157356262 ]
100%|██████████| 125/125 [00:03<00:00, 41.46it/s, epoch: 5/20 - accuracy:0.9870200157165527 precision:0.940500020980835 val_accuracy:0.9864799976348877 val_precision:0.9097499847412109 ]
100%|██████████| 125/125 [00:02<00:00, 41.67it/s, epoch: 6/20 - accuracy:0.987779974937439 precision:0.9422500133514404 val_accuracy:0.9869400262832642 val_precision:0.9137499928474426 ]
100%|██████████| 125/125 [00:02<00:00, 41.92it/s, epoch: 7/20 - accuracy:0.988099992275238 precision:0.9447500109672546 val_accuracy:0.9870200157165527 val_precision:0.9139999747276306 ]
100%|██████████| 125/125 [00:03<00:00, 41.55it/s, epoch: 8/20 - accuracy:0.9887800216674805 precision:0.9477499723434448 val_accuracy:0.986739993095398 val_precision:0.9135000109672546 ]
100%|██████████| 125/125 [00:03<00:00, 41.49it/s, epoch: 9/20 - accuracy:0.9892399907112122 precision:0.9502500295639038 val_accuracy:0.9868199825286865 val_precision:0.9132500290870667 ]
100%|██████████| 125/125 [00:02<00:00, 41.81it/s, epoch: 10/20 - accuracy:0.989359974861145 precision:0.9522500038146973 val_accuracy:0.9873600006103516 val_precision:0.9175000190734863 ]
100%|██████████| 125/125 [00:03<00:00, 41.46it/s, epoch: 11/20 - accuracy:0.9898999929428101 precision:0.953249990940094 val_accuracy:0.9874200224876404 val_precision:0.9194999933242798 ]
100%|██████████| 125/125 [00:02<00:00, 41.77it/s, epoch: 12/20 - accuracy:0.990119993686676 precision:0.953499972820282 val_accuracy:0.9871600270271301 val_precision:0.9154999852180481 ]
100%|██████████| 125/125 [00:02<00:00, 41.87it/s, epoch: 13/20 - accuracy:0.9906600117683411 precision:0.9570000171661377 val_accuracy:0.9876800179481506 val_precision:0.9202499985694885 ]
100%|██████████| 125/125 [00:03<00:00, 40.91it/s, epoch: 14/20 - accuracy:0.9910399913787842 precision:0.9572499990463257 val_accuracy:0.9880200028419495 val_precision:0.9227499961853027 ]
100%|██████████| 125/125 [00:03<00:00, 41.49it/s, epoch: 15/20 - accuracy:0.9903600215911865 precision:0.9542499780654907 val_accuracy:0.9878000020980835 val_precision:0.9194999933242798 ]
100%|██████████| 125/125 [00:02<00:00, 41.68it/s, epoch: 16/20 - accuracy:0.9914000034332275 precision:0.9585000276565552 val_accuracy:0.9878799915313721 val_precision:0.921750009059906 ]
100%|██████████| 125/125 [00:02<00:00, 42.21it/s, epoch: 17/20 - accuracy:0.9915599822998047 precision:0.9597499966621399 val_accuracy:0.988099992275238 val_precision:0.921750009059906 ]
100%|██████████| 125/125 [00:03<00:00, 41.41it/s, epoch: 18/20 - accuracy:0.9915800094604492 precision:0.9595000147819519 val_accuracy:0.9880399703979492 val_precision:0.921500027179718 ]
100%|██████████| 125/125 [00:02<00:00, 41.83it/s, epoch: 19/20 - accuracy:0.9916200041770935 precision:0.9605000019073486 val_accuracy:0.9887400269508362 val_precision:0.9244999885559082 ]
100%|██████████| 125/125 [00:03<00:00, 41.34it/s, epoch: 20/20 - accuracy:0.9922599792480469 precision:0.9637500047683716 val_accuracy:0.9883599877357483 val_precision:0.922249972820282 ]
[8]:
from matplotlib import pyplot as plt
# Draw accuracy values for training & validation
plt.plot(history.global_history['multiclassaccuracy'])
plt.plot(history.global_history['val_multiclassaccuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
[ ]: