/mela

Learn a global representation for few-shot learning without needing global classes

Primary LanguagePythonMIT LicenseMIT

Readme

Introduction

This is the repository accompanying the paper “Robust Meta-Representation Learning via Global Label Inference and Classification” containing the code necessary to run the different steps of the pipeline as outlined in the appendix.

Organization

Workspace directory

The code rely on setting / exporting the shell variable variable WORKSPACE to point to a directory which house the saves and datasets. A suggestion is to either put this in the home directory or wherever you keep your datasets. The workspace directory have two sub-directories checkpoint in which we keep the model weights for the pipeline and metaL_data which stores all of the base datasets we use.

Below is an example of the populated workspace directory output from tree -d

You can initialize this directory in your home directory running

make initialize_workspace

Pipeline

The pipeline of steps that implements the model is instantiated by running the following order of the files in the top directory of the repo

  1. learn_meta_repr.py
  2. learn_labeler.py
  3. fine_tune.py

For evaluating the models run meta_eval.py and to train the supervised oracle models run sup_baseline.py. See Pipeline steps for details on how to do this correctly.

config

The default arguments of the pipeline steps are defined in the corresponding yaml file of the same base name as the pipeline python file in the config directory. We use hydra as the framework.

dataset

Contains the dataset functionality including generating the datasets.

models

Defines the models and loading functionality.

logs

Contains log files which is written to when running a pipeline step.

Data

Few-shot datasets

We have provided a makefile with make rules for downloading and untarring all of the datasets used. The datasets are provided in a form that is ready to use with the code as long as they can be found in the right directories. The makefiles assumes that the command line tools wget and tar is available on your platform. If this is not the case, you can download and untar them manually using this link to the data of the repository.

To download the datasets, please run

make download_all

and to untar them into the right directory, run

make unpack_datasets

mini-60 and tiered-780

The following datasets were generated using the code in /dataset/ximagenet-tools which can produce few-shot base datasets derived from ImageNet on the fly (built off mini-imagenet-tools). In particular we made sure that the meta-train sets of mini-60 and tiered-780 do not overlap with the test sets of /mini/ImageNet and /tiered/ImageNet. If you want to generate these datasets or other ImageNet-derived datasets, see the sub-project README.md.

mixed and H-aircraft

To generate the datasets in mixed we rely on the initial data preparation pipeline of the meta-dataset repo. In order to generate the necessary sub-datasets follow the instructions for

to generate the datasets in the form of a directory with the data format being tfrecords. Move the resulting datasets in $RECORDS to $WORKSPACE/metaL_data using

mv -v $RECORDS/* $WORKSPACE/metaL_data

Now, the datasets may be processed into the right form by running

python $MELA/dataset/tfrecords_to_pickle.py --dataset $DATASET

To create the H-aircraft dataset, you just need to run tfrecords_to_pickle.py with argument $DATASET = aircraft. The resulting datasets will be found under the respective directories in $WORKSPACE in terms of pickle files which will be loaded by the pipeline.

Dependencies

Makefile

The makefile contains functionality for initializing the workspace directory and downloading and unpacking all of the datasets. It assumes that the OS you use has the command line tools wget, and tar installed.

Python

We provide a conda env.yml file which allows to install the necessary packages. Note that the code requires a GPU and a CPU with enough RAM to run.

Running the code

Running the pipeline steps relies on using the same dataset at each step together with other necessary arguments that should be consistent throughout (for example the architecture used). Most steps will produce artifacts in terms of log files, found in logs together with models saved with torch output to $WORKSPACE/checkpoints/label_learn/saves/.

Main arguments

The arguments shared between each pipline step is as follows

logger_name
(string) Base name of the logger file, usually this should be fixed
trial
(string) Numbering scheme of runs (if you run things several times with similar arguments)
dataset
(string) Dataset to use. Options: mixed, h_aircraft, miniImageNet, tieredImageNet, mini60, tiered780
sample_shape
(few_shot / flat) if the dataset should be a few-shot dataset with tasks or a flat supervised learning dataset
fixed_db
(bool) Deterministically sample the tasks of the dataset
no_replacement
(bool) If we are to use the no replacement dataset sampling (called GFSL in the paper)
sim_imbalance
Not used
n_ways
(int) Number of classes in each task
n_shots
(int) Number of samples per class in support set
n_queries
(int) Number of samples per class in query set
val_n_ways
(int) Number of classes during test / validation time. Normally set to n_ways
val_n_shots
(int) Number of samples per class during test / validation time. Normally set to n_shots
model
(string) Architecture to use: resnet12 or resnet18
feat_dim
(int) Dimensionality of feature space, needs to be set according to model used (usually 640)
lam
(float) Regularization strength used in Ridge Regression
train_db_size
(int) Number of tasks / samples in the dataset (overridden by no_replacement or if bigger than the underlying size of the dataset)
test_db_size
(int) Number of tasks to validate over for each sub-dataset
num_workers
(int) Number of workers used
epochs
(int) Number of Epochs to train for
normalize_lam
Not used
data_aug
(bool) Whether to use data augmentation
rotate_aug
(bool) Whether to use rotation augmentation
SGD
(bool) Whether to use SGD to Adam / AdamW
learning_rate
(float) Learning rate
lr_decay_epochs
(string) String of the form “e1,e2,…,em” where each “ei” is an epoch that we multiply the learning rate by lr_decay_rate (below)
lr_decay_rate
(float) Decay rate used for annealing the learning rate
weight_decay
(float) Weight decay to use
momentum
(float) Momentum parameter of the optimization algorithm
test_C
(float) C (inverse regularization strength) to use in the logistic regression classifier during test / validation time
use_bias
(bool) Include bias in logistic regression classifier
is_norm
(bool) Normalize each feature mapped instance to have unit norm at test time
progress
(bool) Show progress bar

Pipeline steps

learn_meta_repr.py

Learn representation using only the locally available labels of each task. Specific arguments

pretrained_model
(string) Name of file of saved model to load (normally not used in this step)

This step will output a log file to the log directory and a saved model file to the save directory in the $WORKSPACE.

learn_labeler.py

Using the learned representation (in the form of the saved model file) from learn_meta_repr.py, learn a labeler in order to label the few-shot dataset to infer labels and thus a flat dataset and use this to train a model using supervised learning. Specific arguments

label_recovery_model
(string) Same as model, the model used in the learn_meta_repr.py step
train_model
(string) The model to use when training after we have inferred a flat dataset
pretrained_labeler
(string) Name of file of saved feature map model learned in learn_meta_repr.py to load for inferring labels using labeler
pretrained_centroids
(string) Name of file of saved centroids if we are to load these directly from a previous run of this step
pretrained_model
Not used
K
(int) Number of initial centroids to use for the labelling algorithm
std_factor
(float) Aggression factor (in terms of standard deviation) for how aggressively we should prune centroids in labelling algorithm
data_aug
(bool) Whether to use data augmentation for labelling step (should be set to the value of the arguments used to produce the saved model from learn_meta_repr.py, usually false)
rotate_aug
(bool) Whether to use rotation augmentation for labelling step (should be set to the value of the arguments used to produce the saved model from learn_meta_repr.py, usually false)
sup_data_aug
(boot) Whether to use data augmentation when doing supervised training using the inferred flat dataset
sup_rotate_aug
(boot) Whether to use rotation augmentation when doing supervised training using the inferred flat dataset

This step will output a log file to the log directory and a saved model and centroids file to the save directory in the $WORKSPACE.

fine_tune.py

Fine tune the model in learn_labeler.py using a residual MLP upon the frozen feature map output from learn_labeler.py. Specific arguments

This step will output a log file to the log directory and a saved model file to the save directory in the $WORKSPACE.

meta_eval.py

Evaluate a saved model on 5-way, 1-shot / 5-shot few-shot setting of a dataset. Specific arguments

pretrained_model
(string) Name of file of saved feature map model to evaluate

Outputs the results to a log file.

sup_baseline.py

Train a supervised baseline that have access to the true labels.

Cite us

@article{wang_2023,
	doi = {10.1109/tpami.2023.3328184},
	url = {https://doi.org/10.1109%2Ftpami.2023.3328184},
	year = 2023,
	publisher = {Institute of Electrical and Electronics Engineers ({IEEE})},
	pages = {1--16},
	author = {Ruohan Wang and John Isak Texas Falk and Massimiliano Pontil and Carlo Ciliberto},
	title = {Robust Meta-Representation Learning via Global Label Inference and Classification},
	journal = {{IEEE} Transactions on Pattern Analysis and Machine Intelligence}
}

License

MIT