This work is licensed under a Creative Commons Attribution-NonCommercial 3.0 Unported License.
We developed a Federated Learning(FL) framework for international researchers to train the AI diagnosis model on remote machines. Please follow this user guide to effectively deploy the AI model empowered by Federated Learning at the hospital or other non-profit organizaitons.
Similar to most C/S structures, the framework consists of two parts: Server side and Client side. To apply this framework in real scenarios, taking hospitals for example, the client part could be deployed on the machines of the hospital side (Client) where the federated model gets trained locally, while the server part is established on the centeral machine (Server).
Once the scripts are executed, the hospitals will train their own models locally and transmit the parameters to the server which aggregates all parameters collected from the clients. Then the server distribute the newly aggregated parameters to each client maintaining the FL process. This process is about to iterate for some pre-set rounds before the accuracy of the aggregated model reached the desired performance.
Besides the above functionality, we equip our framework with addtional features:
-
Encrypted parameters: Every Client could encrypt their model parameters trained locally via their generated private key. The Server will just aggregate these parameters but cannot decrypt them.
-
Weighted aggregation: We also add weighted aggregation method in this framework, researchers could assign different weight to each client for the final aggregation.
For the need of Encryption and Weighted Aggregation, it is not sufficient that Server and Clients only communicate the model parameters between them.
We define the file content format for this framework as follows:
File transmitted from Client contains:
"encrypted model_state": model encrpyted by the Client's private key
"client_weight": this Client's weight in the FL aggregation process
File transmitted from server contains:
"model_state_dict": aggregated model parameters
"client_weight_sum": the weight_sum of Clients of the current FL process
"client_num": the number of Clients of the current FL process
And we prepare pack/unpack functions
both in Server and Client class.
If one does not need the Encryption or Weighted Aggregation, the file format could be redefined. All the files are stored in .pth
format.
1.2.1 "Server folder" contains two main scripts server_main.py
and fl_server.py
. In fl_server.py
we define the FL_Server
class, and in server_main
we provide an example using FL_Server
class.
1.2.2 Before starting the FL process, we need to set some configurations for the Server.
./server/config/config.json
, you need to set the .json
file for yourself:
{ "ip": "0.0.0.0", "send_port": 12341, "recv_port": 12342, "clients_path": "./config/clients.json", "model_path": "./model/model.py", "weight_path": "./model/merge_model/initial.pth", "merge_model_dir": "./model/merge_model/", "client_weight_dir": "./model/client_model/", "buff_size": 1024, "iteration": 120 }
ip
, recv_port
: socket configurations.
client_path
: .json
file path to clients.json
, which contains clients informations(username
and password
)
weight_path
: the model path which will be sent to the clients in the FL process
merge_model_dir
: after aggregations, merged model will be stored here.
client_model_dir
: model_parameters(.pth
file) received from each clients will be saved here, and after the aggregation, all files stored here will be removed.
iterations
: FL processing iterations num.
1.2.3 ./server/config/clients.json
stores the username
and password
of each Client. Clients need to register to the server via these informations. If the information is wrong, the register request will be refused and cannot join the FL process.
There are some examples below:
{
"Bob": "123456",
"Alan": "123456",
"John": "123456"
}
1.3.1 ./client
folder also contains 2 main scripts fl_client.py
in which we define the FL_Client
class, and client1_main.py
which we provide an example.
1.3.2 Configurations of Clients are as below:
./client/config/client1_config.json
{ "username": "Bob", "password": "123456", "ip": "127.0.0.1", "send_port": 12346, "recv_port": 12347, "server_ip": "127.0.0.1", "server_port": 12342, "buff_size": 1024, "model_path": "./model/model.py", "weight_path": "./model/initial.pth", "models_dir": "./model", "iteration": 120 }
For the client, pay attention to its FL_Client.train_model_path
attribute, which stores the path of the model parameter file received from the Server; This attribute stores "weight_path" in the client1_config.json
when it is initialized, but the FL_Client class will change this value when the program is running, the user can directly read the parameter file sent from the server through this path.
1.3.3 Because the training process takes place on the client side, you also need to set you own train settings. Our configurations are given in the following as an example:
{ "train_data_dir": "/home/ABC/EFG/dataset/", "train_df_csv": "./utils/HIJ.csv", "labels_train_df_csv": "./utils/HIJ.csv", "test_data_dir": "/home/ABC/EFG/dataset/", "test_df_csv": "./utils/test_norm.csv", "labels_test_df_csv": "./utils/test_norm.csv", "use_cuda": true, "epoch": 101, "train_batch_size": 50, "test_batch_size": 2, "num_workers": 12, "lr": 0.015, "momentum": 0.9 }
Developers could run this command git clone https://github.com/HUST-EIC-AI-LAB/COVID-19-Fedrated-Learinig.git
to deploy your own FL task.
Some dependencies may need to be pre-installed, e.g. pytorch and CUDA, before you can train on GPU. Run pip install -r requirement.txt
to install the required dependencies
Notice:
In requirement.txt
, we use torch matches cuda == 9.2.
If there are problems in using torch, it may be caused by version mismatch between torch and CUDA, please check your CUDA version by cat /usr/local/cuda/version.txt
, and download the correct version of Pytorch from the official website.
ATTENTION:
ninja
and re2c
are C ++ extension methods, you should install them as described in their github.
ninja
: https://github.com/ninja-build/ninja
re2c
: https://github.com/skvadrik/re2c
Our encryption algorithm comes from https://github.com/lucifer2859/Paillier-LWE-based-PHE
3.1 We have reduced the operations required in the communication process as much as possible. Yet, the Client-side training process and the Sever-side aggregating process still need to be customized by the researcher.
We provide a rough template for the communication process, ./server/server_main_raw.py
and ./client/client1_main_row.py
. You can design your own federal learning process accordingly.
In addition, we also provide our federal learning code for COVID-19 prediction as an example, which contains encryption and weighted aggregation content.
First run the server_main.py
file, then run the client1_main.py
file on the client machine. You can add any number of clients, only need to modify the corresponding configuration file path in client1_main.py
, such as: client1_config.json
and train_config_client1.json
.
client = FL_Client('./config/client1_config.json') client.start() with open('./config/train_config_client1.json') as j: train_config = json.load(j)
After modifying the corresponding configuration, run separately:
python server_main.py
python client1_main.py
pyhton client2_main.py ......
3.2 Some tips
Our FL process has more flexibility. On the server side, client developers can select all registered clients to return to their parameter files and start aggregation. You can also set a minimum number of clients min_clients
and a maximum
delay: when enough number of clients return their files or no new client returns its file in the predefined maximum delay, the server can start the aggregation process in advance, and no longer receive any client requests.
These contents can be found in the server_main_raw.py
file and the corresponding code is provided.
We have completed the communication between different devices in the local area network, and have conducted tests related to the Federated Learning process. Our communication process is based on Socket. If you want to successfully deploy this framework to different devices in different network domains, developers may need to consider the corresponding port and firewall settings to ensure successful communication.