This is the official PyTorch code of the CVPR 2021 works Learning Graph Embeddings for Compositional Zero-shot Learning and Open World Compositional Zero-Shot Learning. The code provides the implementation of the methods CGE, CompCos together with other baselines (e.g. SymNet, AoP, TMN, LabelEmbed+,RedWine). It also provides train and test for the Open World CZSL setting and the new GQA benchmark.
-
Clone the repo
-
We recommend using Anaconda for environment setup. To create the environment and activate it, please run:
conda env create --file environment.yml
conda activate czsl
- Go to the cloned repo and open a terminal. Download the datasets and embeddings, specifying the desired path (e.g.
DATA_ROOT
in the example):
bash ./utils/download_data.sh DATA_ROOT
mkdir logs
Closed World. To train a model, the command is simply:
python train.py --config CONFIG_FILE
where CONFIG_FILE
is the path to the configuration file of the model.
The folder configs
contains configuration files for all methods, i.e. CGE in configs/cge
, CompCos in configs/compcos
, and the other methods in configs/baselines
.
To run CGE on MitStates, the command is just:
python train.py --config configs/cge/mit.yml
On UT-Zappos, the command is:
python train.py --config configs/cge/utzappos.yml
Open World. To train CompCos (in the open world scenario) on MitStates, run:
python train.py --config configs/compcos/mit/compcos.yml
To run experiments in the open world setting for a non-open world method, just add --open_world
after the command. E.g. for running SymNet in the open world scenario on MitStates, the command is:
python train.py --config configs/baselines/mit/symnet.yml --open_world
Note: To create a new config, all the available arguments are indicated in flags.py
.
Closed World. To test a model, the code is simple:
python test.py --logpath LOG_DIR
where LOG_DIR
is the directory containing the logs of a model.
Open World. To test a model in the open world setting, run:
python test.py --logpath LOG_DIR --open_world
To test a CompCos model in the open world setting with hard masking, run:
python test.py --logpath LOG_DIR_COMPCOS --open_world --hard_masking
If you use this code, please cite
@inproceedings{naeem2021learning,
title={Learning Graph Embeddings for Compositional Zero-shot Learning},
author={Naeem, MF and Xian, Y and Tombari, F and Akata, Zeynep},
booktitle={34th IEEE Conference on Computer Vision and Pattern Recognition},
year={2021},
organization={IEEE}
}
and
@inproceedings{mancini2021open,
title={Open World Compositional Zero-Shot Learning},
author={Mancini, M and Naeem, MF and Xian, Y and Akata, Zeynep},
booktitle={34th IEEE Conference on Computer Vision and Pattern Recognition},
year={2021},
organization={IEEE}
}
Note: Some of the scripts are adapted from AttributeasOperators repository. GCN and GCNII implementations are imported from their respective repositories. If you find those parts useful, please consider citing:
@inproceedings{nagarajan2018attributes,
title={Attributes as operators: factorizing unseen attribute-object compositions},
author={Nagarajan, Tushar and Grauman, Kristen},
booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
pages={169--185},
year={2018}
}