/DBC

Diffusion Model-Augmented Behavioral Cloning

Primary LanguagePython

Diffusion Model-Augmented Behavioral Cloning

Shang-Fu Chen*, Hsiang-Chun Wang*, Ming-Hao Hsu, Chun-Mao Lai, Shao-Hua Sun at NTU RLL lab

[Project website] [Paper]

This is the official PyTorch implementation of the paper "Diffusion Model-Augmented Behavioral Cloning" (ICML2024).

image

Installation

  1. This code base requires Python 3.7.2 or higher. All package requirements are in requirements.txt. To install from scratch using Anaconda, use the following commands.
conda create -n [your_env_name] python=3.7.2
conda activate [your_env_name]
pip install -r requirements.txt

cd d4rl
pip install -e .
cd ../rl-toolkit
pip install -e .

cd ..
mkdir -p data/trained_models
  1. Setup Weights and Biases by first logging in with wandb login <YOUR_API_KEY> and then editing config.yaml with your W&B username and project name.

How to reproduce experiments

  • For diffusion model pretraining, run dbc/ddpm.py.
  • For policy learning, you can either run dbc/main.py for single experiment or run wandb sweep configs/<env>/<alg.yaml> to run a wandb sweep.
  • We have provided both methods to reproduce our result. Configuration files for policy learning of all tasks can be found at configs.

We specify how to train diffusion models and the location of configuration files as following:

Maze2D

  • Ours:
    1. DM pretraining: python dbc/ddpm.py --traj-load-path expert_datasets/maze.pt --num-epoch 8000 --lr 0.0001 --hidden-dim 128
    2. Policy learning:
      • single experiment: python dbc/main.py --alg dbc --bc-num-epochs 2000 --depth 3 --hidden-dim 256 --coeff 30 --coeff-bc 1 --ddpm-path data/dm/trained_models/maze_ddpm.pt --env-name maze2d-medium-v2 --lr 0.00005 --traj-load-path ./expert_datasets/maze.pt --seed 1
      • To run a single experiment on other environments, please refer to the configuration files to see the parameters for each environment.
      • wandb sweep: ./wandb.sh ./configs/maze/dbc.yaml
  • BC: ./wandb.sh ./configs/maze/bc.yaml

Fetch Pick

  • Ours:
    1. DM pretraining: python dbc/ddpm.py --traj-load-path expert_datasets/pick.pt --num-epoch 10000 --lr 0.001 --hidden-dim 1024
    2. Policy learning: ./wandb.sh ./configs/fetchPick/dbc.yaml
  • BC: ./wandb.sh ./configs/fetchPick/bc.yaml

Hand Rotate

  • Ours:
    1. DM pretraining: python dbc/ddpm.py --traj-load-path expert_datasets/hand.pt --num-epoch 10000 --lr 0.00003 --hidden-dim 2048
    2. Policy learning: ./wandb.sh ./configs/hand/dbc.yaml
  • BC: ./wandb.sh ./configs/hand/bc.yaml

Half Cheetah

  • Ours:
    1. DM pretraining: python dbc/ddpm.py --traj-load-path expert_datasets/halfcheetah.pt --num-epoch 8000 --lr 0.0002 --hidden-dim 1024
    2. Policy learning: ./wandb.sh ./configs/halfcheetah/dbc.yaml
  • BC: ./wandb.sh ./configs/halfcheetah/bc.yaml

Walker

  • Ours:
    1. DM pretraining: python dbc/ddpm.py --traj-load-path expert_datasets/walker.pt --num-epoch 8000 --lr 0.0002 --hidden-dim 1024
    2. Policy learning: ./wandb.sh ./configs/walker/dbc.yaml
  • BC: ./wandb.sh ./configs/walker/bc.yaml

Ant Goal

  • Ours:
    1. DM pretraining: python dbc/ddpm.py --traj-load-path expert_datasets/ant.pt --num-epoch 20000 --lr 0.0002 --hidden-dim 1024 --norm False
    2. Policy learning: ./wandb.sh ./configs/antReach/dbc.yaml
  • BC: ./wandb.sh ./configs/antReach/bc.yaml

Code Structure

  • Methods:
    • rl-toolkit/rlf/algos/il/dbc.py: Algorithm of our method
    • rl-toolkit/rlf/algos/il/bc.py: Algorithm of BC
  • Environments:
    • d4rl/d4rl/pointmaze/maze_model.py: Maze2D task
    • dbc/envs/fetch/custom_fetch.py: Fetch Pick task.
    • dbc/envs/hand/manipulate.py: Hand Rotate task.

Acknowledgement

Citation

@inproceedings{
    chen2024diffusion,
    title={Diffusion Model-Augmented Behavioral Cloning},
    author={Shang-Fu Chen and Hsiang-Chun Wang and Ming-Hao Hsu and Chun-Mao Lai and Shao-Hua Sun},
    booktitle={Forty-first International Conference on Machine Learning},
    year={2024},
    url={https://openreview.net/forum?id=OnidGtOhg3}
}