/DMBP

Official implementation of Diffusion Model-Based Predictor (DMBP) presented in ICLR2024.

Primary LanguagePythonMIT LicenseMIT

DMBP: Diffusion Model-Based Predictor for robust offline reinforcement learning against state observation perturbations

This is the official implementation of Decision Model-Based Predictor (DMBP) and the reproduced baseline algorithms (including BCQ, CQL, TD3+BC, RORL, and Diffusion QL).

Introduction

A major challenge for the real-world application of offline RL stems from the robustness against state observation perturbations, e.g., as a result of sensor errors or adversarial attacks. Unlike online robust RL, agents cannot be adversarially trained in the offline setting.

Our proposed Diffusion Model-Based Predictor (DMBP) is a new framework that recovers the actual states with conditional diffusion models for state-based RL tasks. Our derived non-Markovian training objective reduces the error accumulations, which are commonly observed for model-based methods in state-based RL tasks.

Follow we present the visualization of the denoising effect of DMBP with Diffusion QL (trained on hopper-medium-replay-v2). The observation is perturbed with Gaussian distributed random noise with std of 0.10.

DMBP_Visualization

Requirement

Our experiment is performed on D4RL benchmark environments and datasets (click here). Please install the Mujoco Version 2.1 (click here) before getting start. See requirements.txt for detailed environment set up.

Training

Baseline Algorithms Training

Before training DMBP, train a baseline offline RL algorithm at first:

python -m scripts.train_baseline --algo [ALGORITHM_NAME] --env_name [ENV_NAME] --dataset [DATASET_NAME]

DMBP Training

DMBP utilizes the trajectory datasets for training. Download the datasets of the corresponding domain through

python -m scripts.download_datasets --domain [DOMAIN_NAME]

Then DMBP can be trained through:

python -m scripts.train_DMBP --task [TASK_NAME] --algo [ALGORITHM_NAME] --env_name [ENV_NAME] --dataset [DATASET_NAME]

where the previously trained baseline algorithms are only used for training-process evaluation.

Evaluation

Robustness against noised state observations

To evaluate the baseline algorithm against different attacks on state observations, run:

python -m evaluations.eval_baseline_noise --noise_type [ATTACK_METHOD] --algo [ALGORITHM_NAME] --env_name [ENV_NAME] --dataset [DATASET_NAME]

Then, the evaluation on the corresponding baseline algorithm strengthed by DMBP can be conducted through:

python -m evaluations.eval_DMBP_noise --noise_type [ATTACK_METHOD] --algo [ALGORITHM_NAME] --env_name [ENV_NAME] --dataset [DATASET_NAME]

Robustness against incomplete state observations with unobserved dimensions

Similarly, to evaluate the baseline algorithm against incomplete state observations, run:

python -m evaluations.eval_baseline_mask --mask_dim [MASKED_DIM] --algo [ALGORITHM_NAME] --env_name [ENV_NAME] --dataset [DATASET_NAME]

Then, the evaluation on DMBP strenghted baseline algorithm can be conducted through:

python -m evaluations.eval_DMBP_mask --mask_dim [MASKED_DIM] --algo [ALGORITHM_NAME] --env_name [ENV_NAME] --dataset [DATASET_NAME]

Bibtex

If you find DMBP helpful for your work, please cite:

@inproceedings{
yang2024dmbp,
title={{DMBP}: Diffusion model-based predictor for robust offline reinforcement learning against state observation perturbations},
author={Yang, Zhihe and Xu, Yunjian},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024}
}