/multi-label-gcns

Multi-label image recognition with Graph Convolution Network and its variants

Primary LanguageJupyter Notebook

Multi-label Image Recognition with GCN and its Variants

Python PyTorch Lightning Config: Hydra Template

Paper Conference

Overview

This repo contains an implementation (and a few variants) of the paper Multi-label Image Recognition with Graph Convolutional Networks. This repo is created following Lightning Hydra template.

In general, the model is the combination of a CNN-based as the image representation extractor and a GCN-based as the label embedding. Figure 1 describes architecture of the model.

architecture

Figure 1. The overall architecture of the model.

Dataset

Currently we have ready-to-use VOCDectection pre-processor. More datasets will be added soon.

Installation

You should have Python 3.7 or higher. I highly recommend creating a virual environment like venv or Conda. For example:

# clone project
git clone https://github.com/thanhtvt/multi-label-gcns.git
cd multi-label-gcns

# [OPTIONAL] create conda environment
conda create -n mlgcn python=3.8
conda activate mlgcn

# install requirements
pip install -r requirements.txt

Results

These results are not fully optimized. Updates will be added in the (unknown :D) future.

Model Params Dataset Accuracy F1 Checkpoint
ResNet-50 + 2xGCN 25.9M VOC2007 97.8% 85.2% model

🚀 Quick start

Train

To train model with default configuration, run:

# train on CPU
python src/train.py trainer=cpu

# train on GPU
python src/train.py trainer=gpu

To train model with chosen experiment configuration from configs/experiment folder, run:

python src/train.py experiment=multi-label_base

To override any parameter from commandline, run:

python src/train.py logger=csv trainer.max_epochs=10

Evaluate

To evaluate model with default configuration, run:

python src/eval.py

To evaluate model with chosen checkpoint, run:

python src/eval.py ckpt_path=best.ckpt