/ijepa

I-JEPA finetuning recipe

Primary LanguagePythonOtherNOASSERTION

I-JEPA Finetuning

PyTorch codebase for finetuning I-JEPA based on the Masked Autoencoders (MAE) finetuning recipe.

Features

  • - Randaugment
  • - Cutmix
  • - Mixup
  • - Gradient Accumulation
  • - Label smoothing
  • - Drop path
  • - Average-pooled patch representation of the last layer
  • - Concatenation of the last 4 layers of the average-pooled patch representations
  • - Layer wise decay (help appreciated)

In addition to this we also provide pre-trained ViT-L weights over ImageNet-1k here.

Launching I-JEPA finetuning

In order to launch finetuning you can either run the finetune.sh script or launch the command below. The default settings for the features above can be found at configs/in1k_vith14_ep300_finetuning.yaml.

python main_finetuning.py \
  --fname configs/in1k_vith14_ep300_finetuning.yaml \
  --devices cuda:0 cuda:1 cuda:2

Disclaimer

Some of those settings were set for a ViT-H model, and should be changed accordingly, see the paper appendix (https://arxiv.org/pdf/2111.06377) and the recommended configurations for finetuning at page 11.

Randaugment, gradient accumulation and other settings that couldn't be found on the .yaml file can be set directly on the engine_finetune.py file.

If one chooses to use the concatenation of the last 4 layers of the average-pooled patch representations as input to a classifier it has to modify the default forward function and adjust the input size at the classification linear head at src/helper.py.

Other changes to the classification pipeline can be performed at the classification model class at src/helper.py as well.

Method

I-JEPA is a method for self-supervised learning. At a high level, I-JEPA predicts the representations of part of an image from the representations of other parts of the same image. Notably, this approach learns semantic image features:

  1. without relying on pre-specified invariances to hand-crafted data transformations, which tend to be biased for particular downstream tasks,
  2. and without having the model fill in pixel-level details, which tend to result in learning less semantically meaningful representations.

ijepa

Evaluations

I-JEPA pretraining is also computationally efficient. It does not involve any overhead associated with applying more computationally intensive data augmentations to produce multiple views. Only one view of the image needs to be processed by the target encoder, and only the context blocks need to be processed by the context encoder. Empirically, I-JEPA learns strong off-the-shelf semantic representations without the use of hand-crafted view augmentations.

1percenteval lineareval

Pretrained models

arch. patch size resolution epochs data download
ViT-L 14x14 224x224 150 ImageNet-1K full checkpoint logs configs
ViT-H 14x14 224x224 300 ImageNet-1K full checkpoint logs configs
ViT-H 16x16 448x448 300 ImageNet-1K full checkpoint logs configs
ViT-H 14x14 224x224 66 ImageNet-22K full checkpoint logs configs
ViT-g 16x16 224x224 44 ImageNet-22K full checkpoint logs configs

Code Structure

.
├── configs                   # directory in which all experiment '.yaml' configs are stored
├── src                       # the package
│   ├── train.py              #   the I-JEPA training loop
│   ├── helper.py             #   helper functions for init of models & opt/loading checkpoint
│   ├── transforms.py         #   pre-train data transforms
│   ├── datasets              #   datasets, data loaders, ...
│   ├── models                #   model definitions
│   ├── masks                 #   mask collators, masking utilities, ...
│   └── utils                 #   shared utilities
├── main_distributed.py       # entrypoint for launch distributed I-JEPA pretraining on SLURM cluster
└── main.py                   # entrypoint for launch I-JEPA pretraining locally on your machine

Config files: Note that all experiment parameters are specified in config files (as opposed to command-line-arguments). See the configs/ directory for example config files.

Launching I-JEPA pretraining

Single-GPU training

This implementation starts from the main.py, which parses the experiment config file and runs the pre-training locally on a multi-GPU (or single-GPU) machine. For example, to run I-JEPA pretraining on GPUs "0","1", and "2" on a local machine using the config configs/in1k_vith14_ep300.yaml, type the command:

python main.py \
  --fname configs/in1k_vith14_ep300.yaml \
  --devices cuda:0 cuda:1 cuda:2

Note: This example is just used for illustrative purposes, as the ViT-H/14 config should be run on 16 A100 80G GPUs for an effective batch-size of 2048, in order to reproduce our results.

Multi-GPU training

In the multi-GPU setting, the implementation starts from main_distributed.py, which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source submitit tool and provide examples for a SLURM cluster.

For example, to pre-train on 16 A100 80G GPUs using the pre-training experiment configs specificed inside configs/in1k_vith14_ep300.yaml, type the command:

python main_distributed.py \
  --fname configs/in1k_vith14_ep300.yaml \
  --folder $path_to_save_submitit_logs \
  --partition $slurm_partition \
  --nodes 2 --tasks-per-node 8 \
  --time 1000

Requirements

  • Python 3.8 (or newer)
  • PyTorch 2.0
  • torchvision
  • Other dependencies: pyyaml, numpy, opencv, submitit

License

See the LICENSE file for details about the license under which this code is made available.

Citation

If you find this repository useful in your research, please consider giving a star ⭐ and a citation

@article{assran2023self,
  title={Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture},
  author={Assran, Mahmoud and Duval, Quentin and Misra, Ishan and Bojanowski, Piotr and Vincent, Pascal and Rabbat, Michael and LeCun, Yann and Ballas, Nicolas},
  journal={arXiv preprint arXiv:2301.08243},
  year={2023}
}