git clone https://github.com/Seeeeeyo/mae.git
0.1)
Use python 3.8 (should work). I tried only with python 3.10 but had to to the following changes:
In “/usr/local/lib/python3.10/dist-packages/timm/models/layers/helpers.py”, Add
import torch
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
from torch._six import container_abcs
else:
import collections.abc as container_abcs
In “/content/mae/util/misc.py”, change
from torch._six import inf
By
from torch import inf
1) *Download the data from the drive 'data_60k' (full medmnist dataset) and 'data_sampled.zip' (fractions of medmnist dataset). *Download a medical classification dataset. I let you choose one, as you know these better than I do. Let's try to find one which is kind of similar to MedMnist to hopefully reach some performances. The data structure should be as follow: - eval_data - train -class1 -img1 -img2 -... -class2 -...
cd mae
wget -nc https://dl.fbaipublicfiles.com/mae/finetune/mae_pretrain_vit_base.pth
!pip install submitit
!pip install timm==0.3.2
- Evaluate the mae_vit_base on eval_data
python main_finetune.py --eval --resume mae_finetuned_vit_base.pth --model vit_base_patch16 --batch_size 32 --data_path 'eval_data'
- FINETUNE
python main_finetune.py \
--accum_iter 1 \
--batch_size 32 \
--model vit_base_patch16 \
--finetune 'mae_pretrain_vit_base.pth' \
--epochs 50 \
--blr 5e-4 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
--dist_eval --data_path 'AGI/data_60k'
OR (couldn't try it without cluster and multiple GPU):
python submitit_finetune.py \
--job_dir ${JOB_DIR} \
--nodes 4 \ TO CHANGE
--batch_size 32 \
--model vit_base_patch14 \
--finetune 'mae_pretrain_vit_base.pth' \
--epochs 50 \
--blr 1e-3 --layer_decay 0.75 \
--weight_decay 0.05 --drop_path 0.3 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
--dist_eval --data_path 'AGI/data_60k'
OR (if running on 1 node with 8 GPUs. Couldn't try it without cluster and multiple GPU):
MP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \
--accum_iter 4 \
--batch_size 32 \
--model vit_base_patch16 \
--finetune 'mae_pretrain_vit_base.pth' \
--epochs 50 \
--blr 5e-4 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
--dist_eval --data_path 'AGI/data_60k'
model_path = 'mae/output_dir/checkpoint-49.pth'
- Evaluate the finetune mae_vit_base on eval_data
python main_finetune.py --eval --resume {model_path} --model vit_base_patch16 --batch_size TODO --data_path 'eval_data'
- Repeat 6) and 7) for the different dataset sizes ('data_sampled_6k', 'data_sampled_36k', 'data_sampled_600').
IF NEEDED
9)a) In case the results are shitty, we might need to pretrain the model on MedMnist and then finetune on the eval_data.
python submitit_pretrain.py \
--job_dir ${JOB_DIR} \
--nodes 8 \
--use_volta32 \
--batch_size 64 \
--model mae_vit_large_patch16 \
--norm_pix_loss \
--mask_ratio 0.75 \
--epochs 800 \
--warmup_epochs 40 \
--blr 1.5e-4 --weight_decay 0.05 \
--data_path 'AGI/data_60k'
9)b)
python main_finetune.py \
--accum_iter 1 \
--batch_size 32 \
--model vit_base_patch16 \
--finetune 'mae_pretrain_vit_base.pth' \
--epochs 50 \
--blr 5e-4 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
--dist_eval --data_path 'eval_data'
This is a PyTorch/GPU re-implementation of the paper Masked Autoencoders Are Scalable Vision Learners:
@Article{MaskedAutoencoders2021,
author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Doll{\'a}r and Ross Girshick},
journal = {arXiv:2111.06377},
title = {Masked Autoencoders Are Scalable Vision Learners},
year = {2021},
}
-
The original implementation was in TensorFlow+TPU. This re-implementation is in PyTorch+GPU.
-
This repo is a modification on the DeiT repo. Installation and preparation follow that repo.
-
This repo is based on
timm==0.3.2
, for which a fix is needed to work with PyTorch 1.8.1+.
- Visualization demo
- Pre-trained checkpoints + fine-tuning code
- Pre-training code
Run our interactive visualization demo using Colab notebook (no GPU needed):
The following table provides the pre-trained checkpoints used in the paper, converted from TF/TPU to PT/GPU:
ViT-Base | ViT-Large | ViT-Huge | |
---|---|---|---|
pre-trained checkpoint | download | download | download |
md5 | 8cad7c | b8b06e | 9bdbb0 |
The fine-tuning instruction is in FINETUNE.md.
By fine-tuning these pre-trained models, we rank #1 in these classification tasks (detailed in the paper):
ViT-B | ViT-L | ViT-H | ViT-H448 | prev best | |
---|---|---|---|---|---|
ImageNet-1K (no external data) | 83.6 | 85.9 | 86.9 | 87.8 | 87.1 |
following are evaluation of the same model weights (fine-tuned in original ImageNet-1K): | |||||
ImageNet-Corruption (error rate) | 51.7 | 41.8 | 33.8 | 36.8 | 42.5 |
ImageNet-Adversarial | 35.9 | 57.1 | 68.2 | 76.7 | 35.8 |
ImageNet-Rendition | 48.3 | 59.9 | 64.4 | 66.5 | 48.7 |
ImageNet-Sketch | 34.5 | 45.3 | 49.6 | 50.9 | 36.0 |
following are transfer learning by fine-tuning the pre-trained MAE on the target dataset: | |||||
iNaturalists 2017 | 70.5 | 75.7 | 79.3 | 83.4 | 75.4 |
iNaturalists 2018 | 75.4 | 80.1 | 83.0 | 86.8 | 81.2 |
iNaturalists 2019 | 80.5 | 83.4 | 85.7 | 88.3 | 84.1 |
Places205 | 63.9 | 65.8 | 65.9 | 66.8 | 66.0 |
Places365 | 57.9 | 59.4 | 59.8 | 60.3 | 58.0 |
The pre-training instruction is in PRETRAIN.md.
This project is under the CC-BY-NC 4.0 license. See LICENSE for details.