This repo is an exploration of some few-shot learning techniques as applied to the Fashion Products
dataset available here. First, we will try to classify the products using a basic model, with some experiments designed to improve classification on rare classes:
To avoid dependencies, we decided not to use a 3rd party "training loop wrapper" library such as fastai or ignite - our basic training loop logic is in the fewshot.trainer
module: trainer.py.
All data processing logic is in the fewshot.data
module: data.py. It relies on the instanciation of a FashionData
class, specifying a top20
flag which will determine if the datasets will contain the 20 most frequent classes, or the remaining rarest classes. This object will run all preprocessing, and instanciate training and test inheriting from pytorch Datasets
.
All training was done with basic train-time augmentation (horizontal flipping and random cropping).
The final training runs were all done in task1.ipynb, and show the training process. We outline below the main results from this notebook.
We first start by training the model on the most frequent classes. The training method and model selection procedure was as follows:
- Train model for 20 epochs
- Select two best models based on class-size weighed accuracy
- Report model that had best top-1 average accuracy across classes
In hindsight, this procedure might have been improved by selecting directly on the desired accuracy metric (which would depend on the business application...) rather than going back on forth on weighed and unweighed accuracy.
We repeated this procedure across four losses:
- unweighed cross entropy
- class-weighed cross entropy
- unweighed focal loss (see Focal Loss for Dense Object Detection)
- class-weighed focal loss
The detailed results are in task1.ipynb - but here is a summary of average top-1 and top-5 accuracies across classes:
Loss | Top-1 Accuracy | Top-5 Accuracy |
---|---|---|
Unweighed cross entropy | 84.4 | 94.1 |
Class-weighed cross entropy | 84.8 | 94.7 |
Unweighed Focal Loss | 84.7 | 94.1 |
Class-weighed Focal Loss | 85.2 | 94.4 |
We therefore selected the class-weighed focal loss model as the best model. With more time I would have liked to check that this variation in result was actually statistically significant.
We then took the best model from the training above, and fine-tuned it on the rare classes using the same procedure as above. We compare the results of the entire procedure to results not using any fine-tuning (all with the weighed focal loss):
Method | Top-1 Accuracy | Top-5 Accuracy | Weighed Top-1 Accuracy |
---|---|---|---|
With fine-tuning | 30.9 | 45.8 | 40.8 |
Without fine-tuning | 17.4 | 39.0 | 27.6 |
Beyond the extensions suggested, here are some things I would have liked to try:
- fix the training and model selectin procedure to rely on a single metric rather than go back and forth between two
- use more data augmentati:on. I almost added some color/saturation variations but then thought it could be problematic for some classes; for example Jeans would become a problem category if I changed their colors. I erred on the cautious side and didn't add it. Similarly affine transformations - some of these items, like watches or lipstics, are quite-geometry dependent so I didn't use any affine transformation.
- change the fine-tuning strategy to a twop-step strategy - first train only the FC layer, than train all layers together
- use a smarter LR scheduling policy, like one-cycle
The idea is to try out three few-shot papers on the fashion products dataset: Prototypical Networks, Matching Networks, and Model-Agnostic Meta-Learning (or MAML).
A series of blog posts with their accompanying github repo show the application on these methods to the two archetypal few-shot datasets, Omniglot and miniImageNet.
For now we've reimplemented the dataset as applied to the Fasion Products dataset and added an episode viz function in fewshot.proto.sampler
. Corresponding exploration notebook is here.
Our first iteration of the training loop, in the exploration notebook, gets to these results with a resnet18 architecture and a 100-dimensional embedding space:
Problem | Top-1 Val. Accuracy |
---|---|
2-shot, 20-way | 72.9 |
1-shot, 20-way | 61.1 |
2-shot, 5-way | 88.4 |
1-shot, 5-way | 84.4 |
Comparing these results to those in the paper, it seems that the 'difficulty' of this fashion dataset is somewhere between Omniglot and miniImageNet. That said, they're extremely good results (maybe too good - check for leakage?).
For now we bundle prototypical and matching networks together, as matching networks is essentially prototypical networks without the averaging and with cosine distance (which the prototypical networks paper shows to be not as good as the Euclidean distance). Would be nice to check this on this dataset!
To improve our results on this task, I'm thinking of a couple of new approaches:
- Better metric for distances on the embedding manifold: The embedding space is definitely not flat - so using plain old L2 distance seems like it wouldn't be the most appropriate. Inspired by Natural Gradient Descent, whose idea is to normalize gradients by the curvature of the space as estimated by the Fisher (see e.g. Martens 2014), we would like to compute the prototypes using a better metric than L2. That said, the empirical methods to compute the Fisher seem to be quite bad Limitations of the Empirical Fisher Approximation.
The idea would then be to resort to a heuristic normalization (maybe a simple diagonal pre-conditioning like in RMSProp) to compute the distances.
- Data augmentation strategies It would be interesting to see the impact of traditional data augmentation strategies. In addition, two interesting areas to explore would be:
- Using adversarial examples as augmentation, as in the semi-supervised approach proposed in Virtual Adversarial Training - this technique had the best results in this evaluation of deep semi-supervised learning
- Using mixup - seems like it would be particularly important for the embedding space to be linearly separable, and using mixup data augmentation would maybe help with that?
That said, I'll probably look at MAML before implementing these two ideas.
Our results on MAML (which are sparse!) are in this notebook. This method is conceptually more interesting than the prototypical networks approach, as we don't really train a network on many examples to produce good 'representations' of the data. Rather, we train a network to be better at learning. This implies taking the gradient of the gradient, and is significantly trickier in terms of implementation. We relied fully on @oscarknagg's implementation here, modified to fit the fashion dataset in maml.py.
Contrary to the prototypical networks example, we have to make the models comply to the task (notably by implementing the functional_forward
method), so it isn't straightforward to use off the shelf resnets - we're using the traditional stacked convolution network typical of these tasks. This might mean it's an unfair comparison of methods.
What I've implemented so far is a rough (but working!) training loop in the notebook. The validation accuracy is definitely increasing but the algorithm is quite slow to train. Will add results here once finished!
Loss | Top-1 Val. Accuracy |
---|---|
5-shot, 10-way | 71.8 after 30 epochs |
Wish I had more time to get some actual comparison numbers for MAML and to implement the two extra ideas in the prototypical networks secion! Might get to it when more time comes by...