This repository is the official implementation of the TRAC optimizer in Fast TRAC: A Parameter-Free Optimizer for Lifelong Reinforcement Learning.
How can you quickly adapt to new tasks or distribution shifts? Without knowing when or how much to adapt? And without ANY tuning? 🤔💭
Well, we suggest you get on the fast TRAC 🏎️💨.
TRAC is a parameter-free optimizer for continual environments inspired by online convex optimization and uses discounted adaptive online prediction.
Update [08/20/24]: TRAC is now supported for JAX and Optax!
Like other meta-tuners, TRAC can work with any of your continual, fine-tuning, or lifelong experiments with just one line change.
pip install trac-optimizer
PyTorch
from trac_optimizer import start_trac
# original optimizer
optimizer = torch.optim.Adam
lr = 0.001
optimizer = start_trac(log_file='logs/trac.text', optimizer)(model.parameters(), lr=lr)
JAX
from trac_optimizer.experimental.jax.trac import start_trac
# original optimizer
optimizer = optax.adam(1e-3)
optimizer = start_trac(optimizer)
After this modification, you can continue using your optimizer methods exactly as you did before. Whether it's calling optimizer.step()
to update your model's parameters or optimizer.zero_grad()
to clear gradients, everything stays the same. TRAC integrates into your existing workflow without any additional overhead.
We recommend running main.ipynb
in Google Colab. This approach requires no setup, making it easy to get started with our control experiments. If you run locally, to install the necessary dependencies, simply:
pip install -r requirements.txt
Our vision-based experiments for Procgen and Atari are hosted in the vision_exp
directory, which is based off this Procgen Pytorch implementation.
To initiate an experiment with the default configuration in the Procgen "starpilot" environment, use the command below. You can easily switch to other game environments, like Atari, by altering the --exp_name="atari"
parameter:
python vision_exp/train.py --exp_name="procgen" --env_name="starpilot" --optimizer="TRAC" --warmstart_step=0