๐ We are still in progress making this repo clean. Use it with caution and please report errors and questions to us.
This repository is the official implementation of our paper:
The Surprising Effectiveness of Test-Time Training for Abstract Reasoning
Ekin Akyรผrek, Mehul Damani, Linlu Qiu, Han Guo, Yoon Kim, Jacob Andreas
To install requirements, you can start a fresh conda environment, and install followings with pip:
# For TTT pipeline, we used a fork of torchtune library.
# You need to install it first
conda create -n arc python=3.10
# You can install with pip or clone and install as editable
pip install torchtune@git+https://github.com/ekinakyurek/torchtune.git@ekin/llama32
pip install torch torchao --pre --upgrade --index-url https://download.pytorch.org/whl/nightly/cu121
# Then we have simple requirements can be installed as:
pip install -r requirements.txt
๐ You need download the ARC dataset from kaggle link https://www.kaggle.com/competitions/arc-prize-2024/data
๐ We will be uploading checkpoints to the hugging face repositories very soon:
- For Llama-3.x checkpoints: https://huggingface.co/meta-llama
- For our finetuned Llama-3 8B checkpoints: https://huggingface.co/ekinakyurek/marc-8B-finetuned-llama3
- For finetuned BARC checkpoints: https://huggingface.co/barc0/Llama-3.1-ARC-Potpourri-Transduction-8B-test-time-finetune
- For our LoRA adapters for Llama-3 8B model: https://huggingface.co/ekinakyurek/marc-lora-adapters-8B-finetuned-llama3
- For our Lora adapter for BARC model: https://huggingface.co/ekinakyurek/marc-lora-adapters-Llama-3.1-ARC-Potpourri-Transduction-8B
To train the model(s) in the paper, run this command:
# Specify data path
data_file=/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json
# Specify finetuned path
base_checkpoint_dir=/path/to/finetuned/model/folder/
# Specify where TTT adapters should be saved
ttt_folder=/path/to/ttt/folder
mkdir -p $ttt_folder
# You need show an initial config file that is compatible with torchtune configs
# This is provided in this repo
lora_config_file=configs/ttt/8B_lora_single_device.yaml
# lora_config_file=configs/ttt/8.1B_lora_single_device.yaml # for barc
# But you can override some of the variables
batch_size=2
epochs=2
learning_rate=5e-5
lora_rank=128
lora_alpha=16.0
lora_to_output=False # doesn't apply for Llama3.2 models for now.
# You can specify how many tasks you want train for.
num_tasks=100
# You can run the main script
python test_time_train.py --lora_config=$lora_config_file \
--base_checkpoint_dir \
$base_checkpoint_dir \
--experiment_folder $ttt_folder \
--data_file $data_file \
--batch_size $batch_size \
--epochs $epochs \
--num_tasks=${num_tasks} \
--lora_rank=$lora_rank \
--lora_alpha=$lora_alpha \
--lora_to_output=$lora_to_output \
--new_format # use --barc_format for barc
๐ If you are using BARC checkpoints and unmask_outputs and if unmask_outputs=True in the program arguments then you need to uncomment these lines in my torchtune clone here
๐ TTT training will save adapter checkpints under
ttt_folder
you specified above.
To do inference with TTT, you run predict.py
# You need to tell where predictions and submissions should be saved
tti_folder=/path/to/tti/folder
mkdir -p $tti_folder
# Tell where your Fintuned (named as base) and TTT checkpoints are
base_checkpoint_dir=/path/to/finetuned/model/folder/
ttt_fodler=/path/to/ttt/folder
# if solution file is given predict will evaluate the model
solution_file=/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions_selected.json
python predict.py --experiment_folder=$tti_folder \
--pretrained_checkpoint $base_checkpoint_dir \
--lora_checkpoints_folder $ttt_folder \
--temperature $temperature \
--n_sample $n_sample \
--data_file $data_file \
--solution_file $solution_file \
--max_lora_rank=$lora_rank \
--include_n=1 \ # means we use leave-1-out prompts
--new_format
๐ For Llama-3 and Llama-3.2 we used different versions of VLLM, and the second one is not compatible with torchtune version that we use. So, we give setup instructions for vllm for llama3 and vllm for llama3-2 for reproducibiltiy. We use seperate conda environments for inference pipeline.
# For Llama3 and 3.1 models
conda create -n vllm python=3.10
pip install torchtune@git+https://github.com/ekinakyurek/vllm.git@ekin/torchtunecompat
# For Llama3.2 models
conda create -n vllmnew python=3.10
pip install torchtune@git+https://github.com/ekinakyurek/vllm.git@ekin/ekin/newvllm