Spatial-Channel Token Distillation for Vision MLPs
A PyTorch implementation of the paper Spatial-Channel Token Distillation for Vision MLPs. This project codebase is mostly based on the codebase of DeiT from Facebook Research and built on top of it with according changes, additions or removals.
This project and the repository is an outcome of collective paper implementation work conducted under METU-CENG502. Refer to STD implementation on CENG502 to see the full and detailed report.
Installation
For installation create the environment by executing the following cmd in the project root
conda env create -f environment.yml
Training
Start training with
python -m std.main --batch-size 128 --input-size 32 --patch-size 4 --model std-mlp-mixer --depth 8 --data-set CIFAR --data-path path/to/data --output_dir path/to/checkpoint_dir --teacher-model resnet32 resnet56
This will instantiate the run with last layer distillation only, to enable intermediate distillation, pass --distill-intermediate
. The additional arguments can also be accessed, to see all arguments use the following command,
python -m std.main --help
Evaluation
To evaluate a model, use the following command on appropriate model and arguments
python -m std.main --eval --resume path/to/model_folder --distillation-type none --teacher-model resnet32 resnet50 --model std-mlp-mixer --patch-size 4 --input-size 32 --data-set CIFAR --data-path path/to/dataset
This should give the following output for model STD-56
* Acc@1 76.850 Acc@5 94.170 loss 0.871
Accuracy of the network on the 10000 test images: 76.9
one important thing to notice here, if the model is trained with multiple-teacher setting, then you must pass --teacher-model
argument accordingly to supply correct teacher count (all multi-teacher settings in the experiments were 2 teacher setting). Alternatively, instead of the model names you can pass anything (i.e. --teacher-model 1 2
would work). Since this is evaluation only, instantiation of the teacher models do not take place, but this will inform the STD model to instantiate with correct layers and tokens, so that the model can be loaded correctly.
Notes
Notes for implementation.
Notes regarding MINE Regularization implementation:
The setup regarding the MINE regularization is not explicitly mentioned in the paper. There are mainly four parts that we set on our assumptions in the implementation:
- Learning Rate: The learning rate for the updates at Algorithm 1 (see 2.1.2.) is not mentioned. With some small experiments, we set and fixed the learning rate for MINE updates for both statistics network and the vision model as 0.01. We set this lr (as a bit high) due to the assumption of the sample size for the regularization.
-
Sample size: There is no explicit information in the paper regarding on how many samples this regularization has been done. We set this as a tunable argument in our implementation and set the default value as the batch size of the vision model. Thus, there are N samples used in each epoch for regularization where N is the original batch size used to train the vision model. Since the update uses a single batch and obviously # of all batches >> 1, for this regularization to have an effect we set the learning rate for regularization accordingly (a bit high compared to learning rate of the vision model).
- Note: As this can be tuned to have higher, currently there is no data loader for this, and hence for higher sample size and with available memory limits, the training could potentially fail. Data loader for this part may come.
- Statistics Network: In the paper regarding the architecture of the statistics network, the information is given as 3 layer MLP with 512 dims. We assumed GELU as the activation function for all layers (same as in the MLP-Mixer layers), and assumed that there are no additional operations applied to the layers.
-
Selecting Samples: For selecting joint samples (paired tokens) and marginal samples (unpaired tokens), we implemented a very naive way of derangement. The way follows the idea of shifting sample indices to 1 index right (i.e. indicies p=0,1,2,3,4 are used for both
$T_S$ and$T_C$ , and for unpaired tokens ($\overline{T_C}$ ) the indices becomes u=1,2,3,4,0.). This naive way is simple and guarantees derangement, but more complicated algorithms may be found. Moreover, this is safe to use as we collect the samples for regularization during training and effectively select a single instance from each batch and batches are randomly shuffled in each epoch.
Contribution
To check if codestyle pass use
python -m scripts.run_code_style check
To reformat the codebase use
python -m scripts.run_code_style format
For easier setup, you can alternatively use the conda command
conda develop <path>
where <path>
is the project root folder (not the source folder).
License
This work contains the implementation of the methodology and study presented in the Spatial-Channel Token Distillation for Vision MLPs paper. Also as the building block of the codebase, facebookresearch/deit is used, modified and adapted accordingly when necessary. The work here is licensed under MIT License (extending deit repository Apache license).