KG-DQN

Preliminary release of the code from the paper "Playing Text-Adventure Games with Graph-Based Deep Reinforcement Learning ", Prithviraj Ammanabrolu and Mark O. Riedl, NAACL-HLT 2019, Minneapolis, MN - https://arxiv.org/abs/1812.01628, ACL Anthology

BibTex:

{
      @inproceedings{ammanabrolu-riedl-2019-playing,
        title = "Playing Text-Adventure Games with Graph-Based Deep Reinforcement Learning",
        author = "Ammanabrolu, Prithviraj  and
          Riedl, Mark",
        booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)",
        month = jun,
        year = "2019",
        address = "Minneapolis, Minnesota",
        publisher = "Association for Computational Linguistics",
        url = "https://www.aclweb.org/anthology/N19-1358",
        pages = "3557--3565",
        abstract = "Text-based adventure games provide a platform on which to explore reinforcement learning in the context of a combinatorial action space, such as natural language. We present a deep reinforcement learning architecture that represents the game state as a knowledge graph which is learned during exploration. This graph is used to prune the action space, enabling more efficient exploration. The question of which action to take can be reduced to a question-answering task, a form of transfer learning that pre-trains certain parts of our architecture. In experiments using the TextWorld framework, we show that our proposed technique can learn a control policy faster than baseline alternatives. We have also open-sourced our code at https://github.com/rajammanabrolu/KG-DQN.",
}

Disclaimer: Code is not upkept

Data/Pre-training

  • Games are created using Textworld's tw-make as specified in the paper.
  • Pre-training is done using DrQA by generating traces using the WalkthroughAgent in Textworld.
    • These traces consist of pairs of (observation, action) pairs which are then used to train DrQA by asking the question of "What action do I take?"
    • Run python scripts/datacollector.py <game-directory> oracle then python scripts/format_to_drqa.py to generate the .json files required to train DrQA
      • Then use the .json files with 100 dimensional GloVe embeddings to run preprocess.py then train.py in the DrQA repo
      • Other hyper parameters for DrQA are (insert into train.py in the DrQA repo)
      'doc_hidden_size': 64,
      'doc_layers': 3,
      'doc_dropout_rnn': 0.2,
      'doc_dropout_rnn_output': True,
      'doc_concat_rnn_layers': True,
      'doc_rnn_padding': True
      

Running the code

  • Code is run using an Anaconda environment for Python 3.6. The environment is defined in env.yml. Run conda env create -f env.yml and then source activate kgdqn to enter the correct environment.
  • Baseline BOW-DQN implementation is in dqn/
  • KG-DQN implementation is in kgdqn/,
  • Download CoreNLP from here and the corresponding English model .jar files
    • Run it with java -mx6g -cp "*" edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9000 -timeout 15000
  • w2id.txt and act2id.txt are required and are the dictionaries for the vocab and full action set for the specific game
    • w2idis the vocabulary and can be extracted from the .z8 game file that is given by tw-make using tw-extract, adding two additional tokens in the form of <UNK> and <PAD> in the 0th and 1st indices.
    • act2id can also be generated by using a union of all admissible actions found using the both the walkthrough agent and random agent to explore the world.
  • Similarly, entity2id.tsv/relation2id.tsv defines the entities and relations that can be extracted by OpenIE for the game in a dictionary format
    • These are also extracted by running both the walkthrough agents and random agents with just the triple extraction process found in representations.py and enumerating all entities and relations found. Entities and relations not in these files at test time are ignored.
  • Run python scripts/datacollector.py <game-directory> collect to generate entity2id.tsv relation2id.tsv act2id.txt
  • For both games, run using python train.py after defining the required parameters and game in train.py