This is the simulation code package for the following paper:
Zehong Lin, Hang Liu, and Ying-Jun Angela Zhang, “Relay-Assisted Cooperative Federated Learning,” IEEE Transactions on Wireless Communications, DOI: 10.1109/TWC.2022.3155596. [ArXiv Version]
The package, written on Python 3 and Matlab, reproduces the numerical results of the proposed algorithm in the above paper.
Federated learning (FL) has recently emerged as a promising technology to enable artificial intelligence (AI) at the network edge, where distributed mobile devices collaboratively train a shared AI model under the coordination of an edge server. To significantly improve the communication efficiency of FL, over-the-air computation allows a large number of mobile devices to concurrently upload their local models by exploiting the superposition property of wireless multi-access channels. Due to wireless channel fading, the model aggregation error at the edge server is dominated by the weakest channel among all devices, causing severe straggler issues. In this paper, we propose a relay-assisted cooperative FL scheme to effectively address the straggler issue. In particular, we deploy multiple half-duplex relays to cooperatively assist the devices in uploading the local model updates to the edge server. The nature of the over-the-air computation poses system objectives and constraints that are distinct from those in traditional relay communication systems. Moreover, the strong coupling between the design variables renders the optimization of such a system challenging. To tackle the issue, we propose an alternating-optimization-based algorithm to optimize the transceiver and relay operation with low complexity. Then, we analyze the model aggregation error in a single-relay case and show that our relay-assisted scheme achieves a smaller error than the one without relays provided that the relay transmit power and the relay channel gains are sufficiently large. The analysis provides critical insights on relay deployment in the implementation of cooperative FL. Extensive numerical results show that our design achieves faster convergence compared with state-of-the-art schemes.
If you in any way use this code for research that results in publications, please cite our original article listed above.
This package is written on Matlab and Python 3. It requires the following libraries:
- Matlab and CVX
- Python >= 3.5
- torch
- torchvision
- scipy
- CUDA (if GPU is used)
- data/: Store the Fashion-MNIST dataset. When running at the first time, it automatically downloads the dataset from the Interenet.
- store/: Store output files (*.npz)
- matlab/: Documents for data and codes to be used in Matlab
- DATA/: Store files (*.mat) for channel models and optimization results in Matlab
- training_result/: Store files for training results (*.mat) to be plotted for presentation
- main_cmp.m: Initialize the simulation system, optimizing the variables
- Setup_Init.m: Specify and initialize the system parameters
- AM.m: Alternating minization algorithm proposed in the paper
- Single.m: Conventional over-the-air model aggregation scheme
- Xu.m: Existing relay-assisted scheme in Ref. [23]
- single_relay_channel.m: Construct the channel model for the single-relay case
- single_relay_channel_loc.m: Construct the channel model for the single-relay case with varying relay location
- cell_channel_model.m: Construct the channel model for the multi-relay case in a single-cell
- plot_figure.m: plot the figure with varying transmission blocks from the training results stored in training_result/
- plot_Pr.m: plot the figure with varying P_r from the training results stored in training_result/
- main.py: Initialize the simulation system, training the learning model, and storing the result to store/ as a npz file
- initial(): Initialize the parser function to read the user-input parameters
- learning_flow.py: Read the optimization result, initial the learning model, and perform training and testing
- Learning_iter(): Given learning model, compute the graidents, update the training models, and perform testing on top of train_script.py
- FedAvg_grad(): Given the aggregated model changes and the current model, update the global model by eq.(5)
- Nets.py:
- CNNMnist(): Specify the convolutional neural network structure used for learning
- MLP(): Specify the multiple layer perceptron structure used for learning
- AirComp.py:
- AM(): Given the local model changes, perform relay-assisted over-the-air model aggregation; see Section II-C
- Single(): Given the local model changes, perform conventional over-the-air model aggregation; see Section II-B
- Xu(): Given the local model changes, perform relay-assisted over-the-air model aggregation scheme proposed in Ref. [23]
- train_script.py:
- Load_fmnist_iid(): Download (if needed) and load the Fashion-MNIST data, and distribute them to the local devices
- Load_fmnist_noniid(): Download (if needed) and load the Fashion-MNIST data, and distribute them to the local devices by following a non-iid distribution
- local_update(): Given a learning model and the distributed training data, compute the local gradients/model changes
- test_model(): Given a learning model, test the accuracy/loss based on certain test images
- plot_result.py: plot the figure with varying transmission blocks from the output files in store/, process and store the training results in matlab/training_result/
- plot_Pr.py: plot the figure with varying P_r from the output files in store/, process and store the training results in matlab/training_result/
-
Use the codes for channel models in matlab/ to obtain the channel coefficients.
-
The main file for optimization in Matlab is matlab/main_cmp.m, which optimizes the variables of the proposed relay-assisted scheme and benchmark schemes.
Run matlab/main_cmp.m, the obtained optimization results are then used for FL.
- The main file for FL is main.py. It can take the following user-input parameters by a parser (also see the function initial() in main.py):
Parameter Name | Meaning | Default Value | Type/Range |
---|---|---|---|
K | total number of devices | 20 | int |
N | total number of relays | 1 | int |
PL | path loss exponent | 3.0 | float |
trial | total number of Monte Carlo trials | 50 | int |
SNR | -noise variance in dB | 100 | float |
P_r | relay transmit power budget | 0.1 | float |
verbose | output no/importatnt/detailed messages in running the scripts | 1 | 0, 1 |
seed | random seed | 1 | int |
gpu | GPU index used for learning (if possible) | 0 | int |
local_ep | number of local epochs, E | 1 | int |
local_bs | local batch size, B, 0 for full batch | 0 | int |
lr | learning rate, lambda | 0.05 | float |
low_lr | learning rate lower bound, bar_lambda | 1e-5 | float |
gamma | learning rate decrease ratio, gamma | 0.9 | float |
step | learning rate decrease step, bar_T | 50 | int |
momentum | SGD momentum, used only for multiple local updates | 0.99 | float |
epochs | number of training rounds, T | 500 | int |
iid | 1 for iid, 0 for non-iid | 1 | 0, 1 |
noniid_level | number of classes at each device for non-iid | 2 | 2, 4, 6, 8, 10 |
V_idx | Variable index | 0 | int |
Here is an example for executing the scripts in a Linux terminal:
python main.py --gpu=0 --trial=50 --V_idx 0