secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w 源代码
#!/usr/bin/env python3
# *_* coding: utf-8 *_*
# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Tuple
import numpy as np
import torch
from secretflow.device import PYUObject, proxy
from secretflow.ml.nn.fl.backend.torch.fl_base import BaseTorchModel
from secretflow.ml.nn.fl.strategy_dispatcher import register_strategy
[文档]class FedAvgW(BaseTorchModel):
"""
FedAvgW: A naive implementation of FedAvg, where the clients upload their trained model
weights to the server for averaging and update their local models via the aggregated weights
from the server in each federated round.
"""
[文档] def train_step(
self,
weights: np.ndarray,
cur_steps: int,
train_steps: int,
**kwargs,
) -> Tuple[np.ndarray, int]:
"""Accept ps model params, then do local train
Args:
weights: global weight from params server
cur_steps: current train step
train_steps: local training steps
kwargs: strategy-specific parameters
Returns:
Parameters after local training
"""
assert self.model is not None, "Model cannot be none, please give model define"
self.model.train()
if weights is not None:
self.model.update_weights(weights)
num_sample = 0
dp_strategy = kwargs.get('dp_strategy', None)
logs = {}
for _ in range(train_steps):
self.optimizer.zero_grad()
iter_data = next(self.train_iter)
if len(iter_data) == 2:
x, y = iter_data
s_w = None
elif len(iter_data) == 3:
x, y, s_w = iter_data
x = x.float()
num_sample += x.shape[0]
if len(y.shape) == 1:
y_t = y
else:
if y.shape[-1] == 1:
y_t = torch.squeeze(y, -1).long()
else:
y_t = y.argmax(dim=-1)
y_pred = self.model(x)
# do back propagation
loss = self.loss(y_pred, y_t)
loss.backward()
self.optimizer.step()
for m in self.metrics:
m.update(y_pred, y_t)
loss_value = loss.item()
logs['train-loss'] = loss_value
self.logs = self.transform_metrics(logs)
self.epoch_logs = copy.deepcopy(self.logs)
model_weights = self.model.get_weights(return_numpy=True)
# DP operation
if dp_strategy is not None:
if dp_strategy.model_gdp is not None:
model_weights = dp_strategy.model_gdp(model_weights)
return model_weights, num_sample
[文档]@register_strategy(strategy_name='fed_avg_w', backend='torch')
@proxy(PYUObject)
class PYUFedAvgW(FedAvgW):
pass