/TPR-RNN-Torch

PyTorch implementation of TPR-RNN (Learning to Reason with Third-Order Tensor Products paper)

Primary LanguagePythonMIT LicenseMIT

Tensor Product Representation Recurrent Neural Network

This repository containes PyTorch implementation of paper Learning to Reason with Third-Order Tensor Products published at NeurIPS, 2018. TPR-RNN is applied to the bAbI tasks and achieves SOTA results. This implementation is primarily based on the original implementation.

Requirements

  • Python 3.6
  • Pytorch==1.0.0
  • tensorboardX==1.5

How to setup environment

  1. Download and install conda
  2. Create conda environment from environment.yml file
conda env create -n tpr_rnn -f environment.yml
  1. Activate conda environment
source activate tpr_rnn

Usage

Run the pre-trained model.

python3 eval.py --model-dir PATH [--no-cuda]

Train from scratch. (Look at the train.py files for details)

python3 train.py --config-file PATH --serialization-path PATH
[--eval-test] [--logging-level LEVEL]

Cluster analysis

python3 cluster_analysis.py --model-path PATH [--num-stories N]

Cluster Analysis

Results of cluster analysis on random stories related to task 3.

e1

e2

r1

r2

r3

Evaluation results

For each task model was trained, here are test data evaluation results. Robust hyper-parameter search wasn't done and Adam is being used instead of NAdam from original implementation but results are comparable to results obtained in the original paper.

Task 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
Error 0.0 0.1 1.1 0.0 0.3 2.5 0.4 0.4 2.2 0.1 0.3 0.2 1.6 0.4 0.0 0.0 3.9 2.3 0.2 0.0

Pre-trained models

Pre-trained models are store at google drive.