/Mask-based-Latent-Reconstruction

This repo is the official implementation of "Mask-based Latent Reconstruction for Reinforcement Learning" (NeurIPS 2022).

Primary LanguagePythonMIT LicenseMIT

Mask-based Latent Reconstruction for Reinforcement Learning

This is the official implementation of Masked-based Latent Reconstruction for Reinforcement Learning (accepted by NeurIPS 2022), which outperforms the state-of-the-art sample-efficient reinforcement learning methods such as CURL, DrQ, SPR, PlayVirtual, etc.

Abstract

For deep reinforcement learning (RL) from pixels, learning effective state representations is crucial for achieving high performance. However, in practice, limited experience and high-dimensional inputs prevent effective representation learning. To address this, motivated by the success of mask-based modeling in other research fields, we introduce mask-based reconstruction to promote state representation learning in RL. Specifically, we propose a simple yet effective self-supervised method, Mask-based Latent Reconstruction (MLR), to predict complete state representations in the latent space from the observations with spatially and temporally masked pixels. MLR enables better use of context information when learning state representations to make them more informative, which facilitates the training of RL agents. Extensive experiments show that our MLR significantly improves the sample efficiency in RL and outperforms the state-of-the-art sample-efficient RL methods on multiple continuous and discrete control benchmarks.

Framework

image

Figure 1. The framework of the proposed MLR. We perform a random spatial-temporal masking (i.e., cube masking) on the sequence of consecutive observations in the pixel space. The masked observations are encoded to be the latent states through an online encoder. We further introduce a predictive latent decoder to decode/predict the latent states conditioned on the corresponding action sequence and temporal positional embeddings. Our method trains the networks to reconstruct the information available in the missing contents in an appropriate latent space using a cosine similarity based distance metric applied between the predicted features of the reconstructed states and the target features inferred from original observations by momentum networks.

Run MLR

We provide codes for two benchmarks: Atari and DMControl.

.
├── Atari
|   ├── README.md
|   └── ...
|── DMControl
|   ├── README.md
|   └── ...
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SUPPORT.md
└── SECURITY.md

Run Atari code: enter ./Atari for more information.

cd ./Atari

Run DMControl code: enter ./DMControl for more information.

cd ./DMControl

Citation

Please use the following arXiv citation to cite our work before the NeurIPS 2022 proceeding is ready.

@article{yu2022mask,
title={Mask-based Latent Reconstruction for Reinforcement Learning},
author={Yu, Tao and Zhang, Zhizheng and Lan, Cuiling and Chen, Zhibo and Lu, Yan},
journal={arXiv preprint arXiv:2201.12096},
year={2022}
}

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.

Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.