Code for the EMNLP 2020 Paper Content Planning for Neural Story Generation with Aristotelian Rescoring.
If you use this code, please cite as:
@inproceedings{goldfarb-tarrant-etal-2020-content,
title = "Content Planning for Neural Story Generation with Aristotelian Rescoring",
author = "Goldfarb-Tarrant, Seraphina and
Chakrabarty, Tuhin and
Weischedel, Ralph and
Peng, Nanyun",
booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
month = nov,
year = "2020",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/2020.emnlp-main.351",
doi = "10.18653/v1/2020.emnlp-main.351",
pages = "4319--4338",
abstract = "Long-form narrative text generated from large language models manages a fluent impersonation of human writing, but only at the local sentence level, and lacks structure or global cohesion. We posit that many of the problems of story generation can be addressed via high-quality content planning, and present a system that focuses on how to learn good plot structures to guide story generation. We utilize a plot-generation language model along with an ensemble of rescoring models that each implement an aspect of good story-writing as detailed in Aristotle{'}s Poetics. We find that stories written with our more principled plot-structure are both more relevant to a given prompt and higher quality than baselines that do not content plan, or that plan in an unprincipled way.",
}
Direct questions to Seraphina
story-gen-BART
Introduction
Story generation comprises of two stages: plot generation and story generation. Both stages involve a tuned BART model but have different inputs.
Training
Plot generation
Input: <story prompt>
Output: <story plot>
Story prompts are short writing commands that set up a scene, conflict, character, or all. Writing prompts are tokenized. An example prompt is
You are an undead , resurrected unwillingly and controlled to serve as part of a necromancer 's army . Slowly , but steadily , you start to regain control of your body .
A plot is a plan for a story expressed as SRL tuples for each sentence. For example:
<A0> you <V> realized <A1> you ' ve blinked # <A0> you <V> blinked </s> <A1> ent 0 <V> passes </s> <A2> ent 0 <V> covers <A1> ent 1 # <A1> ent 1 <V> happens </s> </s> <A1> ent 2 <V> moves # <A0> you <V> track <A1> ent 2 </s> <A0> You <V> keep <A1> marching , onward and onward # <A0> You <V> marching </s> <A1> Nothing <V> stops </s> <A1> It <V> stops </s> <A0> you <V> turn <A1> your head <A2> to the left </s> <A0> you <V> continue <A1> counting # <A0> you <V> counting </s> <A0> You <V> see <A1> the others # <A0> you <V> make </s> <A0> You <V> watch <A1> the shadows change # <A0> you <V> recognize <A1> the crossroads # <A0> you <V> count <A1> more than the days </s> <V> Growing <A2> more and more aware # <A0> ent 3 <V> keep <A1> going # <A0> ent 3 <V> going </s> <A0> you <V> stopped </s> </s> </s> <A1> the others <V> turn # <A1> ent 3 <V> stopped # <A1> what <V> left <A2> of who you # <A0> you <V> need <A1> ent 4 to survive # <A0> you <V> survive # <A0> you <V> call <A1> ent 4 <A2> that </s> <A1> what <A0> you <V> doing # <A1> the amount of roads <A0> you <V> pass # <A0> you <V> turn <A1> ent 5 head # <A1> ent 5 head <V> left # <V> catch <A1> ent 6 bearings </s> <A0> ent 3 <V> keep <A1> ent 3 your blinking # <A0> ent 3 <V> counting <A1> ent 3 blinking # <A0> ent 3 <V> blinking </s> <A1> ent 7 <V> matter </s> <A0> You <V> fight <A1> the fog draped over you # <A1> the fog <V> draped <A2> over you # <A1> You <V> straining # <A0> you <V> keep <A1> moving on # <A1> you <V> moving </s> <A0> you <V> turn <A1> your head </s> <A0> You <V> clench <A1> your hands </s> <A0> You <V> feel <A1> stronger </s> <A0> you <V> control <A1> your feet
The example comes from the plot
directory which contains several more examples.
To generate plots, BART must learn from pairs of prompts and plots. The method in this paper adjusts the training process so tuning BART to generate plots involves extra loss terms. The supplementary loss terms are discriminators trained to encourage Aristotelian writing principles (see Aristotle's poetics on Wikipedia, the original text, or explanations of the text).
Fine-tuning prompt to plot
-
Use
encoder.json
anddict.txt
already provided in the repo, since they contain additional delimeter tokens relevant for story generation. -
Create a directory to store the plot data
cd fairseq mkdir plot
Since this is a seq2seq task you need source and target files. Put in four files in
plot
directory:train.source
,train.target
,val.source
,val.target
. Sample data is present infairseq/plot
. -
The next step is to tokenize the input using BPE tokens. You should download
vocab.bpe
fromfbaipublicfiles
with:wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
Now for BPE preprocess:
sh create_bpe.sh <directory_name>
-
Binarize dataset:
fairseq-preprocess \ --source-lang "source" \ --target-lang "target" \ --trainpref "plot/train.bpe" \ --validpref "plot/val.bpe" \ --destdir "plot/" \ --workers 60 \ --srcdict dict.txt \ --tgtdict dict.txt
-
Download Pretrained BART from https://github.com/pytorch/fairseq/tree/4b986c51ed89f9895df2f30e57909301b4a4f19b/examples/bart
-
Fine tune BART:
sh run.sh <data_dir> <save_dir>
Update the field
BART_PATH
to point to where your pretainedmodel.pt
file is. You can customizeMAX_TOKENS
andUPDATE_FREQ
based on GPU memory and number of GPUs.You should adjust the
--max-epoch 100
parameter. If you are running to verify that the code runs and finishes, set a low value (like 100). If you want to fully train the model, remove--max-epoch
.To help in reproducibility , we have shared the data to finetune BART and finetuned PromptToPlot model In the STORYEMNLP folder "full" contains data required to train the model and checkpoint-full contains finetuned BART model https://drive.google.com/drive/folders/1cOouBxVsORnNdQJuZlH9fu3ACc7p9CwG?usp=sharing
Story generation
Full outputs
For use in analysis and comparison to our system, the full set of outputs from the Aristotelian System and baselines used in the paper (Naive, and Prompt2Story) can be found here: https://drive.google.com/drive/folders/10VFDzJvH1ssByTch4UG8mh1DsTM0xpkI?usp=sharing
Files ending in .auto
are used for automatic evaluation (1000 stories). Files ending in .human
were used for human evaluation (95 stories). Titles coindexed with the stories are in title.auto
and title.human.filtered
. Note that the stories and titles for human evaluation were filtered (refer to the paper for details) and have been detokenized. No modifications were made to the stories used for auto evaluation. We also include title+plot
files so you can view the intermediate plot representation.
Note that for the human title+plot files these have not been filtered or detokenized, so they will be a superset of the title.human.filtered
titles, and will not string match.
Implementation
Input: <story plot>
Output: <story>
For example, given the story plot from the previous section::
<A0> you <V> realized <A1> you ' ve blinked # <A0> you <V> blinked </s> <A1> ent 0 <V> passes </s> <A2> ent 0 <V> covers <A1> ent 1 # <A1> ent 1 <V> happens </s> </s> <A1> ent 2 <V> moves # <A0> you <V> track <A1> ent 2 </s> <A0> You <V> keep <A1> marching , onward and onward # <A0> You <V> marching </s> <A1> Nothing <V> stops </s> <A1> It <V> stops </s> <A0> you <V> turn <A1> your head <A2> to the left </s> <A0> you <V> continue <A1> counting # <A0> you <V> counting </s> <A0> You <V> see <A1> the others # <A0> you <V> make </s> <A0> You <V> watch <A1> the shadows change # <A0> you <V> recognize <A1> the crossroads # <A0> you <V> count <A1> more than the days </s> <V> Growing <A2> more and more aware # <A0> ent 3 <V> keep <A1> going # <A0> ent 3 <V> going </s> <A0> you <V> stopped </s> </s> </s> <A1> the others <V> turn # <A1> ent 3 <V> stopped # <A1> what <V> left <A2> of who you # <A0> you <V> need <A1> ent 4 to survive # <A0> you <V> survive # <A0> you <V> call <A1> ent 4 <A2> that </s> <A1> what <A0> you <V> doing # <A1> the amount of roads <A0> you <V> pass # <A0> you <V> turn <A1> ent 5 head # <A1> ent 5 head <V> left # <V> catch <A1> ent 6 bearings </s> <A0> ent 3 <V> keep <A1> ent 3 your blinking # <A0> ent 3 <V> counting <A1> ent 3 blinking # <A0> ent 3 <V> blinking </s> <A1> ent 7 <V> matter </s> <A0> You <V> fight <A1> the fog draped over you # <A1> the fog <V> draped <A2> over you # <A1> You <V> straining # <A0> you <V> keep <A1> moving on # <A1> you <V> moving </s> <A0> you <V> turn <A1> your head </s> <A0> You <V> clench <A1> your hands </s> <A0> You <V> feel <A1> stronger </s> <A0> you <V> control <A1> your feet
BART is tasked with generating the corresponding story. For the plot above, the gold story is:
One day , you realized you ' ve blinked . </s> ent 7 <P> <P> ent 0 A shadow passes in front of your eyes , frequently and [UNK] . </s> ent 0 It covers ent 1 everything for a moment when ent 1 it happens . </s> The others in front of you . </s> The way ent 2 the sun moves in the sky- you can track ent 2 it now , counting the number of times your eyes open and close . </s> You keep marching , onward and onward . </s> Nothing stops . </s> It never stops . </s> <P> <P> Three weeks later , by your calculations , you can turn your head , ever so slightly , to the left . </s> Stiff , but deliberate , you continue counting . </s> You can see the others next to you , as far as you can make out . </s> You watch the shadows change and you learn to recognize the crossroads as you past and you start to count more than the days . </s> <P> <P> Growing more and more aware , ent 3 you keep going . </s> If you stopped , you 'd be- <P> <P> Well , not dead . </s> ent 3 You 're already dead . </s> Or ent 3 you were . </s> But the others would turn on you if ent 3 your stopped and part of what 's left of who you used to be knows that you need ent 4 the safety to survive , if you can call ent 4 it that . </s> <P> <P> You know that what you 're doing is wrong , you can tell by the amount of roads you pass and how ent 6 the ones ahead of you are full and by the time you can turn ent 5 your head left enough to catch your bearings ent 6 they 're empty . </s> ent 3 You keep counting ent 3 your blinking and ca n't help but wonder how many days it 's been since you were- ent 7 <P> <P> Nevermind . </s> ent 7 That does n't matter . </s> You fight the fog draped over you every day , straining as you keep moving on . </s> Eventually you can turn your head to the right . </s> You can clench your hands , ever so gently . </s> You feel stronger . </s> But you ca n't control your feet
Note that:
- The text is tokenized.
- Stories are split into sentences corresponding to the sentences in the SRL plot. In the training and validation data, the plots are generated from the stories, see Section 2 in the paper.
- Entities are corefereced throughout the story. This allows BART to output coherent stories.
- New paragraphs are marked in the story text. The markers don't have any story value but are kept due to their presence in the original WritingPrompts data set.
The example comes from the story
directory which contains several more examples.
Fine-tuning plot to story
-
Use
encoder.json
anddict.txt
already provided in the repo, since they contain additional delimeter tokens relevant for story generation. -
Create a directory to store the plot data
cd fairseq mkdir story
Since this is a seq2seq task you need source and target files. Put in four files in
plot
directory:train.source
,train.target
,val.source
,val.target
. Sample data is present infairseq/story
. -
The next step is to tokenize the input using BPE tokens. You should download
vocab.bpe
fromfbaipublicfiles
with:wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
Now for BPE preprocess:
sh create_bpe.sh <directory_name>
-
Binarize dataset:
fairseq-preprocess \ --source-lang "source" \ --target-lang "target" \ --trainpref "story/train.bpe" \ --validpref "story/val.bpe" \ --destdir "plot/" \ --workers 60 \ --srcdict dict.txt \ --tgtdict dict.txt
-
Fine tune BART. You can use the pretrained model downloaded during the prompt-to-plot training:
sh run.sh <data_dir> <save_dir>
Update the field
BART_PATH
to point to where your pretainedmodel.pt
file is. You can customizeMAX_TOKENS
andUPDATE_FREQ
based on GPU memory and number of GPUs.You should adjust the
--max-epoch 100
parameter. If you are running to verify that the code runs and finishes, set a low value (like 100). If you want to fully train the model, remove--max-epoch
.To help in reproducibility , we have shared the data to finetune BART and finetuned PlotToStory model In the STORYEMNLP folder "fullstory" contains data required to train the model and checkpoint-fullstory contains finetuned BART model https://drive.google.com/drive/folders/1cOouBxVsORnNdQJuZlH9fu3ACc7p9CwG?usp=sharing
Train Aristotelian Rescorers (aka classifiers, aka discriminators)
To train a rescorer you will need to go through 3 steps:
-
Split the prompt + plot data into source and target. If you are using WritingPrompts, the data is already split into train, valid, and test.
-
Generate continuation data data. Use the script
preprocessing/make_cc_version_pnw_data.py
. -
Compile TSV files for train and validation data. Use the script
preprocessing/create_classifier_dataset.py
to generate positive and negative examples and the TSV from the output of the previous step. -
Preprocess task data and it needs binary .bin files to finetune roberta
./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
. We use it as a proxy for RTE task (sentence pair with label 0 or 1) so recommend to use it. -
Finetune the discriminator using RoBERTa-large. For fine-tuning on GLUE task use
run_disc.sh
.Finetune 4 discriminators:
- 1.0 /nas/home/fairseq/relevance-roberta/ checkpoint_best.pt roberta.large/
- 1.0 /nas/home/fairseq/eventinter-roberta checkpoint_best.pt roberta.large/
- 1.0 /nas/home/fairseq/eventintraV-roberta checkpoint_best.pt roberta.large/
- 1.0 /nas/home/fairseq/entity-roberta checkpoint_best.pt roberta.large/
Mixture Weight training for the Aristotelian rescorers
Next step is mixture coefficient training train_coefs.py is a decoding script (like inference.py) that specifically trains coefficients.
The rescoring is done here:
./story-gen-BART/fairseq/fairseq/search.py#L267-L324
This is the method that sequence generator calls in order to sample the next k hypotheses and then return them as candidates along with their probabilities.
In these 3 lines we concat the source tokens and all tokens generated so far with the current k hypotheses:
./story-gen-BART/fairseq/fairseq/search.py#L281-L284
In this line we call RoBERTa on that tensor:
./story-gen-BART/fairseq/fairseq/search.py#L298
which returns a probability distribution over the vocabulary which we multiply by the coefficients here:
./story-gen-BARTfairseq/fairseq/search.py#L307
and then add to the raw lprobs here:
./story-gen-BARTfairseq/fairseq/fairseq/search.py#L323-L324
ignore all the if learn lines in between as those are only activated if training coefficients
Inference
Prompt-to-plot:
python inference_plot.py
Calling the script with --help
shows all the required arguments.
The script needs dict.source.txt
and dict.target.txt
to be copied from the plot
directory to the checkpoint directory.
If you want to use the aristotelian rescorers, look at the arguments --apply_disc
and --scorers
.
You might need to install requests
from pip
.
Plot-to-story:
python inference_story.py
Calling the script with --help
shows all the required arguments.
The script needs dict.source.txt
and dict.target.txt
to be copied from the story
directory to the checkpoint directory.
You might need to install requests
from pip
.