This is an implementation of the following paper:
Yue Zhao, Meng Li, Liangzhen Lai, Naveen Suda, Damon Civin, Vikas Chandra. Federated Learning with Non-IID Data
arXiv:1806.00582.
TL;DR: Previous federated optization algorithms (such as FedAvg and FedProx) converge to stationary points of a mismatched objective function due to heterogeneity in data distribution. In this paper, the authors propose a data-sharing strategy to improve training on non-IID data by creating a small subset of data which is globally shared between all the edge devices.
Abstract: Federated learning enables resource-constrained edge compute devices, such as mobile phones and IoT devices, to learn a shared model for prediction, while keeping the training data local. This decentralized approach to train models provides privacy, security, regulatory and economic benefits. In this work, we focus on the statistical challenge of federated learning when local data is non-IID. We first show that the accuracy of federated learning reduces significantly, by up to ~55% for neural networks trained for highly skewed non-IID data, where each client device trains only on a single class of data. We further show that this accuracy reduction can be explained by the weight divergence, which can be quantified by the earth mover’s distance (EMD) between the distribution over classes on each device and the population distribution. As a solution, we propose a strategy to improve training on non-IID data by creating a small subset of data which is globally shared between all the edge devices. Experiments show that accuracy can be increased by ~30% for the CIFAR-10 dataset with only 5% globally shared data.
The implementation runs on:
- Python 3.8
- PyTorch 1.6.0
- CUDA 10.1
- cuDNN 7.6.5
Currently, this repository supports the following federated learning algorithms:
- FedAvg (Mcmahan et al. AISTAT 2017): local solver is vanilla SGD; aggregate cumulative local model changes
- FedProx (Li et al. MLSys 2020): local solver is proximal SGD; aggregate cumulative local model changes
An example launch script is shown below.
python main.py
--all_clients \
--fed fedavg \
--gpu 0 \
--seed 1 \
--sampling noniid \
--sys_homo \
--num_channels 3 \
--dataset cifar
Explanations of arguments:
fed
: federated optimization algorithmmu
: parameter for fedproxsampling
: sampling methodalpha
: random portion of global datasetdataset
: name of datasetrounds
: total number of communication roundssys_homo
: no system heterogeneity