/BatchFormer

CVPR2022, BatchFormer: Learning to Explore Sample Relationships for Robust Representation Learning

Primary LanguagePython

BatchFormer: Learning to Explore Sample Relationships for Robust Representation Learning

Introduction

This is the official PyTorch implementation of BatchFormer for Long-Tailed Recognition, Domain Generalization, Compositional Zero-Shot Learning, Contrastive Learning.

Sample Relationship Exploration for Robust Representation Learning

Main Results

Long-Tailed Recognition

ImageNet-LT
All(R10) Many(R10) Med(R10) Few(R10) All(R50) Many(R50) Med(R50) Few(R50)
RIDE(3 experts)[1] 44.7 57.0 40.3 25.5 53.6 64.9 50.4 33.2
+BatchFormer 45.7 56.3 42.1 28.3 54.1 64.3 51.4 35.1
PaCo[2] - - - - 57.0 64.8 55.9 39.1
+BatchFormer - - - - 57.4 62.7 56.7 42.1
iNaturalist 2018
All Many Medium Few
RIDE(3 experts) 72.5 68.1 72.7 73.2
+BatchFormer 74.1 65.5 74.5 75.8

Contrastive Learning

Epochs Top-1 Pretrained
MoCo-v2[3] 200 67.5
+BatchFormer 200 74.1 download
MoCo-v3[4] 100 68.9
+BatchFormer 100 69.9 download

Domain Generalization

ResNet-18
PACS VLCS OfficeHome Terra
SWAD[5] 82.9 76.3 62.1 42.1
+BatchFormer 83.7 76.9 64.26 44.8

Compositional Zero-Shot Learning

MIT-States(AUC) MIT-States(HM) UT-Zap50K(AUC) UT-Zap50K(HM) C-GQA(AUC) C-GQA(HM)
CGE*[6] 6.3 20.0 31.5 46.5 3.7 14.9
+BatchFormer 6.7 20.6 34.6 49.0 3.8 15.5

Reference

  1. Long-tailed recognition by routing diverse distribution-aware experts. In ICLR, 2021
  2. Parametric contrastive learning. In ICCV, 2021
  3. Improved baselines with momentum contrastive learning.
  4. An empirical study of training self-supervised vision transformers. In CVPR, 2021
  5. Domain generalization by seeking flat minima. In NeurIPS, 2021.
  6. Learning graph embeddings for compositional zero-shot learning. In CVPR, 2021

PyTorch Code

The proposed BatchFormer can be implemented with a few lines as follows,

def BatchFormer(x, y, encoder, is_training):
    # x: input features with the shape [N, C]
    # encoder: TransformerEncoderLayer(C,4,C,0.5)
    if not is_training:
        return x, y
    pre_x = x
    x = encoder(x.unsqueeze(1)).squeeze(1)
    x = torch.cat([pre_x, x], dim=0)
    y = torch.cat([y, y], dim=0)
    return x, y

Citation

If you find this repository helpful, please consider cite:

@inproceedings{hou2022batch,
    title={BatchFormer: Learning to Explore Sample Relationships for Robust Representation Learning},
    author={Hou, Zhi and Yu, Baosheng and Tao, Dacheng},
    booktitle={CVPR},
    year={2022}
}

Feel free to contact "zhou9878 at uni dot sydney dot edu dot au" if you have any questions.