/stitching-is-combinatorial-generalisation

[ICLR 2024] Closing the Gap between TD Learning and Supervised Learning - A Generalisation Point of View.

Primary LanguagePythonMIT LicenseMIT

Raj Ghugare, $\quad$ Matthieu Geist, $\quad$ Glen Berseth*, $\quad$ Benjamin Eysenbach*

* Equal advising.

Installation

Create virtual environment named env_stuff using command:

python3 -m venv env_stuff

Install all the packages used to run the code using the requirements.txt file:

pip install -r requirements.txt

Training

To train an RvS (decision-mlp) agent on pointmaze-umaze using temporal data augmentation, with $\epsilon=0.5$ and $K=40$:

python train_dmlp.py dataset_name=pointmaze-umaze-v0 augment_data=True nclusters=40

To train a DT (decision-transformer) agent on pointmaze-umaze using temporal data augmentation, with $\epsilon=0.5$ and $K=40$:

python train_dt.py dataset_name=pointmaze-umaze-v0 augment_data=True nclusters=40

Datasets

To download the pretrained datasets, visit this google drive link.

To collect the pointmaze-large dataset with $1e^6$ transitions and seed 1:

python collect_pointmaze_data.py pointmaze-large-v0 1 1000000

To collect the antmaze-large dataset with $1e^6$ transitions and seed 1:

python collect_antmaze_data.py antmaze-umaze-v0 1 1000000

Acknowledgment

Our codebase has been build using/on top of the following codes. We thank the respective authors for their awesome contributions.

Correspondence

If you have any questions or suggestions, please reach out to me via raj.ghugare@mila.quebec.