Context-Aware Graph Inference with Knowledge Distillation for Visual Dialog

alt text

The overall framework of Context-Aware Graph.

alt text

Knowledge distillation between CAG and Img-Only models.

This is a PyTorch implementation for Context-Aware Graph Inference with Knowledge Distillation for Visual Dialog.

If you use this code in your research, please consider citing:

Requirements

This code is implemented using PyTorch v0.3.1, and provides out of the box support with CUDA 9 and CuDNN 7.

Data

  1. Download the VisDial v1.0 dialog json files and images from here.
  2. Download the word counts file for VisDial v1.0 train split from here.
  3. Use Faster-RCNN to extract image features from here.
  4. Download pre-trained GloVe word vectors from here.
  5. We collected a specific subset from Visdial v1.0 val, called Visdial v1.0 (val-yn) (mentioned in our paper) in the folder subdataset.

Pre-train

Train the CAG model as:

python train/train.py --cuda --encoder=CAGraph

Train the Img-Only model as:

python train/train.py --cuda --encoder=Img_only

Distillation

First, use the pre-trained Img-only model to generate soft-labels:

python train/soft_labels.py --model_path [path_to_root]/save/pretrained_img_only.pth --cuda

Then, fine-tune the pre-trained CAG model as:

python train/train_distill.py --model_path [path_to_root]/save/pretrained_cag.pth  --softlabel ./soft_labels.h5 --cuda

Evaluation

Evaluation of a trained model checkpoint can be done as follows:

python eval/evaluate.py --model_path [path_to_root]/save/XXXXX.pth --cuda

This will generate an EvalAI submission file, and you can submit the json file to online evaluation server to get the result on v1.0 test-std.

Model NDCG MRR R@1 R@5 R@10 Mean
CAG 56.64 63.49 49.85 80.63 90.15 4.11
CAG-Distill 57.77 64.62 51.28 80.58 90.23 4.05

Acknowledgements