The official repository for the ACL 2023 paper "RECAP: Retrieval-Enhanced Context-Aware Prefix Encoder for Personalized Dialogue Response Generation".
Commends for enviroment setup with conda.
conda create --name recap python=3.8
conda activate recap
pip install -U pip
pip install -r requirements.txt
The data is extracted from the Reddit dump from pushshift.io. To preserve persona and personal writing style as much as possible, we did not filter out conversations with unethical content. You can download the raw data from the link here.
Pre-process the raw data into the format for retrieval and generation.
python src/preprocess/encode_comments.py -d <raw_data_path> -o <output_path>
python src/preprocess/retrieval.py -d <raw_data_path> -o <output_path>
python src/preprocess/recent.py -d <raw_data_path> -o <output_path>
This requires the retriever output in retrieved_path
. Please see section training retriever
and inference retrieve
for details on how to train and retrieve with the hierarchical transformer retriever.
python src/preprocess/retrieved.py -d <raw_data_path> -r <retrieved_path> -o <output_path>
Train the retriever and the generator on a single GPU. The code works for multi GPUs, but the batch_size
here is per device batch size, so please change it accordingly if you use more than one GPU.
python src/train_retriever.py \
--data_path <data_path> \
--raw_data_path <raw_data_path> \
--reps_path <representations_path> \
--save_path <save_path> \
--ref_type <style OR semantic> \
--lr 5e-5 \
--batch_size 4 \
--grad_accumulation 8 \
--warmup 6250 \
--nhead 12
python src/train_generator.py \
--data_path <data_path> \
--save_path <save_path> \
--injection_mode <(optional) concat OR context-prefix> \
--ref_type <(optional) style OR semantic> \
--lr 5e-5 \
--batch_size 128 \
--warmup 10000
Retrieve and generate with trained models.
python src/retrieve.py \
--data_path <data_path> \
--model_path <retriever_model_path> \
--save_path <save_path> \
--ref_type <style OR semantic>
python src/generated.py \
--data_path <data_path> \
--model_path <generator_model_path> \
--save_path <save_path> \
--injection_mode <(optional) concat OR context-prefix> \
--ref_type <(optional) style OR semantic>
Please download the bleurt checkpoint BLEURT-20-D3 from here before running the evaluation.
python src/eval.py \
--generated_path <generated_responses_path> \
--dataset_path <data_path> \
--cache_dir <eval_cache_dir> \
--cav_samples <eval_cav_samples_file>