/var-attn

Primary LanguagePythonMIT LicenseMIT

Latent Alignment and Variational Attention

This is a Pytorch implementation of the paper Latent Alignment and Variational Attention from a fork of OpenNMT.

Dependencies

The code was tested with python 3.6 and pytorch 0.4. To install the dependencies, run

pip install -r requirements.txt

Running the code

All commands are in the script va.sh.

Preprocessing the data

To preprocess the data, run

source va.sh && preprocess_bpe

The raw data in data/iwslt14-de-en was obtained from the fairseq repo with BPE_TOKENS=14000.

Training the model

To train a model, run one of the following commands:

  • Soft attention
source va.sh && CUDA_VISIBLE_DEVICES=0 train_soft_b6
  • Categorical attention with exact evidence
source va.sh && CUDA_VISIBLE_DEVICES=0 train_exact_b6
  • Variational categorical attention with exact ELBO
source va.sh && CUDA_VISIBLE_DEVICES=0 train_cat_enum_b6
  • Variational categorical attention with REINFORCE
source va.sh && CUDA_VISIBLE_DEVICES=0 train_cat_sample_b6

Checkpoints will be saved to the project's root directory.

Evaluating on test

The exact perplexity of the generative model can be obtained by running the following command with $model replaced with a saved checkpoint.

source va.sh && CUDA_VISIBLE_DEVICES=0 eval_cat $model

The model can also be used to generate translations of the test data:

source va.sh && CUDA_VISIBLE_DEVICES=0 gen_cat $model
sed -e "s/@@ //g" $model.out | perl tools/multi-bleu.perl data/iwslt14-de-en/test.en

Trained Models

Models with the lowest validation PPL were selected for evaluation on test.

Model Test PPL Test BLEU
Soft 7.17 (soft) 32.77
Exact 6.34 (hard) 33.29
VAE Exact ELBO 6.08 (hard) 33.69
VAE REINFORCE 6.17 (hard) 33.30