/baseline-pretraining

Code for pre-training BabyLM baseline models.

Primary LanguagePython

Environment

Python 3.9, transformer package in huggingface, and datasets package in huggingface.

And also install: https://github.com/chengxuz/pt_framework

Install the current repo using pip install . or pip install -e ..

Where to put data

First, define the environment variable BABYLM_ROOT_DIR to be where your models and data will live. The downloaded data should be put at ${BABYLM_ROOT_DIR}/datasets/ so that this folder contains the following four subfolders: babylm_100M, babylm_10M, babylm_dev, and babylm_test. Note that the T5 training script expects .txt file inputs, so we create a single dev file by running this command in the ${BABYLM_ROOT_DIR}/datasets/babylm_dev/ folder: cat *.dev > babylm_dev.txt. The trained models will be put at ${BABYLM_ROOT_DIR}/models/ and the records will be put at ${BABYLM_ROOT_DIR}/model_recs/.

Training Command

OPT-125M

Run the following command under the scripts folder.

python -m torch.distributed.launch --nproc_per_node=1 --master_port=29123 general_train.py --setting "BabyLM/exp_strict.py:opt125m_s1"

This command will load a training setting specified by function opt125m_s1 at src/babylm_baseline_train/configs/BabyLM/exp_strict.py.

RoBERTa-Base

Run the following command under the scripts folder.

python -m torch.distributed.launch --nproc_per_node=1 --master_port=29123 general_train.py --setting "BabyLM/exp_strict_mask.py:roberta_s1"

T5-Base

Run the following command under the scripts folder.

./train_t5_babylm.sh

Note that this training script uses a different backend than the OPT and RoBERTa models. This script is a slightly modified version of the flax T5 pre-training script from huggingface; the original lives here.

Where important parameters are defined

Learning rate schedule is defined at function get_learning_rate_params in script basic_param_setter.py under src/babylm_baseline_train folder.

Optimizer is in the scripts/general_train.py script inside the get_key_params funciton.

How to load the pretrained models

See the functions in src/babylm_baseline_train/models/ckpt_loader.py.

Questions?

Feel free to open issues here. Or just contact us through Slack/emails.