This repo contains the code for the paper at CVPR'22.
Check INSTALL.md for installation instructions.
Check DATASET.md for instructions of dataset preprocessing.
For training the SGG models, you first need to download the pre-trained object detector for the Visual genome dataset. You can download the pretrained Faster R-CNN from the link.
After you download the Faster R-CNN model, please extract all the files to the directory /home/username/checkpoints/pretrained_faster_rcnn
. To train your own Faster R-CNN model, please follow the instructions given by KaihuaTang.
There are three standard protocols: (1) Predicate Classification (PredCls): taking ground truth bounding boxes and labels as inputs, (2) Scene Graph Classification (SGCls) : using ground truth bounding boxes without labels, (3) Scene Graph Detection (SGDet): detecting SGs from scratch. We use two switches MODEL.ROI_RELATION_HEAD.USE_GT_BOX
and MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL
to select the protocols.
For Predicate Classification (PredCls), we need to set:
MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL True
For Scene Graph Classification (SGCls):
MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False
For Scene Graph Detection (SGDet):
MODEL.ROI_RELATION_HEAD.USE_GT_BOX False MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False
To select the backbone SGG model, use the respective model in the argument MODEL.ROI_RELATION_HEAD.PREDICTOR
.
For Unbiased-Causal-TDE Model:
MODEL.ROI_RELATION_HEAD.PREDICTOR CausalAnalysisPredictor
For Neural-MOTIFS Model:
MODEL.ROI_RELATION_HEAD.PREDICTOR MotifPredictor
For Iterative-Message-Passing(IMP) Model (Note that SOLVER.BASE_LR should be changed to 0.001 in SGCls, or the model won't converge):
MODEL.ROI_RELATION_HEAD.PREDICTOR IMPPredictor
For VCTree Model:
MODEL.ROI_RELATION_HEAD.PREDICTOR VCTreePredictor
The default settings are under configs/e2e_relation_X_101_32_8_FPN_1x.yaml
and maskrcnn_benchmark/config/defaults.py
. The priority is command > yaml > defaults.py
For Training the Motif Model with TDE, we always set MODEL.ROI_RELATION_HEAD.CAUSAL.EFFECT_TYPE
to be 'none' during training, because causal effect analysis is only applicable to the inference/test phase.
Training Command for MOTIF-TDE-Sum Model : (SGCls, Causal, TDE, SUM Fusion, MOTIFS Model)
- The model is first trained only on the set of implicit relations using the following command for the first 30000 iterations.
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --master_port 10026 --nproc_per_node=2 tools/relation_train_net_ours.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.SELF_TRAIN_LOSS none MODEL.ROI_RELATION_HEAD.PREDICTOR CausalAnalysisPredictor MODEL.ROI_RELATION_HEAD.CAUSAL.EFFECT_TYPE none MODEL.ROI_RELATION_HEAD.CAUSAL.FUSION_TYPE sum MODEL.ROI_RELATION_HEAD.CAUSAL.CONTEXT_LAYER motifs SOLVER.IMS_PER_BATCH 12 TEST.IMS_PER_BATCH 2 DTYPE "float16" SOLVER.MAX_ITER 30000 SOLVER.VAL_PERIOD 2000 SOLVER.CHECKPOINT_PERIOD 2000 GLOVE_DIR /home/user/glove MODEL.PRETRAINED_DETECTOR_CKPT /home/user/checkpoints/pretrained_faster_rcnn/model_final.pth OUTPUT_DIR /home/user/checkpoints/causal-motifs-implicit-relations-sgcls-exmp
where GLOVE_DIR
is the directory used to save glove initializations, MODEL.PRETRAINED_DETECTOR_CKPT
is the pretrained Faster R-CNN model you want to load, OUTPUT_DIR
is the output directory used to save checkpoints for the first and the log.
- Then, the model is trained further on the imputed and original labels with Manifold Mixup.
Create a folder causal-motifs-all-relations-sgcls-exmp
and copy the the file last_checkpoint
from causal-motifs-implicit-relations-sgcls-exmp
to this folder.
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --master_port 10026 --nproc_per_node=2 tools/relation_train_net_ours.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.SELF_TRAIN_LOSS kl MODEL.ROI_RELATION_HEAD.MANIFOLD_MIXUP True MODEL.ROI_RELATION_HEAD.IMP_MANIFOLD_MIXUP True MODEL.ROI_RELATION_HEAD.PREDICTOR CausalAnalysisPredictor MODEL.ROI_RELATION_HEAD.CAUSAL.EFFECT_TYPE none MODEL.ROI_RELATION_HEAD.CAUSAL.FUSION_TYPE sum MODEL.ROI_RELATION_HEAD.CAUSAL.CONTEXT_LAYER motifs SOLVER.IMS_PER_BATCH 12 TEST.IMS_PER_BATCH 2 DTYPE "float16" SOLVER.MAX_ITER 50000 SOLVER.VAL_PERIOD 2000 SOLVER.CHECKPOINT_PERIOD 2000 GLOVE_DIR /home/user/glove MODEL.PRETRAINED_DETECTOR_CKPT /home/user/checkpoints/pretrained_faster_rcnn/model_final.pth OUTPUT_DIR /home/user/checkpoints/causal-motifs-all-relations-sgcls-exmp
Test Example for the model trained above: (SGCls, Causal, TDE, SUM Fusion, MOTIFS Model)
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port 10028 --nproc_per_node=1 tools/relation_test_net.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.SELF_TRAIN_LOSS none MODEL.ROI_RELATION_HEAD.PREDICTOR CausalAnalysisPredictor MODEL.ROI_RELATION_HEAD.CAUSAL.EFFECT_TYPE TDE MODEL.ROI_RELATION_HEAD.CAUSAL.FUSION_TYPE sum MODEL.ROI_RELATION_HEAD.CAUSAL.CONTEXT_LAYER motifs TEST.IMS_PER_BATCH 1 DTYPE "float16" GLOVE_DIR /home/user/glove MODEL.PRETRAINED_DETECTOR_CKPT /home/user/checkpoints/causal-motifs-all-relations-sgcls-exmp OUTPUT_DIR /home/user/checkpoints/causal-motifs-all-relations-sgcls-exmp
Our codebase is built upon the repository provided by KaihuaTang.