使用Pytorch后端来进行联邦学习#
以下代码仅作为演示用,请勿直接在生产环境使用。
在本教程中,我们将引导您了解如何在 SecretFlow 上使用 pytorch 后端进行联邦学习。 - 我们将以图像分类任务为例 - 使用 pytorch 作为后端 - 我们将展示如何使用多种 fl 策略
如果你想了解更多关于联邦学习、数据集等的知识,可以移步 水平联邦:图像分类 。
开始吧
[1]:
%load_ext autoreload
%autoreload 2
[2]:
import secretflow as sf
# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob', 'charlie'], address='local')
alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')
2022-08-31 23:43:03.362818: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib
BaseModule: 类似于
torch.nn.module
TorchModel: 一个包装类包括
loss_fn
,optim_fn
,metric_wrapper: metrics包装器
optim_wrapper: optim_fn包装器
FLModel:联合模型,使用
backend
指定哪个bachend将被使用,使用 strategy
来指定将使用哪个联合策略[3]:
from secretflow.ml.nn.fl.backend.torch.utils import BaseModule, TorchModel
from secretflow.ml.nn.fl.utils import metric_wrapper, optim_wrapper
from secretflow.ml.nn import FLModel
from torchmetrics import Accuracy, Precision
from secretflow.security.aggregation import SecureAggregator
from secretflow.utils.simulation.datasets import load_mnist
from torch import nn, optim
from torch.nn import functional as F
当我们定义模型时,我们只需要继承 BaseModule
来代替 nn.Module
的,其他与pytorch一致
[4]:
class ConvNet(BaseModule):
"""Small ConvNet for MNIST."""
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
self.fc_in_dim = 192
self.fc = nn.Linear(self.fc_in_dim, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 3))
x = x.view(-1, self.fc_in_dim)
x = self.fc(x)
return F.softmax(x, dim=1)
我们可以继续使用中在pytorch中定义的损失函数和优化器,在secretflow中唯一的区别是我们需要用包装器来包装它。
[5]:
(train_data, train_label), (test_data, test_label) = load_mnist(
parts={alice: 0.4, bob: 0.6},
normalized_x=True,
categorical_y=True,
is_torch=True,
)
loss_fn = nn.CrossEntropyLoss
optim_fn = optim_wrapper(optim.Adam, lr=1e-2)
model_def = TorchModel(
model_fn=ConvNet,
loss_fn=loss_fn,
optim_fn=optim_fn,
metrics=[
metric_wrapper(Accuracy, task="multiclass", num_classes=10, average='micro'),
metric_wrapper(Precision, task="multiclass", num_classes=10, average='micro'),
],
)
[6]:
device_list = [alice, bob]
server = charlie
aggregator = SecureAggregator(server,[alice,bob])
# spcify params
fl_model = FLModel(
server=server,
device_list=device_list,
model=model_def,
aggregator=aggregator,
strategy='fed_avg_w', # fl strategy
backend="torch", # backend support ['tensorflow', 'torch']
)
[7]:
history = fl_model.fit(
train_data,
train_label,
validation_data=(test_data, test_label),
epochs=20,
batch_size=32,
aggregate_freq=1,
)
100%|█████████▉| 749/750 [00:15<00:00, 47.41it/s]2022-08-31 23:43:34.759168: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib
2022-08-31 23:43:34.759205: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
100%|██████████| 750/750 [00:16<00:00, 46.78it/s, epoch: 1/20 - accuracy:0.9709533452987671 precision:0.8571249842643738 val_accuracy:0.9840199947357178 val_precision:0.8955000042915344 ]
100%|██████████| 125/125 [00:03<00:00, 31.28it/s, epoch: 2/20 - accuracy:0.9825800061225891 precision:0.9190000295639038 val_accuracy:0.9850000143051147 val_precision:0.903249979019165 ]
100%|██████████| 125/125 [00:02<00:00, 41.70it/s, epoch: 3/20 - accuracy:0.9850000143051147 precision:0.9302499890327454 val_accuracy:0.9856399893760681 val_precision:0.906499981880188 ]
100%|██████████| 125/125 [00:03<00:00, 41.66it/s, epoch: 4/20 - accuracy:0.9859799742698669 precision:0.9334999918937683 val_accuracy:0.9861800074577332 val_precision:0.9085000157356262 ]
100%|██████████| 125/125 [00:03<00:00, 41.46it/s, epoch: 5/20 - accuracy:0.9870200157165527 precision:0.940500020980835 val_accuracy:0.9864799976348877 val_precision:0.9097499847412109 ]
100%|██████████| 125/125 [00:02<00:00, 41.67it/s, epoch: 6/20 - accuracy:0.987779974937439 precision:0.9422500133514404 val_accuracy:0.9869400262832642 val_precision:0.9137499928474426 ]
100%|██████████| 125/125 [00:02<00:00, 41.92it/s, epoch: 7/20 - accuracy:0.988099992275238 precision:0.9447500109672546 val_accuracy:0.9870200157165527 val_precision:0.9139999747276306 ]
100%|██████████| 125/125 [00:03<00:00, 41.55it/s, epoch: 8/20 - accuracy:0.9887800216674805 precision:0.9477499723434448 val_accuracy:0.986739993095398 val_precision:0.9135000109672546 ]
100%|██████████| 125/125 [00:03<00:00, 41.49it/s, epoch: 9/20 - accuracy:0.9892399907112122 precision:0.9502500295639038 val_accuracy:0.9868199825286865 val_precision:0.9132500290870667 ]
100%|██████████| 125/125 [00:02<00:00, 41.81it/s, epoch: 10/20 - accuracy:0.989359974861145 precision:0.9522500038146973 val_accuracy:0.9873600006103516 val_precision:0.9175000190734863 ]
100%|██████████| 125/125 [00:03<00:00, 41.46it/s, epoch: 11/20 - accuracy:0.9898999929428101 precision:0.953249990940094 val_accuracy:0.9874200224876404 val_precision:0.9194999933242798 ]
100%|██████████| 125/125 [00:02<00:00, 41.77it/s, epoch: 12/20 - accuracy:0.990119993686676 precision:0.953499972820282 val_accuracy:0.9871600270271301 val_precision:0.9154999852180481 ]
100%|██████████| 125/125 [00:02<00:00, 41.87it/s, epoch: 13/20 - accuracy:0.9906600117683411 precision:0.9570000171661377 val_accuracy:0.9876800179481506 val_precision:0.9202499985694885 ]
100%|██████████| 125/125 [00:03<00:00, 40.91it/s, epoch: 14/20 - accuracy:0.9910399913787842 precision:0.9572499990463257 val_accuracy:0.9880200028419495 val_precision:0.9227499961853027 ]
100%|██████████| 125/125 [00:03<00:00, 41.49it/s, epoch: 15/20 - accuracy:0.9903600215911865 precision:0.9542499780654907 val_accuracy:0.9878000020980835 val_precision:0.9194999933242798 ]
100%|██████████| 125/125 [00:02<00:00, 41.68it/s, epoch: 16/20 - accuracy:0.9914000034332275 precision:0.9585000276565552 val_accuracy:0.9878799915313721 val_precision:0.921750009059906 ]
100%|██████████| 125/125 [00:02<00:00, 42.21it/s, epoch: 17/20 - accuracy:0.9915599822998047 precision:0.9597499966621399 val_accuracy:0.988099992275238 val_precision:0.921750009059906 ]
100%|██████████| 125/125 [00:03<00:00, 41.41it/s, epoch: 18/20 - accuracy:0.9915800094604492 precision:0.9595000147819519 val_accuracy:0.9880399703979492 val_precision:0.921500027179718 ]
100%|██████████| 125/125 [00:02<00:00, 41.83it/s, epoch: 19/20 - accuracy:0.9916200041770935 precision:0.9605000019073486 val_accuracy:0.9887400269508362 val_precision:0.9244999885559082 ]
100%|██████████| 125/125 [00:03<00:00, 41.34it/s, epoch: 20/20 - accuracy:0.9922599792480469 precision:0.9637500047683716 val_accuracy:0.9883599877357483 val_precision:0.922249972820282 ]
[8]:
from matplotlib import pyplot as plt
# Draw accuracy values for training & validation
plt.plot(history.global_history['multiclassaccuracy'])
plt.plot(history.global_history['val_multiclassaccuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()

[ ]: