/SemMAE

[NeurIPS 2022] code for the paper, SemMAE: Semantic-guided masking for learning masked autoencoders

Primary LanguagePython

Introduction

Paper accepted at NeurIPS 2022.

This is a official repository of SemMAE. Our code references the MAE, thanks a lot for their outstanding work! For details of our work see Semantic-Guided Masking for Learning Masked Autoencoders.

Citation

@article{li2022semmae,
  title={SemMAE: Semantic-Guided Masking for Learning Masked Autoencoders},
  author={Li, Gang and Zheng, Heliang and Liu, Daqing and Wang, Chaoyue and Su, Bing and Zheng, Changwen},
  journal={arXiv preprint arXiv:2206.10207},
  year={2022}
}

This implementation is in PyTorch+GPU.

  • This repo is based on timm==0.3.2, for which a fix is needed to work with PyTorch 1.8.1+.
  • It maybe needed for the repository: tensorboard. It can be installed by 'pip install '.

Process ImageNet dataset(including part mask and pixel values).

size 16x16 patch 8x8 patch
link download pwd:1tum
md5 losed waiting

Pretrained models

800-epochs ViT-Base 16x16 patch ViT-Base 8x8 patch
pretrained checkpoint download download
md5 1482ae 322b6a

Evaluation

As a sanity check, run evaluation using our ImageNet fine-tuned models:

800-epochs ViT-Base 16x16 patch ViT-Base 8x8 patch
fine-tuned checkpoint download download
md5 bbc5ef 6abd9e
reference ImageNet accuracy 83.352 84.444

Evaluate ViT-Base_16 in a single GPU (${IMAGENET_DIR} is a directory containing {train, val} sets of ImageNet):

python main_finetune.py --eval --resume SemMAE_epoch799_vit_base_checkpoint-99.pth --model vit_base_patch16 --batch_size 16 --data_path ${IMAGENET_DIR}

This should give:

* Acc@1 83.352 Acc@5 96.494 loss 0.745
Accuracy of the network on the 50000 test images: 83.4%

Evaluate ViT-Base_8 in a single GPU (${IMAGENET_DIR} is a directory containing {train, val} sets of ImageNet):

python main_finetune.py --eval --resume SemMAE_epoch799_vit_base_checkpoint_patch8-78.pth --model vit_base_patch8 --batch_size 8 --data_path ${IMAGENET_DIR}

This should give:

* Acc@1 84.444 Acc@5 97.032 loss 0.683
Accuracy of the network on the 50000 test images: 84.44%. 

Note that all of our results are obtained on the pretraining 800-epoches setting, the best checkpoint is lost for vit_base_patch8(The paper reported a performance of 84.5% top-1 acc vs. 84.44% in 78-th epoch).

Pre-training

To pre-train ViT-Large (recommended default) with multi-node distributed training, run the following on 8 nodes with 8 GPUs each:

python -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \
        --nnodes=${NNODES} --node_rank=\${SLURM_NODEID} --master_addr=${MASTER_ADDR} \
        --use_env main_pretrain_setting3.py \
        --output_dir ${OUTPUT_DIR} --log_dir=${OUTPUT_DIR} \
        --batch_size 128 \
        --model mae_vit_base_patch16 \
        --norm_pix_loss \
        --mask_ratio 0.75 \
        --epochs 800 \
        --warmup_epochs 40 \
        --blr 1.5e-4 --weight_decay 0.05 \
        --setting 3 \
        --data_path ${DATA_DIR}

Note that the input path ${DATA_DIR} is our processed dataset path.

Contact

This repo is currently maintained by Gang Li(@ucasligang).