/FoWM

First-order World Models

Primary LanguagePython

FoRL

A library for for First-order Reinforcement Learning algorithms.

In a world dominated by Policy Gradients based approaches, we have created a library that attempts to learn policies via First-Order Gradients (FOG). Also known as path-wise gradients or the reparametarization trick. Why? FOGs are knwon to have lower variance which translates more efficient learning but they are also don't perform very well with dicontinuous loss landscapes.

Applications:

  • Differentiable simulation
  • Model-based RL
  • World models

Installation

Tested only on Ubuntu 22.04. Requires Python, conda and an Nvidia GPU with >12GB VRAM.

  1. git clone --recursive git@github.com:pairlab/FoRL.git
  2. cd FoWM
  3. conda env create -f environment.yaml
  4. ln -s $CONDA_PREFIX/lib $CONDA_PREFIX/lib64 (hack to get CUDA to work inside conda)
  5. pip install -e .

Examples

Dflex

One of the first differentiable simulations for robotics. First proposed with the SHAC algorithm but is now depricated.

cd scripts
conda activate forl
python train_dflex.py env=dflex_ant

The script is fully configured and usable with hydra.

Warp

The successor of dflex, warp is Nvidia's current effort to create a universal differentiable simulation.

TODO examples

Gym interface

We try to comply with the normnal gym interface but due to the nature of FOG methods, we cannot do that fully. As such we require gym envs passed to our algorithms to:

  • s, r, d, info = env.step(a) must accept and return PyTorch Tensors and maintain gradients through the funciton
  • The info dict must contain termination and truncation key-value pairs. Our libtary does not use the d done flag.
  • env.reset(grad=True) must accept an optional kwarg grad which if true resets the gradient graph but does not reset the enviuronment

Example implementation of this interface

Current algorithms

Notes

  • Due to the nature of GPU acceleration, it is impossible to currently impossible to guarantee deterministic experiments. You can make them "less random" by using seeding(seed, True) but that slows down GPUs.