使用pytorch,mnist、cifar数据集实现基础的联邦学习(平局聚合、不含加密过程)
python
pytorch
下载mnist、cifar10数据集
如果需要使用自己的数据即替换dataset即可
python main.py
在utils/conf.json中修改参数设置
- model_name:模型名称
- no_models:客户端数量
- type:数据集信息
- global_epochs:全局迭代次数,即服务端与客户端的通信迭代次数
- local_epochs:本地模型训练迭代次数
- k:每一轮迭代时,服务端会从所有客户端中挑选k个客户端参与训练。
- batch_size:本地训练每一轮的样本数
- lr,momentum,lambda:本地训练的超参数设置
服务端的主要功能是将被选择的客户端上传的本地模型进行模型聚合(如果需要完善其他的复杂功能如同态加密、服务端需要对各个客户端节点进行网络监控、对失败节点发出重连信号等等功能,可以采用FATE平台)
这里的模型是在本地模拟的,不涉及网络通信细节和失败故障等处理,因此不讨论这些功能细节,仅涉及模型聚合功能。
服务端的工作包括:
第一,将配置信息拷贝到服务端中;
第二,按照配置中的模型信息获取模型,这里我们使用torchvision 的models模块内置的ResNet-18模型。
第三,这里的模型定义模型聚合函数。采用经典的FedAvg 算法。
第四,定义模型评估函数。
客户端主要功能是接收服务端的下发指令和全局模型,利用本地数据进行局部模型训练。
客户端的工作包括:
第一,定义构造函数。首先,将配置信息拷贝到客户端中;然后,按照配置中的模型信息获取模型,通常由服务端将模型参数传递给客户端,客户端将该全局模型覆盖掉本地模型;最后,配置本地训练数据。
第二,定义模型本地训练函数。
首先,读取配置文件信息。
每一轮的迭代,服务端会从当前的客户端集合中随机挑选一部分参与本轮迭代训练,被选中的客户端调用本地训练接口local_train进行本地训练,最后服务端调用模型聚合函数model_aggregate来更新全局模型
见图片"figures/fig31.png"
-
联邦训练配置:一共10台客户端设备(no_models=10),每一轮任意挑选其中的5台参与训练(k=5), 每一次本地训练迭代次数为3次(local_epochs=3),全局迭代次数为20次(global_epochs=20)。
-
集中式训练配置:我们不需要单独编写集中式训练代码,只需要修改联邦学习配置既可使其等价于集中式训练。具体来说,我们将客户端设备no_models和每一轮挑选的参与训练设备数k都设为1即可。这样只有1台设备参与的联邦训练等价于集中式训练。其余参数配置信息与联邦学习训练一致。图中我们将局部迭代次数分别设置了1,2,3来进行比较。
见图片"figures/fig2.png"
图中的单点训练只的是在某一个客户端下,利用本地的数据进行模型训练的结果。
- 我们看到单点训练的模型效果(蓝色条)明显要低于联邦训练 的效果(绿色条和红色条),这也说明了仅仅通过单个客户端的数据,不能够 很好的学习到数据的全局分布特性,模型的泛化能力较差。
- 此外,对于每一轮 参与联邦训练的客户端数目(k 值)不同,其性能也会有一定的差别,k 值越大,每一轮参与训练的客户端数目越多,其性能也会越好,但每一轮的完成时间也会相对较长。