This repository aims to provide a simple implementation, inspired by Karpathy's nanoGPT repo, of the Joint-Embedding Predictive Architecture (JEPA) in PyTorch for text data. The project takes inspiration from I-JEPA and V-JEPA, with some differences. All targets for a single masked sample are packed into the same input of the predictor. The predictor context inputs do not attend to the target embeddings, while the target embeddings attend to the context and their respective target chunks. The target masking strategy is taken from the UL2 paper. With the provided configs the model is trainable (no mode collapse) using the TinyStories dataset. Will need to add testing and evaluation scripts.
install python 3.11 with pyenv
pyenv install 3.11
Create a pyenv virtual environment with python 3.11
pyenv virtualenv 3.11 tjepa
Activate the virtual environment
pyenv activate tjepa
Optional (but recommended): Set the virtual environment to be automatically activated when entering the project directory
pyenv local tjepa
Install the required packages
pip install -r requirements.txt
Install the rotary embeddings package (which supports custom indexing)
git clone https://github.com/AlxSp/rotary-embedding-torch
pip install -e ./rotary-embedding-torch
python train.py
you can ever modify the input arguments to the train.py script to match or add them to a set of config files and pass the files as arguments to the train.py script (see files configs/tiny_Stories_training
).
--init_from
(str, default: 'scratch'): init fromscratch
orresume
--encoder_config_path
(str, required): path to the encoder config (if init_from == 'resume' is loaded from the training directory)--predictor_config_path
(str, required): path to the predictor config--opt_config_path
(str, required): path to the optimizer config--train_run_config_path
(str, required): path to the train run config--target_masking_strategies_path
(str, required): path to the target masking strategies
target_block_size
(int, required): mean of the target block sizemax_target_mask_ratio
(float, default: 0.15): maximum ratio of the target to be maskedmax_target_mask_start_ratio
(float, default: 0.15): maximum ratio of the target start to be masked
run the prepare.py script to download and prepare the dataset in data/TinyStories
python data/TinyStories/prepare.py
TODO: add the prepare.py script to download and prepare the dataset in data/SlimPajama
@article{assran2023self,
title={Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture},
author={Assran, Mahmoud and Duval, Quentin and Misra, Ishan and Bojanowski, Piotr and Vincent, Pascal and Rabbat, Michael and LeCun, Yann and Ballas, Nicolas},
journal={arXiv preprint arXiv:2301.08243},
year={2023},
github={https://github.com/facebookresearch/ijepa}
}
@article{bardes2024revisiting,
title={Revisiting Feature Prediction for Learning Visual Representations from Video},
author={Bardes, Adrien and Garrido, Quentin and Ponce, Jean and Rabbat, Michael, and LeCun, Yann and Assran, Mahmoud and Ballas, Nicolas},
journal={arXiv preprint},
year={2024}
}
@misc{tay2023ul2,
title={UL2: Unifying Language Learning Paradigms},
author={Yi Tay and Mostafa Dehghani and Vinh Q. Tran and Xavier Garcia and Jason Wei and Xuezhi Wang and Hyung Won Chung and Siamak Shakeri and Dara Bahri and Tal Schuster and Huaixiu Steven Zheng and Denny Zhou and Neil Houlsby and Donald Metzler},
year={2023},
eprint={2205.05131},
archivePrefix={arXiv},
primaryClass={cs.CL}
}