/TRPN

Primary LanguagePythonMIT LicenseMIT

TRPN

Introduction

A pytorch implementation of the IJCAI2020 paper "Transductive Relation-Propagation Network for Few-shot Learning". The code is based on Edge-labeling Graph Neural Network for Few-shot Learning

Author: Yuqing Ma, Shihao Bai, Shan An, Wei Liu, Aishan Liu, Xiantong Zhen and Xianglong Liu

Abstract: Few-shot learning, aiming to learn novel concepts from few labeled examples, is an interesting and very challenging problem with many practical advantages. To accomplish this task, one should concentrate on revealing the accurate relations of the support-query pairs. We propose a transductive relation-propagation graph neural network (TRPN) to explicitly model and propagate such relations across support-query pairs. Our TRPN treats the relation of each support-query pair as a graph node, named relational node, and resorts to the known relations between support samples, including both intra-class commonality and inter-class uniqueness, to guide the relation propagation in the graph, generating the discriminative relation embeddings for support-query pairs. A pseudo relational node is further introduced to propagate the query characteristics, and a fast, yet effective transductive learning strategy is devised to fully exploit the relation information among different queries. To the best of our knowledge, this is the first work that explicitly takes the relations of support-query pairs into consideration in few-shot learning, which might offer a new way to solve the few-shot learning problem. Extensive experiments conducted on several benchmark datasets demonstrate that our method can significantly outperform a variety of state-of-the-art few-shot learning methods.

Requirements

  • Python 3
  • Python packages
    • pytorch 1.0.0
    • torchvision 0.2.2
    • matplotlib
    • numpy
    • pillow
    • tensorboardX

An NVIDIA GPU and CUDA 9.0 or higher.

Getting started

mini-ImageNet

You can download miniImagenet dataset from here.

tiered-ImageNet

You can download tieredImagenet dataset from here.

Because WRN has a large amount of parameters. You can save the extracted feature before the classifaction layer to increase train or test speed. Here we provide the features extracted by WRN:

You also can use our pretrained WRN model to generate features for mini or tiered by yourself

Training

# ************************** miniImagenet, 5way 1shot  *****************************
$ python3 conv4_train.py --dataset mini --num_ways 5 --num_shots 1 
$ python3 WRN_train.py --dataset mini --num_ways 5 --num_shots 1 

# ************************** miniImagenet, 5way 5shot *****************************
$ python3 conv4_train.py --dataset mini --num_ways 5 --num_shots 5 
$ python3 WRN_train.py --dataset mini --num_ways 5 --num_shots 5 

# ************************** tieredImagenet, 5way 1shot *****************************
$ python3 conv4_train.py --dataset tiered --num_ways 5 --num_shots 1 
$ python3 WRN_train.py --dataset tiered --num_ways 5 --num_shots 1 

# ************************** tieredImagenet, 5way 5shot *****************************
$ python3 conv4_train.py --dataset tiered --num_ways 5 --num_shots 5 
$ python3 WRN_train.py --dataset tiered --num_ways 5 --num_shots 5 

# **************** miniImagenet, 5way 5shot, 20% labeled (semi) *********************
$ python3 conv4_train.py --dataset mini --num_ways 5 --num_shots 5 --num_unlabeled 4

You can download our pretrained model from here to reproduce the results of the paper.

Testing

# ************************** miniImagenet, Cway Kshot *****************************
$ python3 conv4_eval.py --test_model your_path --dataset mini --num_ways C --num_shots K 
$ python3 WRN_eval.py --test_model your_path --dataset mini --num_ways C --num_shots K