/FastMIM.pytorch

FastMIM, official pytorch implementation of our paper "FastMIM: Expediting Masked Image Modeling Pre-training for Vision"(https://arxiv.org/pdf/2212.06593.pdf).

Primary LanguagePython

Comparison among the MAE, SimMIM and our FastMIM framework. MAE randomly masks and discards the input patches. Although there is only small amount of encoder patches, MAE can only be used to pre-train the isotropic ViT which generates single-scale intermediate features. SimMIM preserves input resolution and can serve as a generic framework for all kinds of vision backbones, but it needs to tackle with large amount of patches. Our FastMIM simply reduces the input resolution and replaces the pixel target with HOG target. These modifications are simple yet effective. FastMIM (i) pre-train faster; (ii) has a lighter memory consumption; (iii) can serve as a generic framework for all kinds of architectures; and (iv) achieves comparable and even better performances compared to previous methods.

Set up

- python==3.x
- cuda==10.x
- torch==1.7.0+
- mmcv-full-1.4.4+

# other pytorch/cuda/timm version can also work

# To pip your environment
sh requirement_pip_install.sh

# build your apex (optional)
cd /your_path_to/apex-master/;
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is:

│path/to/imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Pre-training on ImageNet-1K

ViT-B

To train ViT-B on ImageNet-1K on a single node with 8 gpus:

python -m torch.distributed.launch --nproc_per_node=8 main_pretrain.py --model mim_vit_base --data_path /your_path_to/data/imagenet/ --epochs 800 --warmup_epochs 20 --blr 1.5e-4 --weight_decay 0.05 --output_dir /your_path_to/fastmim_pretrain_output/ --batch_size 512 --save_ckpt_freq 100 --num_workers 10 --mask_ratio 0.75 --norm_pix_loss --rrc_scale 0.2 1.0 --input_size 128 --decoder_embed_dim 256 --decoder_depth 1 --block_size 16 --mim_loss HOG
Swin-B

To train Swin-B on ImageNet-1K on a single node with 8 gpus:

python -m torch.distributed.launch --nproc_per_node=8 main_pretrain.py --model mim_swin_base --data_path /your_path_to/data/imagenet/ --epochs 400 --warmup_epochs 10 --blr 1.5e-4 --weight_decay 0.05 --output_dir /your_path_to/fastmim_pretrain_output/ --batch_size 256 --save_ckpt_freq 50 --num_workers 10 --mask_ratio 0.75 --norm_pix_loss --input_size 128 --rrc_scale 0.2 1.0 --window_size 4 --decoder_embed_dim 256 --decoder_depth 4 --mim_loss HOG --block_size 32

Finetuning on ImageNet-1K

ViT-B

To fine-tune ViT-B on ImageNet-1K on a single node with 8 gpus:

python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py --model vit_base_patch16 --data_path /your_path_to/data/imagenet/ --batch_size 128 --accum_iter 1 --epochs 100 --blr 6e-4 --layer_decay 0.70 --weight_decay 0.05 --drop_path 0.1 --dist_eval --finetune /your_path_to_ckpt/checkpoint-799.pth --output_dir /your_path_to/fastmim_finetune_output/
Swin-B

To fine-tune Swin-B on ImageNet-1K on a single node with 8 gpus:

python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py --model swin_base_patch4_window7_224 --data_path /your_path_to/data/imagenet/ --batch_size 128 --epochs 100 --blr 1.0e-3 --layer_decay 0.80 --weight_decay 0.05 --drop_path 0.1 --dist_eval --finetune /your_path_to_ckpt/checkpoint-399.pth --output_dir /your_path_to/fastmim_finetune_output/

Notice

We build our object detection and sementic segmentation codebase upon mmdet-v2.23 and mmseg-v0.28, however, we also add some features from the updated mmdet version (e.g., simple copy-paste) into our mmdet-v2.23. If you directly download the mmdet-v2.23 from MMDet, the code may report some errors.

Results and Models

Classification on ImageNet-1K (ViT-B/Swin-B/PVTv2-b2/CMT-S)

Model #Params PT Res. PT Epoch PT log/ckpt FT Res. FT log/ckpt Top-1 (%)
ViT-B 86M 128x128 800 log/ckpt 224x224 log/ckpt 83.8
Swin-B 88M 128x128 400 log/ckpt 224x224 log/ckpt 84.1
PVTv2-B2 25M 128x128 800 224x224 ckpt 82.5
CMT-S 25M 128x128 800 224x224 ckpt 83.9

Object Detection on COCO (Swin-B based Mask R-CNN)

Model Backbone Pretrain Lr schd box AP mask AP Config Checkpoint
Mask R-CNN Swin-B SimMIM 3x 52.3 46.4 config log/ckpt
Mask R-CNN Swin-B FastMIM 3x 52.0 46.0 config log/ckpt

Semantic Segmentation on ADE20K (ViT-B based UPerNet)

Model Backbone Pretrain Crop Size Batch Size Lr schd mIoU(ss) Config Checkpoint
UPerNet ViT-B FastMIM 512x512 16 160000 49.5 config log/ckpt

Citation

If you find this project useful in your research, please consider cite:

@article{guo2022fastmim,
  title={FastMIM: Expediting Masked Image Modeling Pre-training for Vision},
  author={Guo, Jianyuan and Han, Kai and Wu, Han and Tang, Yehui and Wang, Yunhe and Xu, Chang},
  journal={arXiv preprint arXiv:2212.06593},
  year={2022}
}

Acknowledgement

The classification task in this repo is based on MAE, SimMIM, SlowFast and timm.

The object detection task in this repo is baesd on MMDet, ViDet and Swin-Transformer-Object-Detection.

The semantic segmentation task in this repo is baesd on MMSeg and BEiT.

License

License: MIT