[Paper] [Open Review] [Long Video]
ICLR Oral 2022
Rose E Wang, Esin Durmus, Noah Goodman, Tatsunori Hashimoto
Abstract: Modern language models can generate high-quality short texts. However, they often meander or are incoherent when generating longer texts. These issues arise from the next-token-only language modeling objective. Recent work in self-supervised learning suggests that models can learn good latent representations via contrastive learning, which can be effective for discriminative tasks. Our work analyzes the application of contrastive representations for generative tasks, like long text generation. We propose one approach for leveraging constrastive representations, which we call Time Control (TC). TC first learns a contrastive representation of the target text domain, then generates text by decoding from these representations. Compared to domain-specific methods and fine-tuning GPT2 across a variety of text domains, TC performs competitively to methods specific for learning sentence representations on discourse coherence. On long text generation settings, TC preserves the text structure both in terms of ordering (up to +15% better) and text length consistency (up to +90% better).
Contents:
- Follow the commands in
setup.sh
- Make sure you are in the virtual environment:
conda activate language_modeling_via_stochastic_processes
- Install the decoder's version of the transformers library:
cd decoder # enter the decoder repo
pip install -e . # Installing transformers locally; I modified their GPT2 module to take in our learned embeddings for decoding.
- Make sure you have a wandb account!
This repo contains all but two datasets (Wikihow and Recipe NLG). Instructions are below.
The other four datasets are already in this repo.
The Wikihow dataset needs to be downloaded from this link. It's a pkl file that should go under as path/2/repo/data/wikihow/wiki_how_data.pkl
.
The Wikisection dataset used in this paper is already included.
It came from this prior work -- specifically, we used the English city wikipedia articles.
The Recipe NLG dataset needs to be downloaded.
Download the Recipe NLG dataset and put the data under encoder/data/recipe_nlg
.
The TM2 dataset used in this paper is already included. It came from the TM2 Restaurant Search dataset.
The TicketTalk dataset used in this paper is already included.
It can be found as the TicketTalk dataset (all the json files).
Before running experiments, cd encoder/code; source init_env.sh
In encoder/code/scripts/run_ou.py
, set the variable name ckpt_dir
to your checkpoint directory.
The script for training the encoders (TC, VAE, Brownian, InfoNCE) can be found at encoder/code/scripts/train_encoders.sh
.
Before running experiments, cd encoder/code; source init_env.sh
In encoder/code/scripts/run_discourse.py
and encoder/code/src/systems/discourse_system.py
, set the correct paths to your data directory and repo.
The script for running the discourse coherence experiments can be found at encoder/code/scripts/discourse.sh
.
For training the decoder, you'll need to be in directory decoder/examples/pytorch/language-modeling/
.
The script for training the decoder can be found at decoder/examples/pytorch/language-modeling/train_encoders.sh
. Make sure to change the path2repo
variable.
You'll need to change the directories to your data directory as appropriate in run_time_clm.py
For generating texts, you'll need to be in directory decoder/transformers/examples/pytorch/text-generation/
.
The script for generating text and measuring per-section length mismatches can be found at decoder/transformers/examples/pytorch/text-generation/toy_wikisection_generation.sh
.
The script for generating long texts can be found at decoder/transformers/examples/pytorch/text-generation/long_generation.sh
.
To collect all the metrics, check out analysis/run_analysis.sh
. You can run all the evaluations with source analysis/run_analysis.sh
.
Remember to change the wandb username and project name as what you listed in the encoder and decoder experiments.