/PWM

PWM: Policy Learning with Large World Models

Primary LanguageJupyter NotebookMIT LicenseMIT

PWM: Policy Learning with Large World Models

Ignat Georgiev, Varun Giridhar, Nicklas Hansen, Animesh Garg

Project website Paper Models & Datasets

This repository is a soft fork of FoRL.

Overview

We introduce Policy learning with large World Models (PWM), a novel Model-Based RL (MBRL) algorithm and framework aimed at deriving effective continuous control policies from large, muti-task world models. We utilize pre-trained TD-MPC2 world models to efficiently learn control policies with first-order gradients in <10m per task. Our empirical evaluations on complex locomotion tasks indicate that PWM not only achieves higher reward than baselines but also outperforms methods that use ground-truth simulation dynamics.

Installation

Tested only on Ubuntu 22.04. Requires Python, conda and an Nvidia GPU with >24GB VRAM.

git clone --recursive git@github.com:imgeorgiev/PWM.git
cd PWM
conda env create -f environment.yaml
conda activate pwm
ln -s $CONDA_PREFIX/lib $CONDA_PREFIX/lib64  # hack to get CUDA to work inside conda
pip install -e .
pip install -e external/tdmpc2

Single environment tasks

The first option for running PWM is on complex single-tasks with up to 152 action dimensions in the Dflex simulator. These runs used pre-trained world models which can be downloaded from hugging face.

cd scripts
conda activate pwm
python train_dflex.py env=dflex_ant alg=pwm general.checkpoint=path/to/model

Due to the nature of GPU acceleration, it is impossible to currently impossible to guarantee deterministic experiments. You can make them "less random" by using seeding(seed, True) but that slows down GPUs.

Single environment with pretraining

Instead of loading a pre-trained world model, you pretrain one yourself using the data:

cd scripts
conda activate pwm
python train_dflex.py env=dflex_ant alg=pwm general.pretrain=path/to/model pretrain_steps=XX

To recreate results from the original paper:

Task Pretrain gradient steps
Hopper 50_000
Ant 100_000
Anymal 100_000
Humanoid 200_000
SNU Humanoid 200_000

Multitask setting

Training large world model

We evaluate on the MT30 and MT80 task settings proposed by TD-MPC2.

  1. Download the data for each task following the TD-MPC2 instructions.
  2. Train a world model from the TD-MPC2 repository using the settings below. Note that horizon=16 and rho=0.99 are crucial. Note that training takes ~2 weeks on an RTX 3900. Alternatively, you can also use some of the pre-trained multi-task world models we provide.
cd external/tdmpc2/tdmpc2
python -u train.py task=mt30 model_size=48 horizon=16 batch_size=1024 rho=0.99 mpc=false disable_wandb=False data_dir=path/to/data

where path/to/data is the full TD-MPC2 dataset for either MT30 or MT80.

Evaluate multi-task

Train a policy for a specific task using the pre-trained world model

cd scripts
python train_multitask.py -cn config_mt30 alg=pwm_48M task=pendulum-swingup general.data_dir=path/to/data general.checkpoint=path/to/model
  • where path/to/data is the full TD-MPC2 dataset for either MT30 or MT80.
  • where path/to/model is the pre-trained world model as provided here.

We also provide scripts which launch slurm tasks across all tasks. scripts/mt30.bash and scripts/mt80.bash

Configs

cfg
├── alg
│   ├── pwm_19M.yaml - different sized PWM models which the main models that should be used. Paired with train_multitask.py
│   ├── pwm_317M.yaml - to be used with train_multitask.py
│   ├── pwm_48M.yaml 
│   ├── pwm_5M.yaml
│   ├── pwm.yaml - redunant but provided for reproducability; to be run with train_dflex.py
│   └── shac.yaml - works only with train_dflex.py
├── config_mt30.yaml - to be used with train_multitask.py
├── config_mt80.yaml - to be used with train_multitask.py
├── config.yaml  - to be used with train_dflex.py
└── env - dflex env config files
    ├── dflex_ant.yaml
    ├── dflex_anymal.yaml
    ├── dflex_cartpole.yaml
    ├── dflex_doublependulum.yaml
    ├── dflex_hopper.yaml
    ├── dflex_humanoid.yaml
    └── dflex_snu_humanoid.yaml

Citation

@misc{georgiev2024pwm,
    title={PWM: Policy Learning with Large World Models},
    author={Ignat Georgiev, Varun Giridha, Nicklas Hansen, and Animesh Garg},
    eprint={2407.02466},
    archivePrefix={arXiv},
    primaryClass={cs.LG},
    year={2024}
}