/CSRO

Primary LanguagePython

Context Shift Reduction for Offline Meta-Reinforcement Learning

Offline meta-reinforcement learning (OMRL) utilizes pre-collected offline datasets to enhance the agent's generalization ability on unseen tasks. However, the context shift problem arises due to the distribution discrepancy between the contexts used for training (from the behavior policy) and testing (from the exploration policy). The context shift problem leads to incorrect task inference and further deteriorates the generalization ability of the meta-policy. Existing OMRL methods either overlook this problem or attempt to mitigate it with additional information. In this paper, we propose a novel approach called Context Shift Reduction for OMRL (CSRO) to address the context shift problem with only offline datasets. The key insight of CSRO is to minimize the influence of policy in context during both the meta-training and meta-test phases. During meta-training, we design a max-min mutual information representation learning mechanism to diminish the impact of the behavior policy on task representation. In the meta-test phase, we introduce the non-prior context collection strategy to reduce the effect of the exploration policy. Experimental results demonstrate that CSRO significantly reduces the context shift and improves the generalization ability, surpassing previous methods across various challenging domains.

Installation

To install locally, you will need to first install MuJoCo. For task distributions in which the reward function varies (Cheetah, Ant, Humanoid), install MuJoCo200. Set LD_LIBRARY_PATH to point to both the MuJoCo binaries (/$HOME/.mujoco/mujoco200/bin).

For the remaining dependencies, create conda environment by

conda env create -f environment.yaml

For Walker and Hopper environments, MuJoCo131 is required. Simply install it the same way as MuJoCo200. To switch between different MuJoCo versions:

export MUJOCO_PY_MJPRO_PATH=~/.mujoco/mjpro131
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mjpro131/bin

export MUJOCO_PY_MJPRO_PATH=~/.mujoco/mujoco200
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco200/bin

The environments make use of the module rand_param_envs which is submoduled in this repository https://github.com/dennisl88/rand_param_envs. We modify some parameters of environment in random_param_envs.

The whole pipeline consists of two stages: data generation and Offline RL experiments:

Data Generation

CSRO requires fixed data (batch) for meta-training and meta-testing, which are generated by trained SAC behavior policies. Experiments at this stage are configured via train.yaml and train_point.yaml located in ./rlkit/torch/sac/pytorch_sac/config/.

The following is to divide the all environments into 8 parts. All the environments in the 0 part are trained on gpu 0:

CUDA_VISIBLE_DEVICES=0 python policy_train.py --config ./configs/[ENV].json --split 8 --split_idx 0

Generated data will be saved in ./offline_dataset/

Offline RL Experiments

Experiments are configured via json configuration files located in ./configs. Basic settings are defined and described in ./configs/default.py. To reproduce an experiment, run:

CUDA_VISIBLE_DEVICES=0 python launch_experiment.py ./configs/[ENV].json --seed 0

Output files will be written to ./output/[ENV]/[EXP NAME]/seed[seed] where the experiment name corresponds to the process starting time. The file progress.csv contains statistics logged over the course of training. We recommend viskit for visualizing learning curves: https://github.com/vitchyr/viskit. Network weights are also snapshotted during training.