Gi-Cheon Kang, Sungdong Kim*, Jin-Hwa Kim*, Donghyun Kwak*, Byoung-Tak Zhang
(* Equal Contribution)
If you use this code or preprocessed data in your research, please consider citing:
@inproceedings{kang2023dialog,
title={The Dialog Must Go On: Improving Visual Dialog via Generative Self-Training},
author={Kang, Gi-Cheon and Kim, Sungdong and Kim, Jin-Hwa and Kwak, Donghyun and Zhang, Byoung-Tak},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2023},
pages={6746-6756}
}
- Setup and Dependencies
- Download Data
- Pre-trained Checkpoints
- Training
- Adaptation to Discriminative Visual Dialog
- Visual Dialog Generation
- Evaluation
- Adversarial Robustness Study
- Demo
- Acknowledgements
- License
This code is implemented using PyTorch v1.7.1+, and provides out of the box support with CUDA 11+ and CuDNN 7+. Anaconda/Miniconda is the recommended to set up this codebase:
- Install Anaconda or Miniconda distribution based on Python3.8+ from their downloads' site.
- Clone this repository and create an environment:
git clone https://www.github.com/gicheonkang/gst-visdial
conda create -n gst python=3.8 -y
# activate the environment and install all dependencies
conda activate gst
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
- Download the preprocessed original VisDial data, collected by Das et al. It includes Faster R-CNN bounding box image features of the MSCOCO dataset (80G) and preprocessed json files for dialog (2G).
chmod +x scripts/download_preprocessed_human_visdial.sh
-
We also release the machine-generated VisDial data which consists of Faster R-CNN bounding box image features of the subset of the Conceptual Captions 12M dataset (nearly 2.4T with 3.6M images) and the corresponding machine-generated dialog data.
-
If you just want to use the machine-generated dialog data along with images, download the json files for the dialog data. The json file contains urls for image data.
chmod +x scripts/download_preprocessed_machine_visdial.sh
Please download the checkpoints to checkpoints/
directory.
Model | Trained Data | Link |
---|---|---|
Questioner | VisDial v1.0 | Download |
Teacher | VisDial v1.0 | Download |
Student | VisDial v1.0 + CC12M with Synthetic Dialogs (iter3) | Download |
Student (Discriminative) | VisDial v1.0 + CC12M with Synthetic Dialogs (iter3) | Download |
Base Model from VisDial-BERT | CC3M + VQA | Download |
Teacher model and questioner model training. Nearly 54G gpu memory is required to train the model. The argument -enc_dec_a
denotes an encoder-decoder model for answerer model, and -enc_dec_q
is the encoder-decoder model for questioner model.
# Teacher model training
python train_gen.py \
-mode vd_train \
-start_path checkpoints/basemodel \
-model enc_dec_a \
-gpu_ids 0 1 2 3
# Questioner model training
python train_gen.py \
-mode vd_train \
-start_path checkpoints/basemodel \
-model enc_dec_q \
-gpu_ids 0 1 2 3
Student model training consists of two steps: (1) training on synthetically generated visual dialog dataset and (2) finetuning on original visual dialog dataset. The argument -chunk
denotes the number of data chunk to use (default 30). -select_data
is to use perplexity-based data selection method. After training on the synthetic dialog data, the student model is trained on the original visual dialog data.
# training a synthetic visual dialog dataset
python train_gen.py \
-mode cc12m_train \
-select_data \
-start_path checkpoints/basemodel \
-save_path checkpoints/iter1/ \
-chunk 30 \
-gpu_ids 0 1 2 3 \
-iter 1
# finetuning on a original visual dialog dataset
python train_gen.py \
-mode vd_train \
-continue \
-start_path checkpoints/iter1/cc12m_train_30_3.ckpt \
-save_path checkpoints/iter1/ \
-chunk 30 \
-gpu_ids 0 1 2 3
A "discriminative" visual dialog model requires answer candidates for each question, but our proposed approach only generates the ground-truth answer. Hence, we propose tricks to train the discriminative model. Based on the encoder-decoder model pre-trained on the synthetic dataset, we finetune the encoder model on the original visdial dataset. Please see our paper (Appendix B) for more details.
python train_disc.py \
-mode vd_train \
-continue \
-model enc_only_a \
-batch_size 40 \
-train_dense \
-num_negative_samples 5 \
-start_path checkpoints/x30_start_iter3.ckpt \
-save_path checkpoints/disc \
-chunk 30 \
-gpu_ids 0 1 2 3
Visual dialog generation given image features and captions. The questioner and the teacher alternately generates the visual question and corresponding answer, respectively.
You can generate your own visual dialog dataset just feeding Bottom-up Attention Features and the caption data. We extracted the image features using the docker container.
python generate.py \
-mode cc12m_gen \
-cc12m_image_feats data/cc12m/features/cc12m_img_feat_0.lmdb/ \
-cc12m_caption data/cc12m/captions/cc12m_filtered_0.json \
-start_path_q checkpoints/questioner_v1.0.ckpt \
-start_path_a checkpoints/teacher_v1.0.ckpt \
-save_name cc12m_dialogs_0.txt \
-save_path data/gen_dialog \
-gpu_ids 0 1
Evaluation of the student model on VisDial v1.0 validation split. Validation scores can be checked in offline setting. But if you want to evaluate the model on the test dataset, you should change the mode to vd_eval_test
and submit the text file to EvalAI online evaluation server. Also, evaluation for the VisDial v0.9 validation dataset is available. Please add -vd_version 0.9
.
python evaluate_gen.py \
-mode vd_eval_val \
-start_path checkpoints/student_v1.0_iter3.ckpt \
-save_path results \
-save_name gen.txt \
-gpu_ids 0 1 2 3
Evaluation for the discriminative model is as follows.
python evaluate_disc.py \
-mode vd_eval_val \
-start_path checkpoints/student_v1.0_iter3_disc_dense.ckpt \
-save_path results \
-save_name disc.txt \
-gpu_ids 0 1 2 3
We propose three different adversarial attacks for VisDial: (1) the FGSM attack, (2) a coreference attack, and (3) a random token attack. The FGSM attack perturbs input visual features, and the others attack the dialog history (textual inputs).
Simply run below for the FGSM attack
python evaluate_gen_attack.py \
-mode vd_eval_val \
-attack fgsm \
-start_path checkpoints/student_v1.0_iter3.ckpt \
-save_path results \
-save_name fgsm.txt \
-gpu_ids 0 1 2 3
For the textual attacks, preprocessing is required. Download the counter-fitted word embeddings and run the preprocessing code below.
python comp_cos_sim_mat.py counter-fitted-vectors.txt
Then, run the script
python evaluate_gen_attack.py \
-mode vd_eval_val \
-attack coreference \
-visdial_processed_val data/visdial/visdial_1.0_val_crowdsourced.json \
-visdial_processed_val_dense_annotations data/visdial/visdial_1.0_val_dense_annotations_processed_crowdsourced.json
-start_path checkpoints/student_v1.0_iter3.ckpt \
-save_path results \
-save_name coreference.txt \
-gpu_ids 0 1 2 3
We prepare interactive demo to show our model's generated answer easily. Simply run and enter the image id in VisDial v1.0 validation images.
python inference.py
We use VisDial-BERT as reference code. Thanks!
MIT License