The overall framework of Context-Aware Graph.
Knowledge distillation between CAG and Img-Only models.
This is a PyTorch implementation for Context-Aware Graph Inference with Knowledge Distillation for Visual Dialog.
If you use this code in your research, please consider citing:
This code is implemented using PyTorch v0.3.1, and provides out of the box support with CUDA 9 and CuDNN 7.
- Download the VisDial v1.0 dialog json files and images from here.
- Download the word counts file for VisDial v1.0 train split from here.
- Use Faster-RCNN to extract image features from here.
- Download pre-trained GloVe word vectors from here.
- We collected a specific subset from Visdial v1.0 val, called Visdial v1.0 (val-yn) (mentioned in our paper) in the folder subdataset.
Train the CAG model as:
python train/train.py --cuda --encoder=CAGraph
Train the Img-Only model as:
python train/train.py --cuda --encoder=Img_only
First, use the pre-trained Img-only model to generate soft-labels:
python train/soft_labels.py --model_path [path_to_root]/save/pretrained_img_only.pth --cuda
Then, fine-tune the pre-trained CAG model as:
python train/train_distill.py --model_path [path_to_root]/save/pretrained_cag.pth --softlabel ./soft_labels.h5 --cuda
Evaluation of a trained model checkpoint can be done as follows:
python eval/evaluate.py --model_path [path_to_root]/save/XXXXX.pth --cuda
This will generate an EvalAI submission file, and you can submit the json file to online evaluation server to get the result on v1.0 test-std.
Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
---|---|---|---|---|---|---|
CAG | 56.64 | 63.49 | 49.85 | 80.63 | 90.15 | 4.11 |
CAG-Distill | 57.77 | 64.62 | 51.28 | 80.58 | 90.23 | 4.05 |
- This code began with jiasenlu/visDial.pytorch. We thank the developers for doing most of the heavy-lifting.