/RECAP

Retrieval-Enhanced Context-Aware Prefix Encoder for Personalized Dialogue Response Generation

Primary LanguagePythonMIT LicenseMIT

RECAP

The official repository for the ACL 2023 paper "RECAP: Retrieval-Enhanced Context-Aware Prefix Encoder for Personalized Dialogue Response Generation".

Installation

Commends for enviroment setup with conda.

conda create --name recap python=3.8
conda activate recap
pip install -U pip
pip install -r requirements.txt

Data

The data is extracted from the Reddit dump from pushshift.io. To preserve persona and personal writing style as much as possible, we did not filter out conversations with unethical content. You can download the raw data from the link here.

Pre-processing

Pre-process the raw data into the format for retrieval and generation.

Retrieval Data

Encode text representations

python src/preprocess/encode_comments.py -d <raw_data_path> -o <output_path>

Retrieval

python src/preprocess/retrieval.py -d <raw_data_path> -o <output_path>

Generation Data

Most recent hisotry responses

python src/preprocess/recent.py -d <raw_data_path> -o <output_path>

Retrieved by hierarchical transformer

This requires the retriever output in retrieved_path. Please see section training retriever and inference retrieve for details on how to train and retrieve with the hierarchical transformer retriever.

python src/preprocess/retrieved.py -d <raw_data_path> -r <retrieved_path> -o <output_path>

Training

Train the retriever and the generator on a single GPU. The code works for multi GPUs, but the batch_size here is per device batch size, so please change it accordingly if you use more than one GPU.

Retriever

python src/train_retriever.py \
    --data_path <data_path> \
    --raw_data_path <raw_data_path> \
    --reps_path <representations_path> \
    --save_path <save_path> \
    --ref_type <style OR semantic> \
    --lr 5e-5 \
    --batch_size 4 \
    --grad_accumulation 8 \
    --warmup 6250 \
    --nhead 12

Generator

python src/train_generator.py \
    --data_path <data_path> \
    --save_path <save_path> \
    --injection_mode <(optional) concat OR context-prefix> \
    --ref_type <(optional) style OR semantic> \
    --lr 5e-5 \
    --batch_size 128 \
    --warmup 10000

Inference

Retrieve and generate with trained models.

Retrieve

python src/retrieve.py \
    --data_path <data_path> \
    --model_path <retriever_model_path> \
    --save_path <save_path> \
    --ref_type <style OR semantic>

Generate

python src/generated.py \
    --data_path <data_path> \
    --model_path <generator_model_path> \
    --save_path <save_path> \
    --injection_mode <(optional) concat OR context-prefix> \
    --ref_type <(optional) style OR semantic>

Evaluate

Please download the bleurt checkpoint BLEURT-20-D3 from here before running the evaluation.

python src/eval.py \
    --generated_path <generated_responses_path> \
    --dataset_path <data_path> \
    --cache_dir <eval_cache_dir> \
    --cav_samples <eval_cav_samples_file>