Implementation for paper "Modality-Aware Representation Learning for Zero-shot Sketch-based Image Retrieval"

Primary LanguagePythonMIT LicenseMIT

Modality-Aware Representation Learning for Zero-shot Sketch-based Image Retrieval

This is an implementation of the paper "Modality-Aware Representation Learning for Zero-shot Sketch-based Image Retrieval", accepted at WACV 2024.

  • this includes categorical ZS-SBIR
  • this does not include instance-level ZS-SBIR

This repository has been updated to reflect the version used for our submission. (Keep in mind that there might be slight differences since we did our experiments with another private repository.)

You can find the pretrained model weights at this link.


  • prepare a virtual environment

      conda create --name sbir python=3.8
  • install the required packages

      pip install -r requirements.txt
  • download necessary datasets

    • quickdraw-extended
    • sketchy-extended
    • tuberlin-extended
  • update lines 12 to 16 in get_loader.py with the paths where you downloaded each dataset

    # example
    data_dir_paths = [


  • update setting file (config-train.yaml)

    • Our code distinguishes between different datasets using indices.

      Idx dataset
      0 TU-berlin
      1 QuickDraw
      2 Sketchy
  • run train.py

      python train.py
  • after successful training, new output folders(outputs, wandb) are generated as shown below

        ├── data
        ├── model
        ├── outputs/
        │   └── 2024-05-10/
        │       └── 22-31-19/
        │           ├── .hydra
        │           ├── model
        │           └── train.log
        ├── setting
        ├── wandb/
        │   └── ...(omitted)
        ├── .gitignore
        └── ...(omitted)


  • update setting file (config-infer.yaml)

  • run infer.py

      python infer.py
  • this will use trained model weights to generate logits, creating infer_output folder.

        ├── ...(omitted)
        ├── outputs/
        │   └── 2024-05-10/
        │       └── 22-31-19/
        │           ├── .hydra
        │           ├── model
        │           ├── infer_output/
        │           │   └── best/
        │           │       └── 0/ <--(dataset index)
        │           │           ├── converted_logits.pt
        │           │           ├── img_logits.pt
        │           │           └── txt_logits.pt
        │           └── train.log
        └── ...(omitted)


  • update setting file (config-eval.yaml)

  • run evaluate.py

      python evaluate.py
  • this process shows metric results and the top 10 retrieved photos for each sketch image. you can check these in the newly generated file (named eval_result.json), after running code


    author    = {Lyou, Eunyi and Lee, Doyeon and Kim, Jooeun and Lee, Joonseok},
    title     = {Modality-Aware Representation Learning for Zero-Shot Sketch-Based Image Retrieval},
    booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
    month     = {January},
    year      = {2024},
    pages     = {5646-5655}