This is an updated
PyTorch/GPU re-implementation of the paper Masked Autoencoders Are Scalable Vision Learners for consumer GPU user for learning purpose.
- Updated to latest Torch and Timm
- Use Imagenette as the default dataset so that you can run the training in a consumer GPU to debug the code immediately without downloading the huge Imagenet
pip install maskedautoencoder
git checkout https://github.com/henrywoo/mae/
cd mae
pip install -r requirements.txt
bash run.sh
Screenshot of training it with a 4G GPU laptop:
One-line change to replace ImageNette with ImageNet1K:
Repalce
dataset_train = get_cv_dataset(path=DS_PATH_IMAGENETTE, transform=transform_train, name="full_size")
with
dataset_train = get_cv_dataset(path=DS_PATH_IMAGENET1K, transform=transform_train)
- Visualization demo
- Pre-trained checkpoints + fine-tuning code
- Pre-training code
Run our interactive visualization demo using Colab notebook (no GPU needed):
The following table provides the pre-trained checkpoints used in the paper, converted from TF/TPU to PT/GPU:
ViT-Base | ViT-Large | ViT-Huge | |
---|---|---|---|
pre-trained checkpoint | download | download | download |
md5 | 8cad7c | b8b06e | 9bdbb0 |
The fine-tuning instruction is in FINETUNE.md.
By fine-tuning these pre-trained models, we rank #1 in these classification tasks (detailed in the paper):
ViT-B | ViT-L | ViT-H | ViT-H448 | prev best | |
---|---|---|---|---|---|
ImageNet-1K (no external data) | 83.6 | 85.9 | 86.9 | 87.8 | 87.1 |
following are evaluation of the same model weights (fine-tuned in original ImageNet-1K): | |||||
ImageNet-Corruption (error rate) | 51.7 | 41.8 | 33.8 | 36.8 | 42.5 |
ImageNet-Adversarial | 35.9 | 57.1 | 68.2 | 76.7 | 35.8 |
ImageNet-Rendition | 48.3 | 59.9 | 64.4 | 66.5 | 48.7 |
ImageNet-Sketch | 34.5 | 45.3 | 49.6 | 50.9 | 36.0 |
following are transfer learning by fine-tuning the pre-trained MAE on the target dataset: | |||||
iNaturalists 2017 | 70.5 | 75.7 | 79.3 | 83.4 | 75.4 |
iNaturalists 2018 | 75.4 | 80.1 | 83.0 | 86.8 | 81.2 |
iNaturalists 2019 | 80.5 | 83.4 | 85.7 | 88.3 | 84.1 |
Places205 | 63.9 | 65.8 | 65.9 | 66.8 | 66.0 |
Places365 | 57.9 | 59.4 | 59.8 | 60.3 | 58.0 |
The pre-training instruction is in PRETRAIN.md.
This project is under the CC-BY-NC 4.0 license. See LICENSE for details.
- The original version: PyTorch Version
- Other version: TF, MAE-pytorch 1, MAE-pytorch 2