/FloNet

Code for "End-to-End Learning of Flowchart Grounded Task-Oriented Dialogs"

Primary LanguagePython

FLONET

This project contains implementations of FloNet described in the paper End-to-End Learning of Flowchart Grounded Task Oriented Dialogs.

Reproduce Validation Numbers Reported in the Paper:

  1. Clone this repo

  2. Set up a python environment using the requirements.txt

  3. Download the pre-trained Glove embeddings and unzip the contents to the folder code/glove6B/

  4. Download the pretrained checkpoints. The compressed file contains the pretrained retriever and generator models for both the S-Flo and the U-Flo settings

  5. Run the inference script

    a. S-Flo setting:

     python flonet.py --save-name='FlonetInferValS' --retriever_checkpoint=path-to-the-sflo-pretrained-retriever-checkpoint.pth.tar  --gpt_model_checkpoint=path-to-the-sflo-pretrained-generator-folder --dialog-dir='../data/dialogs/' --cached-dialog-path='../data/saved_data/cached_in_domain_hard_dialogs.pkl' --domain='in_domain_hard' --saved-glove-path=./glove6B/  --inference=1 --num-epochs=0 --max_length=60
    

    b. U-Flo setting:

    python flonet.py --save-name='FlonetInferValU' --si_model_checkpoint=path-to-the-uflo-pretrained-retriever-checkpoint.pth.tar --gpt_model_checkpoint=path-to-the-uflo-pretrained-generator-folder --dialog-dir='../data/dialogs/' --cached-dialog-path='../data/saved_data/cached_out_domain_dialogs.pkl' --domain='out_domain' --saved-glove-path=./glove6B/  --inference=1 --max_length=60 --num-epochs=0 --emb-size=200 --hidden-size=600
    

Training FloNet from Scratch

  1. Clone this repo

  2. Set up a python environment using the requirements.txt

  3. Download the pre-trained Glove embeddings and unzip the contents to the folder code/glove6B/

  4. Pre-train the retriever using retriever.py. Example command shown below:

    python retriever.py --cached-dialog-path='../data/saved_data/flodial_out.pkl' --domain=out_domain --hidden-size=600 --emb-size=200
    
  5. Pre-train the generator using generator.py (input the data/gpt_input/ file generated by retriever.py). Example command shown below:

    python generator.py --dataset_path="../data/gpt_data/Retriever_out_domain.json" --dataset_cache="../data/saved_data/flonet_out_cache" 
    
  6. Rename generator's last checkpoint (in code/generator/*model_name*/) to pytorch_model.bin

  7. Feed pre-trained retriever checkpoint (in data/model/proxybest_checkpoint...), pre-trained generator checkpoint (in code/generator/*model_name*/) and retriever input to flownet.py

    python flonet.py --cached-dialog-path='../data/saved_data/flodial_out.pkl' --domain=out_domain --hidden-size=600 --emb-size=200  --si_model_checkpoint='../data/model/Retriever_checkpoint_out_domain_600_0.0001_16.pth.tar' --gpt_model_checkpoint='../data/generator/_gpt2_flowchart_out_cache_BLEU_1628354260/'
    

Scripts

  • retriever.py : Used to pretrain the retriever of FloNet.

    • need to change the following arguments
      • cached-dialog-path : path to a processed copy of input data
      • domain : in_domain (s-flo) or out_domain (u-flo)
      • save-name : prefix for the saved data and logs
      • dialog-dir : path to the dataset folder, ../data/flodial/
      • cached-scores-path : path for storing the proxy scores
      • saved-glove-path : point it to the glove embedding folder
      • hidden-size and emb-size according to the domain as explained in the paper
    • Saves the data required for pretraining GPT in data/gpt_data/ folder
    • model checkpoint saved in path pointed by mode-dir argument
    • logs are saved in log-dir argument
  • generator.py - pretrain the generator. Needs the gpt input generated by retriever for training (for the case of training using only flowchart+dialog history or only dialog history, use the GPT input file generated by generate_data_for_generator.py)

    • dataset_path : path of the GPT input file
    • dataset_cache : path for saving a cache of processed GPT input (tokenization and other things)
    • use_flowchart : 1, =0 when training for only dialog history
    • max_length : max decode length
  • flonet.py : trains FloNet, if not given checkpoints of retriever and generator, it trains the NoPretrain version. Combines arguments of generator and retriever. has additional below two arguments:

    • si_model_checkpoint : retriever checkpoint file path
    • gpt_model_checkpoint : generator checkpoint folder path