This is the implementation of the paper:
CASPI:Causal-aware Safe Policy Improvement for Task-oriented dialogue. [paper]
If you use this code, data or our results in your research, please cite as below bibtex:
@article{ramachandran2021causal, title={Causal-aware Safe Policy Improvement for Task-oriented dialogue}, author={Ramachandran, Govardana Sachithanandam and Hashimoto, Kazuma and Xiong, Caiming}, journal={arXiv preprint arXiv:2103.06370}, year={2021} }
Run the following command to install dependencies
pip install -r requirements.txt
Run the following command for data setup
./damd_multiwoz/scripts/data_setup.sh
Create K-fold datasets
Please choose appropriate number of folds. In our work, we use 10 folds. larger the number of folds, larger number of model needs to be trained. For quick turn around smaller number of folds with marginal loss in performance.
python CreateKFoldDataset.py --seed 111 --folds 10
Generate dataset for pairiwse reward model
Following script needs to be run K times, each time the value passed to --fold argument should be increment by 1 i.e between 0 and K-1 and prefereably with different seeds
./damd_multiwoz/scripts/gen_reward_rollout.sh --cuda 0 --K 10 --fold 0 --metric soft --seed 68690
Pairwise Reward Learning
Please ensure the number of folds choosen matches with previous step
python RewardLearning.py --seed 11 --folds 10 --action_space act --gamma 0.0 --metric soft
Estimate Behavior Policy
Please ensure the number of folds, gamma and action_space match with reward learning step
python EstimateBehaviorPolicy.py --seed 111 --folds 10 --action_space act --gamma 0.0 --metric soft
CASPI(MinTL),M_soft(act)
In this version of CASPI, we use MinTL as the base model
Please ensure the argument --caspi_returns_file matches the choices made in previous steps. The file is of the form fn_Gs_<action_space>_.json
python train.py --mode train --context_window 2 --pretrained_checkpoint bart-large-cnn --gradient_accumulation_steps 8 --lr 3e-5 --back_bone bart --cfg seed=111 cuda_device=0 batch_size=8 early_stop_count=7 --caspi_returns_file=fn_Gs_10_0.0_act_soft.json --caspi_wt=5. --caspi_data_file=data_for_damd.json --caspi_val_fraction=.5 --caspi
CASPI(DAMD),M_soft(act) End-to-end
In this version of CASPI, we use DAMD as the base model. This script is to test end-to-end performance. Please ensure the arguments matches the choices made in the reward learning steps
./damd_multiwoz/scripts/caspi_damd.sh --cuda 0 --seed 111 --K 10 --gamma 0.0 --policy_loss L_det,L_sto --action_space act --metric soft --train_e2e True
CASPI(DAMD),M_soft(act) dialogue-context-to-text
In this version of CASPI, we use DAMD as the base model. This script is to test only dialogue-context-to-text generation task of Multiwoz2.0 . Please ensure the arguments matches the choices made in the reward learning steps
./damd_multiwoz/scripts/caspi_damd.sh --cuda 0 --seed 111 --K 10 --gamma 0.0 --policy_loss L_det --action_space act --metric soft --train_e2e False
This code extends or uses following prior codebase and data:
Please refer the [paper] and feel free to reach out to us.