联邦学习作为一个机器学习框架,能够使不同地点的设备在满足用户隐私等要求下,一起进行数据建模和模型构建。最近,人们希望机器能够像人类一样不断学习新的知识而不会对旧的知识产生灾难性遗忘。而在现实生活中对一个新知识的学习往往是通过和其他人一起互相交流学习而来,例如:书籍,视频等。因此,本课题将利用联邦学习的框架和持续学习相结合,使得各个设备能够互相学习,借鉴彼此已有的知识增强训练效果。
相关准备
- Windows
- Python 3.7+
- PyTorch 1.9+
- CUDA 10.2+
准备虚拟环境
-
准备conda环境并进行激活.
conda create -n FedCL python=3.7 conda active FedCL
-
在官网安装对应版本的pytorch
可以在终端直接使用官方提供的code
-
安装FedCL所需要的包
git clone https://github.com/luopanyaxin/Federated-Continual-Learning pip install -r requirements.txt
Note: 安装前请确保电脑上是否有显卡且显卡算力和pytorch版本匹配
FedKNOW 根据下面的命令进行运行:
python main_WEIT.py --dataset [dataset] --model [mdoel] --num_users [num_users]
--shard_per_user [shard_per_user] --frac [frac] --local_bs [local_bs] --lr [lr]
--task [task] --epoch [epoch] --local_ep [local_ep] --gpu [gpu]
参数解释:
dataset
: 数据集,例如:cifar100
,FC100
,CORe50
,SVHN
,mnist
model
: 网络模型,例如:6-Layers CNN
,ResNet18
num_users
: 客户端数量shard_per_user
: 每个客户端拥有的类frac
:每一轮参与训练的客户端local_bs
:每一个客户端的batch_sizelr
:学习率task
:任务数epochs
: 客户端和服务器通信的总次数local_ep
:本地客户端迭代数gpu
:GPU ID
完整的参数信息解释在utils/option.py
。
数据集:
- Cifar100: Cifar100共包含50000条训练数据和10000条测试数据,共100个类。在持续学习中,我将其分为10个任务,每个任务包含10类。
- FC100: FC100共包含50000条训练数据和10000条测试数据,共100个类。在持续学习中,我将其分为10个任务,每个任务包含10类。
- MNIST: MNIST数据集包含60000条训练数据和10000条测试数据,共10个类。在持续学习中,将这些数据按照随机序列进行重新排序生成5个任务,每个任务包括10类。
- SVHN: SVHN数据集包含73257条训练数据和26032条测试数据,共10个类,为了使各个类的数据数量一致,在训练集中每个类选取4500个数据样本,在测试集中选取1500个测试样本。在持续学习中,将这些数据按照随机序列进行重新排序生成5个任务,每个任务包括10类。 在联邦学习中,每个任务数据我们利用non-iid的方式分配给20个客户端。
模型:
- 6_layer CNN: 参考AGS-CL实现的模型,共包含6个卷积层和2个全连接层。
- ResNet : 参考pytorch官方提供的resnet18进行的改动。
Baseline:
- GEM: 持续学习中的经典算法,通过存储部分样本后通过梯度的旋转来防止遗忘。
- Co2L: 2021年提出的最新夫人持续学习算法,通过使用对比学习目标学习表征,再通过自监督蒸馏方式来保留表征方式来防止遗忘。
- FedAvg: 联邦学习经典算法,将各个客户端的参数加权平均算法。
- APFL: 个性化联邦学习算法,通过设定参数比例来权衡各个客户端模型参数聚合参数,防止数据异构导致的算法发散。
- FedRep: 个性化联邦学习,划分每个模型参与全局聚合的层以及本地训练的层,之后通过冻结梯度的方式依次更新对应部分的层参数,防止数据异构导致的算法发散。
实验设置表:
Devices | Models and data | Baselines |
---|---|---|
Windows |
6-layer CNN on CIFAR100 6-layer CNN on FC100 6-layer CNN on MNIST ResNet18 on SVHN |
GEM Co2L FedAvg APFL FedRep |
- 6-layer CNN on Cifar100
python main_WEIT.py --epochs=150 --round=15 --num_users=20 --frac=0.4 shard_per_user=5 per--model=6_layer_CNN --dataset=cifar100 --num_classes=100 --task=10 lr=0.001
- 6-layer CNN on FC100
python main_WEIT.py --epochs=150 --round=15 --num_users=20 --frac=0.4 shard_per_user=5 per--model=6_layer_CNN --dataset=FC100 --num_classes=100 --task=10 lr=0.001
- 6-layer CNN on MNIST
python main_WEIT.py --epochs=75 --round=15 --num_users=20 --frac=0.4 shard_per_user=5 per--model=6_layer_CNN --dataset=mnist --num_classes=100 --task=5 lr=0.001
- ResNet18 on SVHN
python main_WEIT.py --epochs=75 --round=15 --num_users=20 --frac=0.4 shard_per_user=5 per--model=6_layer_CNN --dataset=SVHN --num_classes=100 --task=5 lr=0.001