/SRPO

Codes accompanying the paper "Score Regularized Policy Optimization through Diffusion Behavior" (ICLR 2024).

Primary LanguagePythonMIT LicenseMIT

Score Regularized Policy Optimization through Diffusion Behavior

Huayu Chen, Cheng Lu, Zhengyi Wang, Hang Su, Jun Zhu

image info

D4RL experiments

Requirements

Installations of PyTorch, MuJoCo, and D4RL are needed.

Running

Download the pretrained behavior and critic checkpoints from here and store them under ./SRPO_model_factory/.

You can also choose to pretrain the behavior and the critic model yourself. Respectively run

TASK="halfcheetah-medium-v2"; seed=0; python3 -u train_behavior.py --expid ${TASK}-baseline-seed${seed} --env $TASK --seed ${seed}
TASK="halfcheetah-medium-v2"; seed=0; python3 -u train_critic.py --expid ${TASK}-baseline-seed${seed} --env $TASK --seed ${seed}

Finally, run

TASK="halfcheetah-medium-v2"; seed=0; python3 -u train_policy.py --expid ${TASK}-baseline-seed${seed} --env $TASK --seed ${seed} --actor_load_path ./SRPO_model_factory/${TASK}-baseline-seed${seed}/behavior_ckpt200.pth --critic_load_path ./SRPO_model_factory/${TASK}-baseline-seed${seed}/critic_ckpt150.pth

License

MIT