使用Pytorch后端来进行联邦学习#

以下代码仅作为演示用,请勿直接在生产环境使用。

在本教程中,我们将引导您了解如何在 SecretFlow 上使用 pytorch 后端进行联邦学习。 - 我们将以图像分类任务为例 - 使用 pytorch 作为后端 - 我们将展示如何使用多种 fl 策略

如果你想了解更多关于联邦学习、数据集等的知识,可以移步 水平联邦:图像分类

开始吧

[1]:
%load_ext autoreload
%autoreload 2
[2]:
import secretflow as sf

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(['alice', 'bob', 'charlie'], address='local')
alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')
2022-08-31 23:43:03.362818: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib
BaseModule: 类似于 torch.nn.module
TorchModel: 一个包装类包括 loss_fn,optim_fn,
metric_wrapper: metrics包装器
optim_wrapper: optim_fn包装器
FLModel:联合模型,使用 backend 指定哪个bachend将被使用,使用 strategy 来指定将使用哪个联合策略
[3]:
from secretflow.ml.nn.fl.backend.torch.utils import BaseModule, TorchModel
from secretflow.ml.nn.fl.utils import metric_wrapper, optim_wrapper
from secretflow.ml.nn import FLModel
from torchmetrics import Accuracy, Precision
from secretflow.security.aggregation import SecureAggregator
from secretflow.utils.simulation.datasets import load_mnist
from torch import nn, optim
from torch.nn import functional as F

当我们定义模型时,我们只需要继承 BaseModule 来代替 nn.Module 的,其他与pytorch一致

[4]:

class ConvNet(BaseModule): """Small ConvNet for MNIST.""" def __init__(self): super(ConvNet, self).__init__() self.conv1 = nn.Conv2d(1, 3, kernel_size=3) self.fc_in_dim = 192 self.fc = nn.Linear(self.fc_in_dim, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 3)) x = x.view(-1, self.fc_in_dim) x = self.fc(x) return F.softmax(x, dim=1)

我们可以继续使用中在pytorch中定义的损失函数和优化器,在secretflow中唯一的区别是我们需要用包装器来包装它。

[5]:
(train_data, train_label), (test_data, test_label) = load_mnist(
    parts={alice: 0.4, bob: 0.6},
    normalized_x=True,
    categorical_y=True,
    is_torch=True,
)

loss_fn = nn.CrossEntropyLoss
optim_fn = optim_wrapper(optim.Adam, lr=1e-2)
model_def = TorchModel(
    model_fn=ConvNet,
    loss_fn=loss_fn,
    optim_fn=optim_fn,
    metrics=[
        metric_wrapper(Accuracy, task="multiclass", num_classes=10, average='micro'),
        metric_wrapper(Precision, task="multiclass", num_classes=10, average='micro'),
    ],
)
[6]:
device_list = [alice, bob]
server = charlie
aggregator = SecureAggregator(server,[alice,bob])

# spcify params
fl_model = FLModel(
    server=server,
    device_list=device_list,
    model=model_def,
    aggregator=aggregator,
    strategy='fed_avg_w', # fl strategy
    backend="torch", # backend support ['tensorflow', 'torch']
)
[7]:
history = fl_model.fit(
            train_data,
            train_label,
            validation_data=(test_data, test_label),
            epochs=20,
            batch_size=32,
            aggregate_freq=1,
        )
100%|█████████▉| 749/750 [00:15<00:00, 47.41it/s]2022-08-31 23:43:34.759168: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib
2022-08-31 23:43:34.759205: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
100%|██████████| 750/750 [00:16<00:00, 46.78it/s, epoch: 1/20 -  accuracy:0.9709533452987671  precision:0.8571249842643738  val_accuracy:0.9840199947357178  val_precision:0.8955000042915344 ]
100%|██████████| 125/125 [00:03<00:00, 31.28it/s, epoch: 2/20 -  accuracy:0.9825800061225891  precision:0.9190000295639038  val_accuracy:0.9850000143051147  val_precision:0.903249979019165 ]
100%|██████████| 125/125 [00:02<00:00, 41.70it/s, epoch: 3/20 -  accuracy:0.9850000143051147  precision:0.9302499890327454  val_accuracy:0.9856399893760681  val_precision:0.906499981880188 ]
100%|██████████| 125/125 [00:03<00:00, 41.66it/s, epoch: 4/20 -  accuracy:0.9859799742698669  precision:0.9334999918937683  val_accuracy:0.9861800074577332  val_precision:0.9085000157356262 ]
100%|██████████| 125/125 [00:03<00:00, 41.46it/s, epoch: 5/20 -  accuracy:0.9870200157165527  precision:0.940500020980835  val_accuracy:0.9864799976348877  val_precision:0.9097499847412109 ]
100%|██████████| 125/125 [00:02<00:00, 41.67it/s, epoch: 6/20 -  accuracy:0.987779974937439  precision:0.9422500133514404  val_accuracy:0.9869400262832642  val_precision:0.9137499928474426 ]
100%|██████████| 125/125 [00:02<00:00, 41.92it/s, epoch: 7/20 -  accuracy:0.988099992275238  precision:0.9447500109672546  val_accuracy:0.9870200157165527  val_precision:0.9139999747276306 ]
100%|██████████| 125/125 [00:03<00:00, 41.55it/s, epoch: 8/20 -  accuracy:0.9887800216674805  precision:0.9477499723434448  val_accuracy:0.986739993095398  val_precision:0.9135000109672546 ]
100%|██████████| 125/125 [00:03<00:00, 41.49it/s, epoch: 9/20 -  accuracy:0.9892399907112122  precision:0.9502500295639038  val_accuracy:0.9868199825286865  val_precision:0.9132500290870667 ]
100%|██████████| 125/125 [00:02<00:00, 41.81it/s, epoch: 10/20 -  accuracy:0.989359974861145  precision:0.9522500038146973  val_accuracy:0.9873600006103516  val_precision:0.9175000190734863 ]
100%|██████████| 125/125 [00:03<00:00, 41.46it/s, epoch: 11/20 -  accuracy:0.9898999929428101  precision:0.953249990940094  val_accuracy:0.9874200224876404  val_precision:0.9194999933242798 ]
100%|██████████| 125/125 [00:02<00:00, 41.77it/s, epoch: 12/20 -  accuracy:0.990119993686676  precision:0.953499972820282  val_accuracy:0.9871600270271301  val_precision:0.9154999852180481 ]
100%|██████████| 125/125 [00:02<00:00, 41.87it/s, epoch: 13/20 -  accuracy:0.9906600117683411  precision:0.9570000171661377  val_accuracy:0.9876800179481506  val_precision:0.9202499985694885 ]
100%|██████████| 125/125 [00:03<00:00, 40.91it/s, epoch: 14/20 -  accuracy:0.9910399913787842  precision:0.9572499990463257  val_accuracy:0.9880200028419495  val_precision:0.9227499961853027 ]
100%|██████████| 125/125 [00:03<00:00, 41.49it/s, epoch: 15/20 -  accuracy:0.9903600215911865  precision:0.9542499780654907  val_accuracy:0.9878000020980835  val_precision:0.9194999933242798 ]
100%|██████████| 125/125 [00:02<00:00, 41.68it/s, epoch: 16/20 -  accuracy:0.9914000034332275  precision:0.9585000276565552  val_accuracy:0.9878799915313721  val_precision:0.921750009059906 ]
100%|██████████| 125/125 [00:02<00:00, 42.21it/s, epoch: 17/20 -  accuracy:0.9915599822998047  precision:0.9597499966621399  val_accuracy:0.988099992275238  val_precision:0.921750009059906 ]
100%|██████████| 125/125 [00:03<00:00, 41.41it/s, epoch: 18/20 -  accuracy:0.9915800094604492  precision:0.9595000147819519  val_accuracy:0.9880399703979492  val_precision:0.921500027179718 ]
100%|██████████| 125/125 [00:02<00:00, 41.83it/s, epoch: 19/20 -  accuracy:0.9916200041770935  precision:0.9605000019073486  val_accuracy:0.9887400269508362  val_precision:0.9244999885559082 ]
100%|██████████| 125/125 [00:03<00:00, 41.34it/s, epoch: 20/20 -  accuracy:0.9922599792480469  precision:0.9637500047683716  val_accuracy:0.9883599877357483  val_precision:0.922249972820282 ]
[8]:
from matplotlib import pyplot as plt

# Draw accuracy values for training & validation
plt.plot(history.global_history['multiclassaccuracy'])
plt.plot(history.global_history['val_multiclassaccuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()

../_images/tutorial_Federated_Learning_with_Pytorch_backend_13_0.png
[ ]: