
Coherent Soft Imitation Learning

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

Coherent Soft Imitation Learning

arXiv Python 3.7+ License

This repository contains an implementation of coherent soft imitation learning (CSIL), published at NeurIPS 2023.

We also provide implementations of other 'soft' imitation learning (SIL) algorithms: Inverse soft Q-learning (IQ-Learn) and proximal point imitation learning (PPIL).


The implementation is built on top of Acme and follows their agent structure.

├── run_csil.py                      - Example of running CSIL on continuous control tasks.
├── run_iqlearn.py                   - Example of running IQ-Learn on continuous control tasks.
├── run_ppil.py                      - Example of running PPIL on continuous control tasks.
├── soft_policy_iteration.ipynb      - Evaluation of SIL algorithms in a discrete tabular setting.
├── helpers.py                       - Utilities such as dataset iterators and environment creation.
├── experiment_logger.py             - Implements a Weights & Biases logger within the Acme framework.
├── sil
|   ├── config.py                    - Algorithm-specific configurations for soft imitation learning (SIL).
|   ├── builder.py                   - Creates the learner, actor, and policy.
|   ├── evaluator.py                 - Creates the evaluators and video recorders.
|   ├── learning.py                  - Implements the imitation learners.
|   ├── networks.py                  - Defines the policy, reward and critic networks.
|   └── pretraining.py               - Implements pre-training for policy and critic.


Before running any code, first activate the conda environment and set the PYTHONPATH:

conda activate csil
export PYTHONPATH=$(pwd)/..

To run CSIL with default settings:

python scripts/run_csil.py

This runs the online version of CSIL on HalfCheetah-v2.

The experiment configurations for each algorithm (CSIL, IQ-Learn, and PPIL), can be adjusted via the flags defined at the start of run_*.py.

The available tasks (specified with the --env_name flag) are:

door-v0         # Adroit hand
hammer-v0       # Adroit hand
pen-v0          # Adroit hand

The default setting is online soft imitation learning. To run the offline version on the Adroit door task, for example:

python scripts/run_{algo_name}.py --offline=True --env_name=door-v0

replacing {algo_name} with either csil, iqlearn, or ppil.

We have also included a Colab here that reproduces the discrete grid world experiments shown in the paper, for a range of imitation learning algorithms.

We highly encourage the use of accelerators (i.e. GPUs, TPUs) for CSIL. As CSIL requires a larger policy architecture, it has a slow wallclock time if run only on CPUs.

For a reproduction of the paper's experiment, see this Weights & Biases project.

The additional imitiation learning baselines shown in the paper are available in Acme.

Open issues

Distribued Acme experiments currently do not finish cleanly, so they appear as 'Crashed' on W&B when they finish successfully.

The robomimic experiments are currently not open-sourced.

Citing this work

  author       = {Joe Watson and
                  Sandy H. Huang and
                  Nicolas Heess},
  title        = {Coherent Soft Imitation Learning},
  booktitle    = {Advances in Neural Information Processing Systems},
  year         = {2023}


First clone this code repository into a local directory:

git clone https://github.com/google-deepmind/csil.git
cd csil

We recommend installing required dependencies inside a conda environment. To do this, first install Anaconda and then create and activate the conda environment:

conda create --name csil python=3.9
conda activate csil

CSIL is written in JAX, so first install the correct version of JAX for your system by following the installation instructions. Acme requires jax 0.4.3 and will install that version. This may need to be uninstalled for a CUDA-based JAX installation, e.g.

pip install jax==0.4.7 https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.7+cuda12.cudnn88-cp39-cp39-manylinux2014_x86_64.whl

MuJoCo must also be installed, in order to load the environments. Please follow the instructions here to install the MuJoCo binary and place it in a directory where mujoco-py can find it. This installation uses mujoco200, gym < 0.24.0 and mujoco-py for compatibility reasons.

Then install pip and use it to install all the dependencies:

pip install -r requirements.txt

To verify the installation, run

python -c "import jax.numpy as jnp; print(jnp.ones((1,)).device); import acme; import mujoco_py; import gym; print(gym.make('HalfCheetah-v2').reset())"

If this fails, follow the guidance below.


If you get the error

Command conda not found

then you need to add the folder where Anaconda is installed to your PATH variable:

export PATH=/path/to/anaconda/bin:$PATH

If you get the error

ImportError: libpython3.9.so.1.0: cannot open shared object file: No such file or directory

first activate the conda environment and then add it to the LD_LIBRARY_PATH:

conda activate csil

If you get the error

cannot find -lGL: No such file or directory

then install libGL with:

sudo apt install libgl-dev

If you get the error

fatal error: GL/glew.h: No such file or directory

then you need to install the following in your conda environment and update the CPATH:

conda install -c conda-forge glew
conda install -c conda-forge mesalib
conda install -c menpo glfw3
export CPATH=$CONDA_PREFIX/include

If you get the error

ImportError: libgmpxx.so.4: cannot open shared object file: No such file or directory

then you need to install the following in your conda environment and update the CPATH:

conda install -c conda-forge gmp
export CPATH=$CONDA_PREFIX/include

If you get the error

ImportError: ../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /lib/x86_64-linux-gnu/libLLVM-15.so.1)


mv libstdc++.so.6 libstdc++.so.6.old
ln -s /usr/lib/x86_64-linux-gnu/libstdc++.so.6 libstdc++.so.6

according to this advice.

License and disclaimer

Copyright 2023 DeepMind Technologies Limited

All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0

All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY). You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode

Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.

This is not an official Google product.