
PyTorch implementation of Attention-over-Attention Neural Networks for Reading Comprehension

Primary LanguagePythonMIT LicenseMIT

Attention-over-Attention Model for Reading Comprehension

This is an implementation of Attention-over-Attention Model with PyTorch. This model was proposed by Cui et al. (paper).


  • PyTorch with cuda
  • Python 3.6+
  • NLTK (with punkt data)


This implementation uses facebook’s children’s book test data.


Make sure the data files (train.txt, dev.txt, test.txt) are present in the data directory.

To preprocess the data:

python preprocess.py

This will generate the dictonary(dict.pt) from all words appeared in the dataset and vectorize all data (train.txt.pt, dev.txt.pt, test.txt.pt).

Train the model

Below is an example of training a model, set the parameters as you like.

python train.py -traindata data/train.txt.pt -validdata data/test.txt.pt -dict data/dict.pt \
 -save_model model1 -gru_size 384 -embed_size 384 -batch_size 64 -dropout 0.1 \
 -epochs 13 -learning_rate 0.001 -weigth_decay 0.0001 -gpu 1 -log_interval 50

After each epoch, a checkpoint will be saved, to resume a training process from checkpoint:

python train.py -train_from xxx_model_xxx_epoch_x.pt


python test.py -testdata data/test.txt.pt -dict data/dict.pt -out result.txt -model models/xx_checkpoint_epochxx.pt


MIT License