The FL training process comprises of two iterative phases, i.e., local training and global aggregation. Thus the learning performance is determined by both the effectiveness of the parameters from local training and smooth aggregation of them. However, these two requirements are not easy to satisfy in edge environment, i.e., edge users often have limited bandwidth and insufficient data, which can cause inefficient parameters aggregation, excessive training time and reduced model accuracy. FL inherently entails a large number of communication rounds and a large amount of labeled data for training, which are often unavailable for edge users. Such challenges are particularly salient under the combined effect of a long training process and unfavorable factors such as non-IID and unbalanced data, limited communication bandwidth, and unreliable and limited device availability.
We revisits the question of how FL mines the distributed data in iterative training rounds, and exploit the emerging foundation model (FM) to optimize the FL training. We investigate the behavior of the nascent model in a standard FL setting using popular off-the-shelf FMs, e.g., CLIP, and methods for FM adaptation. We propose PROMPTFL, a framework that replaces existing federated model training with prompt training, i.e., FL clients train prompts instead of a model, which can simultaneously exploit the insufficient local data and reduce the aggregation overhead. PROMPTFL ships an off-the-shelf public CLIP to users and apply continuous prompts (a.k.a. soft prompts) for FM adaptation, which requires very few data samples from edge users. The framework is technically very simple but effective.
You can run federated_main.py
with some specified arguments.
--root
takes as input a path to dataset, like caltech101
or oxford_flowers
.
--config-file
means which config file to use, such as rn50_ep50
or vit_b16_ep50
.
You can select variables like shots, users by changing cfg
or you can change every arguments you like in main_pipeline.sh
.
PromptFL (M=16, end):
If you want to train caltech100 with 2 shots, backbone rn50 and total independent non-iid setting.
You can specify that:
TRAINER=PromptFL
DATA=caltech101
SHOTS=2
REPEATRATE=0.0
and run bash main_pipeline.sh rn50_ep50 end 16 False False False
FinetuningFL:
If you want to train caltech100 with fintuning, backbone rn50 and total independent non-iid setting.
You can specify that:
TRAINER=Baseline
DATA=caltech101
SHOTS=1
REPEATRATE=0.0
and run bash main_pipeline.sh rn50_ep50 end 16 False False True
After the experiments, all the results are finished and save to output/
.
We build and modify the code based on Dassl and CoOp.
We will release the full-version and detailed description later to help faciliate the community and further study.
If this code is useful in your research, you are encouraged to cite our academic paper:
@article{guo2022promptfl,
title={PromptFL: Let Federated Participants Cooperatively Learn Prompts Instead of Models--Federated Learning in Age of Foundation Model},
author={Guo, Tao and Guo, Song and Wang, Junxiao and Xu, Wenchao},
journal={arXiv preprint arXiv:2208.11625},
year={2022}
}