Shang-Fu Chen*, Hsiang-Chun Wang*, Ming-Hao Hsu, Chun-Mao Lai, Shao-Hua Sun at NTU RLL lab
This is the official PyTorch implementation of the paper "Diffusion Model-Augmented Behavioral Cloning" (ICML2024).
- This code base requires
Python 3.7.2or higher. All package requirements are inrequirements.txt. To install from scratch using Anaconda, use the following commands.
conda create -n [your_env_name] python=3.7.2
conda activate [your_env_name]
pip install -r requirements.txt
cd d4rl
pip install -e .
cd ../rl-toolkit
pip install -e .
cd ..
mkdir -p data/trained_models
- Setup Weights and Biases by first logging in with
wandb login <YOUR_API_KEY>and then editingconfig.yamlwith your W&B username and project name.
- For diffusion model pretraining, run
dbc/ddpm.py. - For policy learning, you can either run
dbc/main.pyfor single experiment or runwandb sweep configs/<env>/<alg.yaml>to run a wandb sweep. - We have provided both methods to reproduce our result. Configuration files for policy learning of all tasks can be found at
configs.
We specify how to train diffusion models and the location of configuration files as following:
- Ours:
- DM pretraining:
python dbc/ddpm.py --traj-load-path expert_datasets/maze.pt --num-epoch 8000 --lr 0.0001 --hidden-dim 128 - Policy learning:
- single experiment:
python dbc/main.py --alg dbc --bc-num-epochs 2000 --depth 3 --hidden-dim 256 --coeff 30 --coeff-bc 1 --ddpm-path data/dm/trained_models/maze_ddpm.pt --env-name maze2d-medium-v2 --lr 0.00005 --traj-load-path ./expert_datasets/maze.pt --seed 1 - To run a single experiment on other environments, please refer to the configuration files to see the parameters for each environment.
- wandb sweep:
./wandb.sh ./configs/maze/dbc.yaml
- single experiment:
- DM pretraining:
- BC:
./wandb.sh ./configs/maze/bc.yaml
- Ours:
- DM pretraining:
python dbc/ddpm.py --traj-load-path expert_datasets/pick.pt --num-epoch 10000 --lr 0.001 --hidden-dim 1024 - Policy learning:
./wandb.sh ./configs/fetchPick/dbc.yaml
- DM pretraining:
- BC:
./wandb.sh ./configs/fetchPick/bc.yaml
- Ours:
- DM pretraining:
python dbc/ddpm.py --traj-load-path expert_datasets/hand.pt --num-epoch 10000 --lr 0.00003 --hidden-dim 2048 - Policy learning:
./wandb.sh ./configs/hand/dbc.yaml
- DM pretraining:
- BC:
./wandb.sh ./configs/hand/bc.yaml
- Ours:
- DM pretraining:
python dbc/ddpm.py --traj-load-path expert_datasets/halfcheetah.pt --num-epoch 8000 --lr 0.0002 --hidden-dim 1024 - Policy learning:
./wandb.sh ./configs/halfcheetah/dbc.yaml
- DM pretraining:
- BC:
./wandb.sh ./configs/halfcheetah/bc.yaml
- Ours:
- DM pretraining:
python dbc/ddpm.py --traj-load-path expert_datasets/walker.pt --num-epoch 8000 --lr 0.0002 --hidden-dim 1024 - Policy learning:
./wandb.sh ./configs/walker/dbc.yaml
- DM pretraining:
- BC:
./wandb.sh ./configs/walker/bc.yaml
- Ours:
- DM pretraining:
python dbc/ddpm.py --traj-load-path expert_datasets/ant.pt --num-epoch 20000 --lr 0.0002 --hidden-dim 1024 --norm False - Policy learning:
./wandb.sh ./configs/antReach/dbc.yaml
- DM pretraining:
- BC:
./wandb.sh ./configs/antReach/bc.yaml
- Methods:
rl-toolkit/rlf/algos/il/dbc.py: Algorithm of our methodrl-toolkit/rlf/algos/il/bc.py: Algorithm of BC
- Environments:
d4rl/d4rl/pointmaze/maze_model.py: Maze2D taskdbc/envs/fetch/custom_fetch.py: Fetch Pick task.dbc/envs/hand/manipulate.py: Hand Rotate task.
- This repo is based on the official PyTorch implementation of the paper "Generalizable Imitation Learning from Observation via Inferring Goal Proximity"
- Base RL code and code for imitation learning baselines from rl-toolkit.
- The Maze2D environment is based on D4RL: Datasets for Deep Data-Driven Reinforcement Learning.(https://github.com/rail-berkeley/d4rl) for Maze2D.
- The Fetch and Hand Rotate environments are with some tweaking from OpenAI
- The HalfCheetah and the Walker2d environment is in OpenAI Gym.
- The Ant environment is with some tweaking from DnC
@inproceedings{
chen2024diffusion,
title={Diffusion Model-Augmented Behavioral Cloning},
author={Shang-Fu Chen and Hsiang-Chun Wang and Ming-Hao Hsu and Chun-Mao Lai and Shao-Hua Sun},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=OnidGtOhg3}
}
