This is the repo for our work “An Extensible Plug-and-Play Method for Multi-Aspect Controllable Text Generation” (ACL 2023).
Our work focuses on Multi-Aspect Controllable Text Generation (MCTG) that controls the generated text in multiple aspects (e.g., sentiment, topic, and keywords). Our approach, Prompt Gating, is based on parameter efficient tuning methods (e.g., prefix-tuning), which control aspects by inserting trainable prompts or prefixes into the model, referred to as plugins. Specifically, Prompt Gating achieves higher fulfillment of constraints while maintaining the efficiency of training and inference. Each plugin for a single aspect is trained individually and is trained only once. Multi-aspect controlling is realized by synthesizing multiple plugins without training. Please refer to our paper for more details.
Text Generation
Topic | Sentiment | Generated Text |
---|---|---|
Mexican | Positive | While this is happening i’m going here for dinner for the first time. the food here was very, very good and very tasty!! we ordered a couple of different salads and some tacos. i got a vego beef taco with a spicy sauce (it is very good). i also got an onion rings (it does not have any onions, nor are there many onions in this recipe), and it was delicious! |
Asian | Positive | Once upon a time i was at the restaurant and had dinner at a friend’s place, the food was great. the chef is very knowledgeable (the owner’s daughter is a sushi professional), but i can only say that the restaurant was very clean and the food was very flavor. |
Asian | Negative | While this is happening, i’ve never been here to the japanese section and have never eaten here before. the place is small and the place is crowded. the food is small, and the service isn’t very friendly. i did not expect to spend a lot of money for a sushi place because it’s pretty small. they have a few places where they can take a large order and they give it to those who pay for the food. also the staff seems to ignore you if you come for sushi. |
Machine Translation
Keywords | Tense | Knowledge (French) | Source (German) | Generated Translation |
---|---|---|---|---|
"This", "bedroom", "completely" | Past | Cette chambre et une autre ont été complètement brûlées. | Dieses und ein weiteres Zimmer brannten vollständig aus. | This and another bedroom burned completely out. |
"attempt" | Future | Demain, il tentera de s’entraîner avec l’équipe. | Morgen wird er versuchen, mit der Mannschaft zu trainieren. | Tomorrow he will attempt to train with the team. |
Overview of our method:
Each plugin for a single aspect is trained individually and is trained only once. The combination of multiple aspects of constraints is realized by concatenating multiple plugins without training (i.e., in a zero-shot manner).
Model structure:
It shows the case of inference stage of double-aspect controlled text generation. Blue and purple represent lexical and sentimental constraints respectively. Continuous prompts and contextual contexts are fed into the model and trainable gates are applied to steer the pretrained model as well as alleviate the mutual interference of plugins.
Two different virtual environments are required for training and evaluation.
Training
pip install -r requirements_training.txt
Evaluation
pip install -r requirements_eval.txt
To quickly start with our code, you can download the checkpoints of our model and run the following command to generate text with constraints.
-
Put the downloaded checkpoints into
codes/checkpoints/
. -
Merge different prompts and gates that you need into one checkpoint. For example, to generate text with constraints of positive sentiment and usa topic, you need to merge
codes/checkpoints/model-pos.pt
andcodes/checkpoints/model-usa.pt
into one checkpoint, and save it incodes/checkpoints/test
. You can use scripts incodes/{thumt,thumt_gen}/scripts/substitute_prompt_gating.py
for each attribute (usecodes/thumt_gen/scripts/substitute_prompt_gating_pre_encoder.py
for tense attribute in the generation task and French constraints in the translation task). -
Prepare your data for inference like Yelp or WMT dataset in Data Preparation and change the paths of files and in
codes/infer_{translation,generation}.sh
. -
Run the following command to inference with constraints.
cd codes bash infer_{translation,generation}.sh
The generated text will be saved in
codes/checkpoints/test
.
Instructions below are for reproducing the results of our paper.
-
Find
Segment Embedding
,Prompt Encoder
andGatingModule
incodes/thumt/models/PromptGating.py
. Include them into your code.class SegmentEmbedding(nn.Module): def __init__(self, num_embeddings: int, embedding_dim: int): ... class PromptEncoder(nn.Module): def __init__(self, embed_dim, hidden_dim, prompt_num=100, seg_num=2, pri_dim=512): ... class GatingModule(nn.Module): def __init__(self, prompt_num: int, embed_dim: int, dropout, hidden_dim: int = 512, pri_dim: int = 512, seg_num: int = 2, layer_num: int = 12) -> None: ...
-
Add the segment embeddings to different sources.
# Process inputs # Suppose your first input sequence is features["source"] and the rest in features["hypothesis"] seg_mask = torch.cat([torch.zeros(features["source"].shape, dtype=torch.long) if self.src_attached[0] else torch.tensor([], dtype=torch.long), ] + [torch.ones(hyp.shape, dtype=torch.long).fill_(idp+1) if self.src_attached[idp+1] else torch.tensor([], dtype=torch.long) for idp, hyp in enumerate(features["hypothesis"])], axis=1) # in BartEncoder.forward # Suppose your first input ids is input_ids["src"] and the rest in input_ids["hypo"] def forward( self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False, prompt_enc=None, gating_enc=None, prefix_enc=None, pre_encode_fr=False, fr_mask=None ): inputs_embeds_list = [] if input_ids["src"] is not None: inputs_embeds_list.append(self.embed_tokens(input_ids["src"]) * self.embed_scale + self.embed_positions(input_ids["src"])) if input_ids["hypo"] is not None and len(input_ids["hypo"]) > 0: for hyp in input_ids["hypo"]: inputs_embeds_list.append(self.embed_tokens(hyp) * self.embed_scale + self.embed_positions(hyp)) input_segment_ids = input_ids["seg_mask"] embed_seg = self.embed_segments(input_segment_ids) x = inputs_embeds + embed_seg # Proceed with `x` ...
-
Feed prompt generated from
Prompt Encoder
into the model.# in BartEncoder.forward # Suppose your first input ids is input_ids["src"] and the rest in input_ids["hypo"] def forward( self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False, prompt_enc=None, gating_enc=None, prefix_enc=None, pre_encode_fr=False, fr_mask=None ): ... # Proceed with `x` if prompt_enc is not None: prompt = prompt_enc([True for _ in [input_ids["src"]]+input_ids["hypo"]]) if prompt is not None: prompt = prompt.expand(x.size()[0], -1, -1) x = torch.cat([prompt, x], axis=1)
-
Append gates at each layer.
# in EncoderLayer.forward def forward(self, x, encoder_padding_mask, output_attentions=False, attach=None, adding=None, gating=None, prefixkv=None): ... if gating is not None: gating = gating.unsqueeze(0).expand(x.size(1), -1, -1).transpose(0, 1) gating = torch.cat([gating, torch.ones([x.size(0)-gating.size(0), x.size(1), x.size(2)])], axis=0) x = x * gating if adding is not None: adding = adding.unsqueeze(0).expand(x.size(1), -1, -1).transpose(0, 1) adding = torch.cat([adding, torch.zeros([x.size(0)-adding.size(0), x.size(1), x.size(2)])], axis=0) x = adding + x ... # in BartEncoder.forward # Suppose `gating_enc` is an instance of `GatingModule` def forward( self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False, prompt_enc=None, gating_enc=None, prefix_enc=None, pre_encode_fr=False, fr_mask=None ): ... gating = gating_enc(input_ids['attach']) if gating is not None: adding, gating = gating ... # Feeding the input to each encoder layer encoder_states = [] if output_hidden_states else None all_attentions = () if output_attentions else None for ide, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states.append(x) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.layerdrop): # skip the layer attn = None else: x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions, attach=input_ids['attach'] if 'attach' in input_ids and input_ids['attach'] is not None else None, adding=adding[:, ide, :] if adding is not None else None, # Changed Here gating=gating[:, ide, :] if gating is not None else None, # Changed Here ) if output_attentions: all_attentions = all_attentions + (attn,) ...
-
Detokenize the raw text and discard the samples with a sentiment score of 3 (neutral).
-
Sample 30k/3k samples for training and validation respectively, and for positive/negative sentiment, Asian/USA/Mexican food respectively. (All other attributes should be balanced in each training/validation set.)
To identify the tense of each review sentence, we use this repo. We modified some of the code which is presented incodes/exp_tense/tense_detector_t.py
.
To extract keywords from each review sentence, we use this command.cat /path/to/yelp/sent_valid.txt | python codes/thumt_gen/scripts/sample_mask.py > /path/to/processed/yelp/mask_sent_valid.txt
-
Save these samples into
/path/to/processed/yelp
. The training/validation set of each attribute has two files with text and manual constraints in each line.# pos_sent_valid.txt ... super friendly, courteous, and excellent service. serg and the other young woman who works here are great and checked up on us ... # pos_label_valid.txt ... This is a positive review. ...
-
Perform sentencepiece tokenization on the text files and save the tokenized files into
/path/to/processed/yelp
. The script is incodes/spm_gen.sh
.# pos_sent_valid.spm.txt ... super Ġfriendly , Ġcour te ous , Ġand Ġexcellent Ġservice . Ġser g Ġand Ġthe Ġother Ġyoung Ġwoman Ġwho Ġworks Ġhere Ġare Ġgreat Ġand Ġchecked Ġup Ġon Ġus ... # pos_label_valid.spm.txt ... This Ġis Ġa Ġpositive Ġreview . ...
Now, the files in /path/to/processed/yelp
should be like this:
├── pos_label_train.spm.txt
├── pos_label_valid.spm.txt
├── pos_sent_train.spm.txt
├── pos_sent_valid.spm.txt
└── {neg,usa,mexi,asian,mask(keywords),future,past,present}_{label,sent}_{train,valid}.spm.txt
-
Prepare WMT14 DE-EN dataset in parallel format in
/path/to/processed/wmt
./path/to/processed/wmt ├── train.de └── train.en
-
Extract keywords constraints from target sentences using the command below:
cat /path/to/processed/wmt/train.en | python codes/thumt/scripts/sample_mask.py > /path/to/processed/wmt/train.en.masked.num.s
Label the tense of each source sentence using this repo as well.
Translate target sentences to French use any MT model you like. We uses
codes/translate_en_fr/summon_dummy.sh
.Repeat the same operation with validation and test set (newstest2013 and newstest2014).
-
Perform sentencepiece tokenization on the text files and save the tokenized files into
/path/to/processed/wmt
. The script is incodes/spm.sh
.
Now, the files in /path/to/processed/wmt
should be like this:
├── train.spm.en
├── train.spm.fr
├── train.spm.de
└── train.spm.{en.masked.num.s,en.tense}
-
Download vocabulary file of BART model into
/path/to/pretrain_BART
. If the file is in json format, then convert it withcodes/pretrain_BART/process_vocab.py
. And also prepareconfig.json
andpytorch_model.bin
in/path/to/pretrain_BART
. -
Choose one script of
exp_train_generation_*.sh
and run it.
Because of the large size of the vocabulary of mBART model, we provide a cutted model with vocabulary size of 6w to save memory.
-
Download vocabulary file of mBART model and other files in the google drive into
/path/to/pretrain_mBART
. -
Choose one script of
exp_train_translation_*.sh
and run it.
After merging different prompts and gates, you can perform inference the model on different tasks. The first step of instruction below is for inference, and the other steps are for evaluation.
-
Process the manual constraints for test set in
/path/to/processed/yelp/infer
, each of the file contains 375 lines of constraint sentences. For example,neg_label.375.spm.txt
contains 375 lines ofThis is a negative review.
and is tokenized.Tokenized prefixes used in the paper are in
data/infer_gen/pre_tokens.25.spm.txt
. Positive key words and negative key words for inference are ramdomly sampled and tokenized indata/infer_gen/maskfix.pos.375.spm.txt
anddata/infer_gen/maskfix.neg.375.spm.txt
. -
Train a classifier on sentiment and food category. The script is in
codes/train_classifier/train_*.py
. Follow the paper to oversample the raw data and train the classifiers. -
You still needs to install this repo to evaluate the tense attribute.
-
Run
codes/infer_train_generation_yelp.sh
and remember to change thesrc_attached
andprompt_attached
simultaneously according to the combination of prompts and gates you want to evaluate.
-
Run
codes/infer_train_translation.sh
and remember to change thesrc_attached
andprompt_attached
simultaneously according to the combination of prompts and gates you want to evaluate. -
With the translated results, use
codes/exp_tense/calc_tense_acc.py
to evaluate the accuracy of tense attribute.
Remember to check how many prompts and gates are in the checkpoint and merge them into one model if necessary.
The checkpoints with prompts and gates of different attributes are in google drive.
You can transfer the prompts and gates of each model into another model (or merge different prompts and gates into one model) by using example scripts codes/thumt/scripts/substitute_prompt_gating.py
or codes/thumt_gen/scripts/substitute_prompt_gating.py
.
The checkpoints with prompts and gates are in google drive.
The prompts and gates of keywords, tense and French are already merged in the same checkpoint.
If you find this repo helpful, please cite the following:
@inproceedings{huang2023extensible,
title={An Extensible Plug-and-Play Method for Multi-Aspect Controllable Text Generation},
author={Xuancheng Huang, Zijun Liu, Peng Li, Tao Li, Maosong Sun, Yang Liu},
booktitle={Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics},
publisher={Association for Computational Linguistics},
year={2023}
}