The official implementation of Policy-Guided Diffusion (https://arxiv.org/abs/2404.06356) - built by Matthew Jackson and Michael Matthews.
- Offline RL agents (TD3+BC, IQL),
- Trajectory-level U-Net diffusion model,
- EDM diffusion training and sampling,
- Runs on the D4RL benchmark.
Diffusion and agent training is implemented entirely in Jax, with extensive JIT-compilation and parallelization!
Diffusion and agent training is executed with python3 train_diffusion.py
and python3 train_agent.py
, with all arguments found in util/args.py
.
--log --wandb_entity [entity] --wandb_project [project]
enables logging to WandB.--debug
disables JIT compilation.
- Build docker image
cd docker & ./build.sh & cd ..
- (To enable WandB logging) Add your account key to
setup/wandb_key
:
echo [KEY] > setup/wandb_key
./run_docker.sh [GPU index] python3.9 [train_script] [args]
Diffusion training example:
./run_docker.sh 0 python3.9 train_diffusion.py --log --wandb_project diff --wandb_team flair --dataset_name walker2d-medium-v2
Agent training example:
./run_docker.sh 6 python3.9 train_agent.py --log --wandb_project agents --wandb_team flair --dataset_name walker2d-medium-v2 --agent iql
If you use this implementation in your work, please cite us with the following:
@misc{jackson2024policyguided,
title={Policy-Guided Diffusion},
author={Matthew Thomas Jackson and Michael Tryfan Matthews and Cong Lu and Benjamin Ellis and Shimon Whiteson and Jakob Foerster},
year={2024},
eprint={2404.06356},
archivePrefix={arXiv},
primaryClass={cs.LG}
}