/TaiLr

ICLR2023 - Tailoring Language Generation Models under Total Variation Distance

Primary LanguagePythonMIT LicenseMIT

Tailoring Language Generation Models under Total Variation Distance

This repository contains the code of our paper (ICLR 2023):

Tailoring Language Generation Models under Total Variation Distance

In this work, we showed that MLE as an object to train text generation models tend to overestimate degenerated samples due to the zero-avoiding property of the KL Divergence (KLD). Instead, we leveraged Total Variation Distance (TVD) as an alternative to KLD and proposed the TaiLr objective based on TVD that is more robust to outliers by downweigting low-probability samples which MLE models tend to overestimate.

To illustrate this, below we give a toy experiment of fitting a single Gaussian using KLD and TVD to a mixture of two Gaussians where KLD is sensitive to outliers while TVD is robust.

Prerequisites

torch                   1.10.0
fairseq                 1.0.0
tensorboardX            2.5
sacremoses              0.0.53
regex                   2022.4.24
(Python 3.7.3)

Run pip install -e . to build fairseq locally. To work with the latest version of fairseq, simply copy the file fairseq/criterions/tailr.py to the corresponding directory fairseq/criterions in the new package.

Quick Start

Use the fairseq command line fairseq-train to run training. Besides the default argument, our method specifies the following additional arguments:

  • --density-ratio-threshold: the coefficient $\gamma$ in the proxy distribution (see Sec 3.3 of the paper)
  • --density-min-weight: the lower bound of the weighting coefficient in the training object (see Sec 3.4 of the paper)

Below, we provide more examples to run the experiments in our paper.

Synthetic Experiment

First go to examples/synthetic and download the COCO dataset from here.

Preprocess

Run the command python preproc_synth.py to preprocess the data. The splitted text data is saved in data/coco, fairseq processed data is saved in data/coco-bin.

Generate Synthetic Data

First train an oracle model by running bash train_oracle.sh (the default architecture is a one-layer LSTM). The model checkpoint will be saved in models/coco-mle-4096-lr1e-3-ep50.

Then sample synthetic data from the oracle model by running the following command:

python sample_synth.py \
       --root_dir data/coco_pseudo \
       --model_dir models/coco-mle-4096-lr1e-3-ep50 \
       --train_num 10000 \
       --valid_num 5000

The data will be saved in data/coco_pseudo. In our paper, we sampled 10K data for training and 5K for evaluation and test.

Binarize the data with fairseq-preprocess by running:

bash binarize.sh data/coco_pseudo \
     --srcdict data/coco_pseudo-bin/dict.src.txt \
     --tgtdict data/coco_pseudo-bin/dict.tgt.txt

The result will be saved in data/coco_pseudo-bin.

We also provide the checkpoint of the oracle model and the synthetic dataset.

Train

To train a generation model using TaiLr, run the training script bash train_tailr.sh. The default model architecture is the same as the oracle model.

Note:The keyword argument best-checkpoint-metric is set to nll_loss so that the best checkpoint has the lowest NLL loss.

Evaluation

PPL_oracle measures the perplexity of the oracle model evaluated on the data generated by the trained model. To evaluate this metric, there are two steps:

  1. First generate samples from the trained model (in the paper we generate 20K samples), and save the data as the test split. For example, given the trained model saved in <MODEL_DIR>, run the following command to generate data:
python sample_synth.py \
       --root_dir data/coco_tailr \
       --model_dir <MODEL_DIR> \
       --test_num 20000 \

Then binarize the data by running the command bash binarize_test.sh data/coco_tailr.

  1. To evaluate the result using the oracle model, run the command: bash eval_lm.sh models/coco-mle-4096-lr1e-3-ep50 50 test <GPU_ID> data/coco_tailr.

PPL_test measures the perplexity of the trained model evaluated on the oracle data. In the paper, we sampled 20K samples from the oracle model for evaluation. The evaluation is similar to PPL_oracle.

BLEU-4 measures the BLEU score of overlapped 4-grams between the data generated by the trained model and the oracle data. For example given the generated data saved in <GEN_FILE> and the oracle data saved in <GT_FILE>, run the command: python bleu.py --s <GEN_FILE> --r <GT_FILE>.

SelfBLEU-4 measures the SelfBLEU score by calculating overlapped 4-grams within the generated data. For example given the generated data saved in <GEN_FILE>, run the command: python selfbleu.py --s <GEN_FILE>.

Machine Translation

First go to examples/translation.

Preprocess the data following https://github.com/facebookresearch/fairseq/tree/main/examples/translation, the preprocessed data should be saved in a new directory data-bin.

To train a generation model using TaiLr, run the command bash train_tailr.sh. To generate samples on the test set with the model saved in <MODEL_DIR>, run the command: bash generate.sh <MODEL_DIR> <GPU_ID> test. Default using the best checkpoint with the highest BLEU score on the valid set.

To calculate BLEU score of the generated texts saved in <GEN_FILE> and the reference texts saved in <GT_FILE>, run the command fairseq-score -s <GEN_FILE> -r <GT_FILE>.

Summarization

First go to examples/bart/summarization.

To prepare pre-trained BART, download the checkpoint of pre-trained BART base model from here and extract the files to ../models/bart-base. Then download bpe files by running the following commands:

mkdir -p ../models/bart-base/bpe
wget -O ../models/bart-base/bpe/encoder.json 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
wget -O ../models/bart-base/bpe/vocab.bpe 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
wget -O ../models/bart-base/bpe/dict.txt 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'

Download the gigaword data from the Google Drive link provided by https://github.com/microsoft/unilm/. Copy the org_data from the extracted file to data/gigaword.

To tokenize the data using BPE, run the command bash encode_bpe.sh data/gigaword models/bart-base. Then binarize the data by running the command bash binarize.sh data/gigaword. The result will be saved in data/gigaword-bin.

To train a generation model using TaiLr, run the command bash train_tailr.sh.

To evaluate ROUGE score, first install files2rouge. To pick the best checkpoint with the highest ROUGE score, run the command bash validate.sh <MODEL_DIR> <GPU_ID> and the scores are saved in dev.score.x where x is the epoch.

To generate samples on the test set using the model saved at epoch <EPOCH>, run the command: bash generate.sh <MODEL_DIR> <EPOCH> test.src <GPU_ID> data/gigaword. Further calculate the ROUGE score of the generated summaries using files2rouge data/gigaword/test.tgt <GEN_FILE>.

Long Text Generation

First go to examples/bart/long.

Prepare the BART checkpoint following the previous section.

Download the writingPrompts data from here. Extract the files to data/writingPrompts.

To tokenize the data using BPE, run the command bash encode_bpe.sh data/writingPrompts models/bart-base. Then binarize the data by running the command bash binarize.sh data/writingPrompts. The result will be saved in data/writingPrompts-bin.

To train a generation model using TaiLr, run the command bash train_tailr.sh. The best checkpoint is picked with the lowest perplexity on the valid set.

To generate samples on the test set (In the paper we evaluate on a subset with 1000 examples due to time constraint) using the model saved at epoch <EPOCH>, run the following command:

head -n 1000 data/writingPrompts/test.src > data/writingPrompts/test.src.1000
bash sample.sh <MODEL_DIR> <EPOCH> test.src.1000 <GPU_ID> data/writingPrompts

Before evaluation, we use the script post_wp.sh to post-process the target and generated texts, which strips <newline> from the texts.

To evaluate Corpus-level BLEU score, run the command bleu_score.py --h <GEN_FILE> --r <GT_FILE> --corpus.

To evaluate Rep-l, run the command python rep.py --h <GEN_FILE>.

To evaluate Distinct, run the command python dist.py --h <GEN_FILE>.

Citation

Please kindly cite our paper if you find this paper and the codes useful.

@inproceedings{ji2023Tailoring,
    title = "Tailoring Language Generation Models under Total Variation Distance",
    author = "Haozhe Ji, Pei Ke, Zhipeng Hu, Rongsheng Zhang, Minlie Huang",
    booktitle = "The Eleventh International Conference on Learning Representations",
    year = "2023",
}