/dsn_fewshot

Primary LanguagePythonMIT LicenseMIT

Adaptive Subspaces for Few-Shot Learning

The repository contains the code for:
Adaptive Subspaces for Few-Shot Learning
CVPR 2020

Our pipeline:

Comparison with previous methods:

Robustness on toy data: subspaces VS prototypes

OVERVIEW

Requirements:

  • PyTorch 1.0 or above
  • Python 3.6

There are two backbones separated in different folders.

  • Conv-4, there are two datasets using this backbone: mini-ImageNet and OpenMIC.
  • ResNet-12, there are three datasets using this backbone: mini-ImageNet, tiered-ImageNet, and Cifar-FS.

DATASET

** Adopted from Kwonjoon Lee

USAGE

Conv-4

Train mini-ImageNet:

python3 train_subspace_discriminative.py --data-path 'yourdatafolder'

Evaluate mini-ImageNet:

python3 test_subspace_discriminative.py --data-path 'yourdatafolder'

Train OpenMIC:

python3 train_subspace_museum.py --data-path 'yourdatafolder'

ResNet-12

Note: Training using ResNet-12 requires 4 GPUs with ~10GB/GPU

Set the image folders:

_IMAGENET_DATASET_DIR = './miniimagenet/' (in data/mini_imagenet.py)
_TIERED_IMAGENET_DATASET_DIR = '/tieredimagenet/' (in data/tiered_imagenet.py)
_CIFAR_FS_DATASET_DIR = './cifar/CIFAR-FS/' (in data/CIFAR_FS.py)

Train mini-ImageNet:

  python3 train.py --gpu 0,1,2,3 --save-path "./experiments/miniImageNet_subspace" --train-shot 15 \
  --head Subspace --network ResNet --dataset miniImageNet --eps 0.1

Evaluate mini-ImageNet:

  python3 test.py --gpu 0 --load ./experiments/miniImageNet_subspace/best_model.pth --episode 1000 \
  --way 5 --shot 5 --query 15 --head Subspace --network ResNet --dataset miniImageNet
options --dataset [miniImageNet, tieredImageNet, CIFAR_FS]

Citation:

@inproceedings{simon2020dsn,
        author       = {C. Simon}, {P. Koniusz}, {R. Nock}, and {M. Harandi}
        title        = {Adaptive Subspaces for Few-Shot Learning},
        booktitle    = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}
        year         = 2020
        }

Acknowledgement

Thank you for the codebases:

Prototypical Network

MetaOpt