Progetto di Machine Learning su Few-Shot learning con il metodo di meta learning Prototypical Network.
Impostare nei successivi moduli .py su quale GPU si vuole lavorare
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0" #0 o 1 se si dispone di due GPU
Eseguire il modulo train.py per addestrare e condurre un nuovo esperimento, le successive variabili possono essere modificate per eseguire training in scenari differenti
# .Conv4, .Conv6, .ResNet10, .ResNet18, .ResNet34
model = 'Conv6'
# CUB, omniglot, cross_char
dataset = 'CUB'
# class num to classify for training
train_n_way = 5
# class num to classify for testing (validation)
test_n_way = 5
# number of labeled data in each class, same as n_support
n_shot = 1
# perform data augmentation or not during training
train_aug = True
# Save frequency
save_freq = 50
# Starting epoch
start_epoch = 0
# Stopping epoch
stop_epoch = -1
# continue from previous trained model with largest epoch
resume = False
A termine dell'esecuzione si avrà nella directory checkpoints il miglior modello best_model.tar (non presenti nella repo a causa delle grandi dimensioni).
Regolando le variabili in modo analogo a quanto visto prima in fase di training è necessario estrarre le features dal modello in .tar in un file .hdf5 presente a fine esecuzione nella directory features
Dopo aver addestrato i modelli per CUB nelle configurazioni (1S , 5S)x(NoAugmentation, Augmentation)x(Conv4, Conv6, ResNet10, ResNet18, ResNet34). Quelli per Omniglot con 1Shot, 5Shot sempre con Conv4, Omniglot->EMNIST con 1Shot, 5Shot sempre con Conv4. Infine CUB ResNet18 5W-5S NoAdaptation e CUB ResNet18 5W-5S Adaptation, risulta possibile eseguire il file test.py che mostra in output le seguenti tabelle e grafici riportati in relazione.
ssh://er******@m***.d****.u***.**:*****/u*/b*/python3 -u /home/er******/test.py
Test omniglot, Conv4, n_shot = 1
600 Test Acc = 97.92% +- 0.28%
Test omniglot, Conv4, n_shot = 5
600 Test Acc = 99.37% +- 0.12%
Test cross_char, Conv4, n_shot = 1
600 Test Acc = 73.29% +- 0.78%
Test cross_char, Conv4, n_shot = 5
600 Test Acc = 86.95% +- 0.58%
--------------------------------------------------------------------
Omni 5W-1S Omni 5W-5S Omni->Emnist 5W-1S Omni->Emnist 5W-5S
--------------------------------------------------------------------
97.92 +- 0.28 99.37 +- 0.12 73.29 +- 0.78 86.95 +- 0.58
--------------------------------------------------------------------
Test without augmentiation, Conv4, n_shot = 1
600 Test Acc = 52.16% +- 0.92%
Test without augmentiation, Conv6, n_shot = 1
600 Test Acc = 53.20% +- 0.95%
Test without augmentiation, ResNet10, n_shot = 1
600 Test Acc = 59.97% +- 0.93%
Test without augmentiation, ResNet18, n_shot = 1
600 Test Acc = 61.52% +- 0.97%
Test without augmentiation, ResNet34, n_shot = 1
600 Test Acc = 59.69% +- 0.98%
Test without augmentiation, Conv4, n_shot = 5
600 Test Acc = 67.39% +- 0.72%
Test without augmentiation, Conv6, n_shot = 5
600 Test Acc = 67.71% +- 0.74%
Test without augmentiation, ResNet10, n_shot = 5
600 Test Acc = 72.46% +- 0.73%
Test without augmentiation, ResNet18, n_shot = 5
600 Test Acc = 73.66% +- 0.69%
Test without augmentiation, ResNet34, n_shot = 5
600 Test Acc = 74.43% +- 0.72%
Test with augmentiation, Conv4, n_shot = 1
600 Test Acc = 50.86% +- 0.92%
Test with augmentiation, Conv6, n_shot = 1
600 Test Acc = 65.67% +- 1.01%
Test with augmentiation, ResNet10, n_shot = 1
600 Test Acc = 73.16% +- 0.87%
Test with augmentiation, ResNet18, n_shot = 1
600 Test Acc = 74.18% +- 0.90%
Test with augmentiation, ResNet34, n_shot = 1
600 Test Acc = 74.56% +- 0.92%
Test with augmentiation, Conv4, n_shot = 5
600 Test Acc = 76.37% +- 0.69%
Test with augmentiation, Conv6, n_shot = 5
600 Test Acc = 81.74% +- 0.61%
Test with augmentiation, ResNet10, n_shot = 5
600 Test Acc = 85.83% +- 0.49%
Test with augmentiation, ResNet18, n_shot = 5
600 Test Acc = 86.51% +- 0.51%
Test with augmentiation, ResNet34, n_shot = 5
600 Test Acc = 87.91% +- 0.46%
------------------------------------------------------------------------------------------
Backbone: Conv4 Conv6 ResNet10 ResNet18 ResNet34
------------------------------------------------------------------------------------------
CUB 5W-1S NoAug 52.16 +- 0.92 53.2 +- 0.95 59.97 +- 0.93 61.52 +- 0.97 59.69 +- 0.98
CUB 5W-5S NoAug 67.39 +- 0.72 67.71 +- 0.74 72.46 +- 0.73 73.66 +- 0.69 74.43 +- 0.72
CUB 5W-1S Aug 50.86 +- 0.92 65.67 +- 1.01 73.16 +- 0.87 74.18 +- 0.9 74.56 +- 0.92
CUB 5W-5S Aug 76.37 +- 0.69 81.74 +- 0.61 85.83 +- 0.49 86.77 +- 0.49 87.91 +- 0.46
------------------------------------------------------------------------------------------
Test CUB, ResNet18, n_shot = 5, adaptation = False
600 Test Acc = 86.77% +- 0.49%
Test CUB, ResNet18, n_shot = 5, adaptation = True
600 Test Acc = 86.51% +- 0.50%
--------------------------------------------------------------
CUB ResNet18 5W-5S NoAdaptation CUB ResNet18 5W-5S Adaptation
--------------------------------------------------------------
86.77 +- 0.49 86.51 +- 0.5
--------------------------------------------------------------
Process finished with exit code 0
import numpy as np
import torch
import torch.optim
import glob
import os
import torch.utils.data.sampler
import random
from matplotlib import pyplot as plt
from matplotlib.ticker import PercentFormatter
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional
from PIL import ImageEnhance
import torchvision.transforms as transforms
import json
import h5py
Tutti i datasets di questo progetto sono stati presi da CUB-200-2011, Omniglot,Omniglot2, EMNIST. A supporto delle immagini sono stati realizzati tre file json (base, val, novel) con i campi:
{"label_names": ["class0","class1","..."], "image_names": ["filepath1","filepath2","..."],"image_labels":["l1","l2","l3","..."]}
L'implementazione è stata ripresa e riadattata dalle seguenti repo: Gestione data e backbone, ProtoNet e gestione Omniglot, Struttura e setup
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change. Please make sure to update tests as appropriate.
Edoardo Re, 2021