This repository contains the code for our paper Guided Cluster Aggregation: A Hierarchical Approach to Generalized Category Discovery.
The dependencies can be installed via conda:
conda env create -f environment.yaml
To activate the environment run
conda activate gca
We use the follwoing datasets in our paper:
- CIFAR10
- CIFAR100
- ImageNet-100 (subset of regular ImageNet)
- CUB200-2011
- Stanford Cars
- FGVC Aircraft
We use two pretrainings in our paper: GCD and PromptCAL. While PromptCAL offers pretrained models for download, GCD does not, so you will have to train the GCD models yourself.
We provide experiment configs in the ./config/experiment
folder.
The naming scheme is ./config/experiment/{pretraining}_{dataset}.yaml
.
We'll use the ./config/experiment/promptcal_cub.yaml
config as an example.
To reproduce the results from our paper follow these steps:
In this case just use the link from the PromptCAL repository.
Place it in the directory specified in the config.
The paths given in the config file are relative to the ./src
folder, so in
our case the full path of the pretrained model would be
./pretrained/GCD/promptcal_cub.ckpt
.
To do so, open a shell in the ./src
folder
and run
python evaluate.py trainer=debug experiment=promptcal_cub datamodule.data_dir=[data location] lightning_module.knn_file=null
The nearest neighbor file will be saved in the
./src/evaluation_results/promptcal_cub
folder.
Place it in the directory specified in the config, in this case
./neighbors/cub_promptcal_finetuned.npy
.
To do so, open a shell in the ./src
folder
and run
python train.py trainer=train experiment=promptcal_cub datamodule.data_dir=[data location]
This will run the model using tensorboard logging.
You can also use wandb by adding
trainer/logger=wandb trainer.logger.entity=[entity] trainer.logger.project=[project]
to the command.
You can also run the training on a slurm cluster by adapting
./config/cluster/example.yaml
and ./config/slurm_config/gca_gcd_example.yaml
.
Then run
python train.py cluster=example trainer=train experiment=promptcal_cub datamodule.data_dir=[data location]
Most of the dataset loading code is based on GCD. The code for the VPT-Vit is based on PromptCAL, as well as some code regarding the experiments with less labeled data (e.g. CIFAR100 c10l50). Both repositories are licensed under the MIT license.