Critic-Driven Decoding for Mitigating Hallucinations in Data-to-text Generation

Critic classifier training

If you want to use WebNLG data, you can download it with

python critics/dataset_generators/download-webnlg.py SPLIT_NAME

Generating data

  • ver 1. critic (base)
python critics/dataset_generators/gen_train_onlystop.py SPLIT_NAME

SPLIT_NAME is a placeholder for "train", "test", and "dev". To generate all necessary data, you should run the command three times i.e.

python critics/dataset_generators/gen_train_onlystop.py train
python critics/dataset_generators/gen_train_onlystop.py test
python critics/dataset_generators/gen_train_onlystop.py dev
  • ver 2. critic (base with full sentences)
python critics/dataset_generators/gen_train.py SPLIT_NAME 
  • ver 3. critic (vanilla LM)
python3 ./bin/decode.py \
    --model_name facebook/bart-base \
    --experiment webnlg \
    --in_dir data/webnlg \
    --split SPLIT_NAME \
    --accelerator gpu \
    --devices 1 \
    --beam_size 1 \
    --condition_lambda 1.0 \
    --critic_top_k 5 \
    --batch_size 8\
    --out_filename FILE_NAME --wrapper data --load_in_8bit

where --model_name is the name of LM used to generate data from huggingface (here: facebook/bart-base)

python critics/dataset_generators/gen_train_fromLM.py SPLIT_NAME FILE_NAME-data
  • ver 4. critic (fine-tuned LM)

Put the checkpoint of fine-tuned language model into experiments/webnlg/CHECKPOINT_NAME path. Our BART-based LM model fine-tuned on WebNLG can be downloaded from https://we.tl/t-1aufs3tnyS

python3 ./bin/decode.py \
    --experiment webnlg \
    --checkpoint CHECKPOINT_NAME \
    --in_dir data/webnlg \
    --split SPLIT_NAME \
    --accelerator gpu \
    --devices 1 \
    --beam_size 1 \
    --condition_lambda 1.0 \
    --critic_top_k 5 \
    --batch_size 8\
    --out_filename FILE_NAME --wrapper data --load_in_8bit

python critics/dataset_generators/gen_train_fromLM.py SPLIT_NAME FILE_NAME-data
  • ver 5. critic (fine-tuned LM with full sentences)
python3 ./bin/decode.py \
    --experiment webnlg \ 
    --checkpoint CHECKPOINT_NAME \
    --in_dir data/webnlg \
    --split SPLIT_NAME \
    --accelerator gpu \
    --devices 1 \
    --beam_size 1 \
    --condition_lambda 1.0 \
    --critic_top_k 5 \
    --batch_size 8\
    --out_filename FILE_NAME --wrapper data-full --load_in_8bit

python critics/dataset_generators/gen_train_fromLM.py SPLIT_NAME FILE_NAME-data

Training the critic

Put the generated training data into OUT_DIR. The OUT_DIR directory should contain 3 files: train.csv, test.csv, and dev.csv with the training/test/validation data (these files should be generated by gen_train*.py scripts -- see above)

python critics/run.py --batch_size 32 --outdir OUT_DIR --model MLPSELU --lr 1e-5

Critic-aware decoding

Put the checkpoint of fine-tuned LM model into experiments/webnlg/CHECKPOINT_NAME path. Our BART-based LM model fine-tuned on WebNLG can be downloaded from here. The checkpoint of a trained critic should be located in CRITIC_CHECKPOINT_NAME. The name of the output file with the decoded text is specified by FILE_NAME.

python3 ./bin/decode.py \
    --experiment webnlg \
    --checkpoint LM_CHECKPOINT_NAME \
    --in_dir data/webnlg \
    --split test \
    --accelerator gpu \
    --devices 1 \
    --beam_size 1 \
    --condition_lambda 0.25 \
    --critic_top_k 5 \
    --linear_warmup \
    --batch_size 8\
    --critic_ckpt CRITIC_CHECKPOINT_NAME \
    --out_filename FILE_NAME --wrapper classifier --load_in_8bit