/jaxirl

Contains JAX implementation of algorithms for inverse reinforcement learning

Primary LanguagePythonApache License 2.0Apache-2.0

JaxIRL

Installation | Setup | Algorithms | Citation

Inverse Reinforcement Learning in JAX

Contains JAX implementation of algorithms for inverse reinforcement learning (IRL). Inverse RL is an online approach to imitation learning where we try to extract a reward function that makes the expert optimal. IRL doesn't suffer from compounding errors (like behavioural cloning) and doesn't need expert actions to train (only example trajectories of states). Depending on the environment and hyperparameters, our implementation is about ๐Ÿ”ฅ 100x ๐Ÿ”ฅ faster than standard IRL implementations in PyTorch (e.g. 3.5 minutes to train a single hopper agent โšก). By running multiple agents in parallel, you can be even faster! (e.g. 200 walker agents can be trained in ~400 minutes on 1 GPU! That's 2 minutes per agent โšกโšก).

Hopper walker
ant halfcheetah

A game-theoretic perspective on IRL

IRL

IRL is commonly framed as a two-player zero-sum game between a policy player and a reward function player. Intuitively, the reward function player tries to pick out differences between the current learner policy and the expert, while the policy player attempts to maximise this reward function to move closer to expert behaviour. This setup is effectively a GAN in the trajectory space, where the reward player is the Discriminator and the policy player is a Generator.




Why JAX?

JAX is a game-changer in the world of machine learning, empowering researchers and developers to train models with unprecedented efficiency and scalability. Here's how it sets a new standard for performance:

  • GPU Acceleration: JAX harnesses the full power of GPUs by JIT compiling code in XLA. Executing environments directly on the GPU, we eliminate CPU-GPU bottlenecks due to data transfer. This results in remarkable speedups compared to traditional frameworks like PyTorch.
  • Parallel Training at Scale: JAX effortlessly scales to multi-environment and multi-agent training scenarios, enabling efficient parallelization for massive performance gains.

All our code can be used with jit, vmap, pmap and scan inside other pipelines. This allows you to:

  • ๐ŸŽฒ Efficiently run tons of seeds in parallel on one GPU
  • ๐Ÿ’ป Perform rapid hyperparameter tuning

Running Experiments

The experts are already provided, but to re-run them, simply delete the corresponding expert file and they will be automatically retrained. The default configs for the experts are in jaxirl/configs/inner_training_configs.py. To change the default configs for the IRL training, change jaxirl/configs/outer_training_configs.py.

To train an IRL agent, run:

python jaxirl/irl/main.py --loss loss_type --env env_name

This package supports training via:

  • Behavioral Cloning (loss_type = BC)
  • IRL (loss_type = IRL)
  • Standard RL (loss_type = NONE)

We support the following brax environments:

  • halfcheetah
  • hopper
  • walker
  • ant

and classic control environments:

  • cartpole
  • pendulum
  • reacher
  • gridworld

Setup

The high-level structure of this repository is as follows:

โ”œโ”€โ”€ jaxirl  # package folder
โ”‚   โ”œโ”€โ”€ configs # standard configs for inner and outer loop
โ”‚   โ”œโ”€โ”€ envs # extra envs
โ”‚   โ”œโ”€โ”€ irl # main scripts that implement Imitation Learning and IRL algorithms
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ bc.py # Code for standard Behavioural Cloning, called when loss_type = BC
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ irl.py # Code implementing basic IRL algorithm, called when loss_type = IRL
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ gail_discriminator.py # Used by irl.py to implement IRL algorithm
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ main.py # Main script to call to execute all algorithms
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ rl.py # Code use to train basic RL agent, called when loss_type = NONE
|   โ”œโ”€โ”€ training # generated expert demos
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ ppo_v2_cont_irl.py # PPO implementation for continuous action envs
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ ppo_v2_irl.py # PPO implementation for discrete action envs
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ supervised.py # Standard supervised training implementation
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ wrappers.py # Utility wrappers for training
โ”‚   โ”œโ”€โ”€ utils # utility functions
โ”œโ”€โ”€ experts # expert policies
โ”œโ”€โ”€ experts_test # expert policies for test version of environment

Install

conda create -n jaxirl python=3.10.8
conda activate jaxirl
pip install -r requirements.txt
pip install -e .
export PYTHONPATH=jaxirl:$PYTHONPATH

Important

All scripts should be run from under jaxirl/.

Algorithms

Our IRL implementation is the moment matching version. This includes implementation tricks to make learning more stable, including decay on the discriminator and learner learning rates and gradient penalties on the discriminator.

Reproduce Results

Simply run

python3 jaxirl/irl/main.py --env env_name --loss IRL -sd 1

and the default parameters in outer_training_configs.py and the trained experts in experts/ will be used.

Citation

If you find this code useful in your research, please cite:

@misc{sapora2024evil,
      title={EvIL: Evolution Strategies for Generalisable Imitation Learning}, 
      author={Silvia Sapora and Gokul Swamy and Chris Lu and Yee Whye Teh and Jakob Nicolaus Foerster},
      year={2024},
}

See Also ๐Ÿ™Œ

Our work reused code, tricks and implementation details from the following libraries, we encourage you to take a look!

  • FastIRL: PyTorch implementation of moment matching IRL and FILTER algorithms.
  • PureJaxRL: JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training.