/THICK

Primary LanguagePythonMIT LicenseMIT

THICK world models

Implementation of THICK world models from our ICLR 2024 paper Learning Hierarchical World Models with Adaptive Temporal Abstractions from Discrete Latent Dynamics in TensorFlow 2.

Abstract

Hierarchical world models can significantly improve model-based reinforcement learning (MBRL) and planning by enabling reasoning across multiple time scales. Nonetheless, the majority of state-of-the-art MBRL methods employ flat, non-hierarchical models. We propose Temporal Hierarchies from Invariant Context Kernels (THICK), an algorithm that learns a world model hierarchy via discrete latent dynamics. The lower level of THICK updates parts of its latent state sparsely in time, forming invariant contexts. The higher level exclusively predicts situations involving context changes. Our experiments demonstrate that THICK learns categorical, interpretable, temporal abstractions on the high level, while maintaining precise low-level predictions. Furthermore, we show that the emergent hierarchical predictive model seamlessly enhances the abilities of MBRL or planning methods. We believe that THICK contributes to the further development of hierarchical agents capable of more sophisticated planning and reasoning abilities.

Installation

We provide a requirements file with our dependencies

pip3 install -r requirements.txt

Note: For problems installing nle refer to the installation guide. Multiworld needs mujoco installed.

Running experiments

MiniHack

Train THICK Dreamer on MiniHack, e.g. the EscapeRoom problem:

python3 train.py --logdir ~/logdir/minihack_escape/thick/0 --configs thick_minihack --task Minihack_escaperoomwater

For KeyCorridor tasks, the sparsity loss needs to be adjusted:

python3 train.py --logdir ~/logdir/minihack_corridor8/thick/0 --configs thick_minihack --loss_scales.ctxt_sparsity 1 --task Minihack_keycorridor8

Train default DreamerV2 as a baseline:

python3 train.py --logdir ~/logdir/minihack_corridor8/dreamerv2/0 --configs minihack --task Minihack_keycorridor8

PinPad

For the VisualPinpad experiments of our paper, first generate an exploration dataset with Plan2Explore:

python3 train.py --logdir ~/logdir/pp4/plan2explore/0 --configs visualpinpad plan2explore --task pinpad_four

Then load the data to jumpstart training THICK Dreamer:

python3 train_from_replay.py --logdir ~/logdir/pp4/thick/0 --load_dir ~/logdir/pp4/plan2explore/0 --configs thick_pinpad --task pinpad_four

Multiworld

For zero-shot planning with THICK PlaNet, first generate a dataset withh Plan2Explore:

python3 train.py --logdir ~/logdir/multiworld_door/plan2explore/0 --configs multiworld plan2explore --task MultiWorld_doorhook

We can use the offline dataset to train a THICK world model for hierarchical zero-shot planning:

python3 train_only_on_replay.py --logdir ~/logdir/multiworld_door/thick/0 --load_dir ~/logdir/multiworld_door/plan2explore/0 --configs thick_multiworld --task MultiWorld_doorhook

New Environments

Try our default setting and run for example THICK Dreamer

python3 train.py --logdir ~/logdir/new_task/thick/0 --configs thick --thick_dreamer --task your_new_task --loss_scales.ctxt_sparsity 1

The hyperparameter ctxt_sparsity needs to be tuned to a value with sparse context changes.

Acknowledgements

This code was developed based on the DreamerV2 code base. The CEM implementation is based on PlaNet. The MCTS implementation is based on MuZero. We provide a modified version of the Multiworld environment in this code base.

Citation

@inproceedings{gumbsch2024thick,
title={Learning Hierarchical World Models with Adaptive Temporal Abstractions from Discrete Latent Dynamics},
author={Christian Gumbsch and Noor Sajid and Georg Martius and Martin V. Butz},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=TjCDNssXKU}
}