{ "cells": [ { "cell_type": "markdown", "id": "9f6ceddd", "metadata": {}, "source": [ "# Federated Learning with Pytorch Backend" ] }, { "cell_type": "markdown", "id": "3c2fb4f8", "metadata": {}, "source": [ ">The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production." ] }, { "cell_type": "markdown", "id": "13008203", "metadata": {}, "source": [ "In this tutorial, We will walk you through how to use pytorch backend on SecretFlow for federated learning.\n", "+ We will use the image clasification task as example\n", "+ Use pytorch as backend\n", "+ We will show how to use multi fl strategy\n", " \n", "If you want to learn more about federated learning, datasets, etc., you can move to [Federated Learning for Image Classification](Federate_Learning_for_Image_Classification.ipynb)\n", " \n", "**Here we go!**" ] }, { "cell_type": "code", "execution_count": 1, "id": "4c69265d", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "id": "b00d46f2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-08-31 23:43:03.362818: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib\n" ] } ], "source": [ "import secretflow as sf\n", "\n", "# In case you have a running secretflow runtime already.\n", "sf.shutdown()\n", "\n", "sf.init(['alice', 'bob', 'charlie'], address='local')\n", "alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')" ] }, { "cell_type": "markdown", "id": "f58fe11d", "metadata": {}, "source": [ "BaseModule: Similar to the `torch.nn.module` \n", "TorchModel: A wrap class include `loss_fn`,`optim_fn`,`model_def`,`metrics` \n", "metric_wrapper: Wrap metrics to workers \n", "optim_wrapper: Wrap optim_fn to workers \n", "FLModel: Federated model, use `backend` to specify which bachend will be use, use `strategy` to spcify which federated strategy will be use" ] }, { "cell_type": "code", "execution_count": 3, "id": "de99cbf8", "metadata": {}, "outputs": [], "source": [ "from secretflow.ml.nn.fl.backend.torch.utils import BaseModule, TorchModel\n", "from secretflow.ml.nn.fl.utils import metric_wrapper, optim_wrapper\n", "from secretflow.ml.nn import FLModel\n", "from torchmetrics import Accuracy, Precision\n", "from secretflow.security.aggregation import SecureAggregator\n", "from secretflow.utils.simulation.datasets import load_mnist\n", "from torch import nn, optim\n", "from torch.nn import functional as F" ] }, { "cell_type": "markdown", "id": "b9c3ea64", "metadata": {}, "source": [ "When we define the model, we only need to inherit `BaseModule` instead of `nn.Module`, and the others are consistent with pytorch" ] }, { "cell_type": "code", "execution_count": 4, "id": "85d2028a", "metadata": {}, "outputs": [], "source": [ "\n", "class ConvNet(BaseModule):\n", " \"\"\"Small ConvNet for MNIST.\"\"\"\n", "\n", " def __init__(self):\n", " super(ConvNet, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 3, kernel_size=3)\n", " self.fc_in_dim = 192\n", " self.fc = nn.Linear(self.fc_in_dim, 10)\n", "\n", " def forward(self, x):\n", " x = F.relu(F.max_pool2d(self.conv1(x), 3))\n", " x = x.view(-1, self.fc_in_dim)\n", " x = self.fc(x)\n", " return F.softmax(x, dim=1)\n" ] }, { "cell_type": "markdown", "id": "e62ce093", "metadata": {}, "source": [ "We can continue to use the loss function and optimizer defined in pytorch, the only difference is that we need to wrap it with the wrapper provided in secretflow" ] }, { "cell_type": "code", "execution_count": 5, "id": "645e3fbc", "metadata": {}, "outputs": [], "source": [ "(train_data, train_label), (test_data, test_label) = load_mnist(\n", " parts={alice: 0.4, bob: 0.6},\n", " normalized_x=True,\n", " categorical_y=True,\n", " is_torch=True,\n", ")\n", "\n", "loss_fn = nn.CrossEntropyLoss\n", "optim_fn = optim_wrapper(optim.Adam, lr=1e-2)\n", "model_def = TorchModel(\n", " model_fn=ConvNet,\n", " loss_fn=loss_fn,\n", " optim_fn=optim_fn,\n", " metrics=[\n", " metric_wrapper(Accuracy, task=\"multiclass\", num_classes=10, average='micro'),\n", " metric_wrapper(Precision, task=\"multiclass\", num_classes=10, average='micro'),\n", " ],\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "id": "95fcf7b6", "metadata": {}, "outputs": [], "source": [ "device_list = [alice, bob]\n", "server = charlie\n", "aggregator = SecureAggregator(server,[alice,bob])\n", "\n", "# spcify params\n", "fl_model = FLModel(\n", " server=server,\n", " device_list=device_list,\n", " model=model_def,\n", " aggregator=aggregator,\n", " strategy='fed_avg_w', # fl strategy\n", " backend=\"torch\", # backend support ['tensorflow', 'torch']\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "c595099d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████▉| 749/750 [00:15<00:00, 47.41it/s]2022-08-31 23:43:34.759168: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib\n", "2022-08-31 23:43:34.759205: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)\n", "100%|██████████| 750/750 [00:16<00:00, 46.78it/s, epoch: 1/20 - accuracy:0.9709533452987671 precision:0.8571249842643738 val_accuracy:0.9840199947357178 val_precision:0.8955000042915344 ]\n", "100%|██████████| 125/125 [00:03<00:00, 31.28it/s, epoch: 2/20 - accuracy:0.9825800061225891 precision:0.9190000295639038 val_accuracy:0.9850000143051147 val_precision:0.903249979019165 ]\n", "100%|██████████| 125/125 [00:02<00:00, 41.70it/s, epoch: 3/20 - accuracy:0.9850000143051147 precision:0.9302499890327454 val_accuracy:0.9856399893760681 val_precision:0.906499981880188 ]\n", "100%|██████████| 125/125 [00:03<00:00, 41.66it/s, epoch: 4/20 - accuracy:0.9859799742698669 precision:0.9334999918937683 val_accuracy:0.9861800074577332 val_precision:0.9085000157356262 ]\n", "100%|██████████| 125/125 [00:03<00:00, 41.46it/s, epoch: 5/20 - accuracy:0.9870200157165527 precision:0.940500020980835 val_accuracy:0.9864799976348877 val_precision:0.9097499847412109 ]\n", "100%|██████████| 125/125 [00:02<00:00, 41.67it/s, epoch: 6/20 - accuracy:0.987779974937439 precision:0.9422500133514404 val_accuracy:0.9869400262832642 val_precision:0.9137499928474426 ]\n", "100%|██████████| 125/125 [00:02<00:00, 41.92it/s, epoch: 7/20 - accuracy:0.988099992275238 precision:0.9447500109672546 val_accuracy:0.9870200157165527 val_precision:0.9139999747276306 ]\n", "100%|██████████| 125/125 [00:03<00:00, 41.55it/s, epoch: 8/20 - accuracy:0.9887800216674805 precision:0.9477499723434448 val_accuracy:0.986739993095398 val_precision:0.9135000109672546 ]\n", "100%|██████████| 125/125 [00:03<00:00, 41.49it/s, epoch: 9/20 - accuracy:0.9892399907112122 precision:0.9502500295639038 val_accuracy:0.9868199825286865 val_precision:0.9132500290870667 ]\n", "100%|██████████| 125/125 [00:02<00:00, 41.81it/s, epoch: 10/20 - accuracy:0.989359974861145 precision:0.9522500038146973 val_accuracy:0.9873600006103516 val_precision:0.9175000190734863 ]\n", "100%|██████████| 125/125 [00:03<00:00, 41.46it/s, epoch: 11/20 - accuracy:0.9898999929428101 precision:0.953249990940094 val_accuracy:0.9874200224876404 val_precision:0.9194999933242798 ]\n", "100%|██████████| 125/125 [00:02<00:00, 41.77it/s, epoch: 12/20 - accuracy:0.990119993686676 precision:0.953499972820282 val_accuracy:0.9871600270271301 val_precision:0.9154999852180481 ]\n", "100%|██████████| 125/125 [00:02<00:00, 41.87it/s, epoch: 13/20 - accuracy:0.9906600117683411 precision:0.9570000171661377 val_accuracy:0.9876800179481506 val_precision:0.9202499985694885 ]\n", "100%|██████████| 125/125 [00:03<00:00, 40.91it/s, epoch: 14/20 - accuracy:0.9910399913787842 precision:0.9572499990463257 val_accuracy:0.9880200028419495 val_precision:0.9227499961853027 ]\n", "100%|██████████| 125/125 [00:03<00:00, 41.49it/s, epoch: 15/20 - accuracy:0.9903600215911865 precision:0.9542499780654907 val_accuracy:0.9878000020980835 val_precision:0.9194999933242798 ]\n", "100%|██████████| 125/125 [00:02<00:00, 41.68it/s, epoch: 16/20 - accuracy:0.9914000034332275 precision:0.9585000276565552 val_accuracy:0.9878799915313721 val_precision:0.921750009059906 ]\n", "100%|██████████| 125/125 [00:02<00:00, 42.21it/s, epoch: 17/20 - accuracy:0.9915599822998047 precision:0.9597499966621399 val_accuracy:0.988099992275238 val_precision:0.921750009059906 ]\n", "100%|██████████| 125/125 [00:03<00:00, 41.41it/s, epoch: 18/20 - accuracy:0.9915800094604492 precision:0.9595000147819519 val_accuracy:0.9880399703979492 val_precision:0.921500027179718 ]\n", "100%|██████████| 125/125 [00:02<00:00, 41.83it/s, epoch: 19/20 - accuracy:0.9916200041770935 precision:0.9605000019073486 val_accuracy:0.9887400269508362 val_precision:0.9244999885559082 ]\n", "100%|██████████| 125/125 [00:03<00:00, 41.34it/s, epoch: 20/20 - accuracy:0.9922599792480469 precision:0.9637500047683716 val_accuracy:0.9883599877357483 val_precision:0.922249972820282 ]\n" ] } ], "source": [ "history = fl_model.fit(\n", " train_data,\n", " train_label,\n", " validation_data=(test_data, test_label),\n", " epochs=20,\n", " batch_size=32,\n", " aggregate_freq=1,\n", " )" ] }, { "cell_type": "code", "execution_count": 8, "id": "55c13406", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from matplotlib import pyplot as plt\n", "\n", "# Draw accuracy values for training & validation\n", "plt.plot(history.global_history['multiclassaccuracy'])\n", "plt.plot(history.global_history['val_multiclassaccuracy'])\n", "plt.title('FLModel accuracy')\n", "plt.ylabel('Accuracy')\n", "plt.xlabel('Epoch')\n", "plt.legend(['Train', 'Valid'], loc='upper left')\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2e330d34", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('3.8')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" }, "vscode": { "interpreter": { "hash": "ae1fdd5fd034b7d694352220485921694ff89198520409089b4646721fce11ca" } } }, "nbformat": 4, "nbformat_minor": 5 }