/neurips2021-meta-gradient-offpolicy-evaluation

Code for Unifying Gradient Estimators for Meta-Reinforcement Learning via Off-Policy Evaluation @ NeurIPS 2021

Primary LanguagePython

Unifying Gradient Estimators for Meta-Reinforcement Learning via Off-Policy Evaluation @ NeurIPS 2021

This is the open source implementation of the toy example in the NeurIPS 2021 paper.

In the toy example, we examine the property of a few gradient and Hessian estimates of value functions in the tabular MDP. These estimates are used as subroutines for meta RL gradient estimates.

Installation

You need to install JAX. Our code works under Python 3.8 and you can install JAX by running the following

pip install jax
pip install jaxlib

Introduction to the code structure

The code contains a few components.

  • main.py implements the main loop for the experiments. It creates MDP instances, generates trajectories and computes estimates and their accuracy measures.
  • evaluation_utils.py implements different estimates through off-policy evaluation subroutines.
  • tabular_mdp.py implements the tabular MDP object.
  • plot_results.py plots the results similar to Fig 1 in the paper.

A few important aspects of the implementation:

Running the code

To run all experiments, run the following. Note that you can directly adjust hyper-parameters in main.py

python main.py

After the experiments terminate, run the following to plot results.

python plot_results.py

Citation

If you find this code base useful, you are encouraged to cite the following paper

@article{tang2021unifying,
  title={Unifying Gradient Estimators for Meta-Reinforcement Learning via Off-Policy Evaluation},
  author={Tang, Yunhao and Kozuno, Tadashi and Rowland, Mark and Munos, R{\'e}mi and Valko, Michal},
  journal={arXiv preprint arXiv:2106.13125},
  year={2021}
}