/edp

[NeurIPS 2023] Efficient Diffusion Policy

Primary LanguagePythonMIT LicenseMIT

Efficient Diffusion Policy

Official Jax implementation of EDP, from the following paper:

Efficient Diffusion Policies for Offline Reinforcement Learning. NeurIPS 2023.
Bingyi Kang, Xiao Ma, Chao, Du, Tianyu Pang, Shuicheng Yan
Sea AI Lab
[arxiv]


We propse a class of diffusion policies (EDP) that are efficient to train and generally compatible to a variety of RL algorithms. EDP serves as a more powerful policy representation for decision making, which can be used as a plug-in replacement for feed-forward policies (e.g., Gaussian policies). It has the following features:

  • Enabling training diffusion with long steps, e.g., 1000 steps.
  • $25\times$ boost in traning speed, reducing training time from 5 days to 5 hours.
  • Generally applicable to both likelihood-based methods (PG, CRR, AWR, IQL) and value-maximization based methods (DDPG, TD3)
  • Setting new state-of-the-arts on all four domains in D4RL.

Main Results

Usage

Before you start, make sure to run

pip install -e .

Apart from this, you'll have to setup your MuJoCo environment and key as well. Please follow D4RL repo and setup the environment accordingly.

Run Experiments

You can run EDP experiments using the following command:

python -m diffusion.trainer --env 'walker2d-medium-v2' --logging.output_dir './experiment_output' --algo_cfg.loss_type=TD3

To use other offline RL algorithms, simply change --algo_cfg.loss_type parameter. For example:

python -m diffusion.trainer --env 'walker2d-medium-v2' --logging.output_dir './experiment_output' --algo_cfg.loss_type=IQL --norm_reward=True

By default we use ddpm solver. To use dpm, set --sample_method=dpm and -algo_cfg.num_timesteps=1000.

Weights and Biases Online Visualization Integration

This codebase can also log to W&B online visualization platform. To log to W&B, you first need to set your W&B API key environment variable. Alternatively, you could simply run wandb login.

Credits

The project structure borrows from the Jax CQL implementation.

We also refer to the diffusion model implementation from OpenAI and the official diffusion Q learning implementation.