/marc

Public repository for "The Surprising Effectiveness of Test-Time Training for Abstract Reasoning"

Primary LanguagePythonMIT LicenseMIT

๐Ÿ“‹ We are still in progress making this repo clean. Use it with caution and please report errors and questions to us.

The Surprising Effectiveness of Test-Time Training for Abstract Reasoning

This repository is the official implementation of our paper:

The Surprising Effectiveness of Test-Time Training for Abstract Reasoning

Ekin Akyรผrek, Mehul Damani, Linlu Qiu, Han Guo, Yoon Kim, Jacob Andreas

Requirements

To install requirements, you can start a fresh conda environment, and install followings with pip:

# For TTT pipeline, we used a fork of torchtune library.
# You need to install it first
conda create -n arc python=3.10
# You can install with pip or clone and install as editable
pip install torchtune@git+https://github.com/ekinakyurek/torchtune.git@ekin/llama32
pip install torch torchao --pre --upgrade --index-url https://download.pytorch.org/whl/nightly/cu121

# Then we have simple requirements can be installed as:
pip install -r requirements.txt

๐Ÿ“‹ You need download the ARC dataset from kaggle link https://www.kaggle.com/competitions/arc-prize-2024/data

๐Ÿ“‹ We will be uploading checkpoints to the hugging face repositories very soon:

Test Time Training

To train the model(s) in the paper, run this command:

# Specify data path
data_file=/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json
# Specify finetuned path
base_checkpoint_dir=/path/to/finetuned/model/folder/
# Specify where TTT adapters should be saved
ttt_folder=/path/to/ttt/folder
mkdir -p $ttt_folder


# You need show an initial config file that is compatible with torchtune configs
# This is provided in this repo
lora_config_file=configs/ttt/8B_lora_single_device.yaml
# lora_config_file=configs/ttt/8.1B_lora_single_device.yaml # for barc
# But you can override some of the variables
batch_size=2
epochs=2
learning_rate=5e-5
lora_rank=128
lora_alpha=16.0
lora_to_output=False # doesn't apply for Llama3.2 models for now.
# You can specify how many tasks you want train for.
num_tasks=100

# You can run the main script
python test_time_train.py --lora_config=$lora_config_file \
--base_checkpoint_dir \
$base_checkpoint_dir \
--experiment_folder $ttt_folder \
--data_file $data_file \
--batch_size $batch_size \
--epochs $epochs \
--num_tasks=${num_tasks} \
--lora_rank=$lora_rank \
--lora_alpha=$lora_alpha \
--lora_to_output=$lora_to_output \
--new_format # use --barc_format for barc

๐Ÿ“‹ If you are using BARC checkpoints and unmask_outputs and if unmask_outputs=True in the program arguments then you need to uncomment these lines in my torchtune clone here

๐Ÿ“‹ TTT training will save adapter checkpints under ttt_folder you specified above.

Inference

To do inference with TTT, you run predict.py

# You need to tell where predictions and submissions should be saved
tti_folder=/path/to/tti/folder
mkdir -p $tti_folder
# Tell where your Fintuned (named as base) and TTT checkpoints are
base_checkpoint_dir=/path/to/finetuned/model/folder/
ttt_fodler=/path/to/ttt/folder

# if solution file is given predict will evaluate the model
solution_file=/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions_selected.json

python predict.py --experiment_folder=$tti_folder \
--pretrained_checkpoint $base_checkpoint_dir \
--lora_checkpoints_folder $ttt_folder \
--temperature $temperature \
--n_sample $n_sample \
--data_file $data_file \
--solution_file $solution_file \
--max_lora_rank=$lora_rank \
--include_n=1 \ # means we use leave-1-out prompts
--new_format

๐Ÿ“‹ For Llama-3 and Llama-3.2 we used different versions of VLLM, and the second one is not compatible with torchtune version that we use. So, we give setup instructions for vllm for llama3 and vllm for llama3-2 for reproducibiltiy. We use seperate conda environments for inference pipeline.

# For Llama3 and 3.1 models
conda create -n vllm python=3.10
pip install torchtune@git+https://github.com/ekinakyurek/vllm.git@ekin/torchtunecompat
# For Llama3.2 models
conda create -n vllmnew python=3.10
pip install torchtune@git+https://github.com/ekinakyurek/vllm.git@ekin/ekin/newvllm