PyTorch implementation of our paper 'TransMatch: a transfer learning scheme for semi-supervised few-shot learning'. The following code is for miniImagenet.
- Python 3.5
- Pytorch 0.4+
-
Download miniImageNet: download link
-
Download pretrained model on base-class data: download link
Then Unzip the MY_mini_data.zip
.
The project folder should look like:
main.py
mini_loader.py
wide_models.py
wideresnet_legacy.py
pretrained_model_on_base_class.pth.tar
utils
MY_mini_data
| base_data
| novel_data
└─── miniImagenet_base_novel
└───base
└───novel
Please check main.py for the details.
Notice:
- The corresponding results are stored in
final_result.csv
(totally 600 test results) - Imprinting will be stored in
imprinting_result.csv
(totally 600 test results) - logXXX.png is the plot for the test of 0 to XXX; each line stands for the change of test accuracy over epochs
- Fine-tuning for each test experiment may last for 10-20 minutes, so it may take 4-6 days to finish all 600 test experiments. You could also just let the code run for 100 test experiments. The results are similar.
python main.py --gpu 0 --num-way 5 --num-sample 5 --unlabelnumber 100 --epoch 25 --checkpoint TransMatch_5_5_100
python main.py --gpu 0 --random --num-way 5 --num-sample 5 --unlabelnumber 100 --epoch 25 --checkpoint MixMatch_5_5_100
python main.py --gpu 0 --num-way 5 --num-sample 5 --unlabelnumber 100 --distractor --distractor_class 2 --epoch 25 --checkpoint TransMatch_5_5_100_distractor_2
Learning_Rate Train_Loss Train_Loss_X Train_Loss_U Valid_Loss Valid_Acc
0.001000 0.908974 0.776229 0.026549 1.061434 56.000000
0.001000 0.768066 0.660366 0.021540 0.905195 66.666667
0.001000 0.813665 0.708609 0.021011 0.946537 65.333333
0.001000 0.765839 0.673886 0.018390 0.896166 66.666667
0.001000 0.705718 0.611745 0.018795 0.878929 69.333333
...
0.001000 0.639709 0.563627 0.015217 0.803816 72.000000
The change of test accuracy over 25 epochs for 600 experiments will be like (each line represents one experiment)
python main.py --gpu 0 --num-way 5 --num-sample 1 --unlabelnumber 100 --epoch 20 --checkpoint TransMatch_5_1_100
python main.py --gpu 0 --random --num-way 5 --num-sample 1 --unlabelnumber 100 --epoch 20 --checkpoint MixMatch_5_1_100
python main.py --gpu 0 --num-way 5 --num-sample 1 --unlabelnumber 100 --distractor --distractor_class 2 --epoch 20 --checkpoint TransMatch_5_1_100_distractor_2
(%) | Imprinting | MixMatch | TranMatch (Ours) |
---|---|---|---|
1-shot | 58.68 ± 0.81 |
52.00 ± 1.20 |
63.02 ± 1.07 |
5-shot | 76.06 ± 0.59 |
79.97 ± 0.62 |
81.19 ± 0.59 |
(%) | Imprinting | MixMatch | TranMatch (Ours) |
---|---|---|---|
1-shot | 58.68 ± 0.81 |
50.68 ± 1.15 |
60.41 ± 1.02 |
5-shot | 76.06 ± 0.59 |
78.07 ± 0.69 |
79.48 ± 0.64 |