/DHVT

This is an official implementation of our NeurIPS 2022 paper "Bridging the Gap Between Vision Transformers and Convolutional Neural Networks on Small Datasets".

Primary LanguagePythonApache License 2.0Apache-2.0

DHVT

This repository contains PyTorch based codes for DHVT, including detailed models, training code and pretrained models for the NeurIPS 2022 paper:

Bridging the Gap Between Vision Transformers and Convolutional Neural Networks on Small Datasets

Zhiying Lu, Hongtao Xie, Chuanbin Liu, Yongdong Zhang

Updates!

Note that we fix the bug when calculating the FLOPs of the models. There are two reasons.

(1) The previous applied toolkit fvcore does not support some operations. And now we change the toolkit to deepspeed, which is more robust. (2) When calculating the FLOPs of CNNs on CIFAR-100, we made mistakes on the resolution. For example, on the previous version of the paper, the image of 32x32 in CIFAR was first pooled to 8x8 and then fed into the first stage of the CNNs. This downsampling operation greatly decreased the performance. The correct way is to remove such downsampling, which means the input resolution of the first stage of CNNs is kept to 32x32. And we now achieve the close results with the original CNN papers. Therefore, the correct FLOPs of CNNs is roughly 16 times as those in our original paper.

The results in the Openreview version and arxiv version have been modified.

Usage

git clone https://github.com/ArieSeirack/DHVT.git

Change directory to the cloned repository by running cd DHVT, install necessary packages, and prepare the datasets.

Preparation

Models are trained using Python3.6 and the following packages

torch==1.9.0
torchvision==0.10.0
timm==0.4.12
tensorboardX==2.4
torchprofile==0.0.4
lmdb==1.2.1
pyarrow==5.0.0
einops==0.4.1

These packages can be installed by running pip install -r requirements.txt.

Data preparation

Download the desired datasets from the following links, and use the scripts in VTs-Drloc to pre-process the DomainNet datasets. For the CIFAR-100 dataset, we recommend using the official dataset reader code from torchvision as in the dataset.py as set the download=True

Dataset Download Link
ImageNet train,val
CIFAR-100 all
Clipart images, train_list, test_list
Infograph images, train_list, test_list
Painting images, train_list, test_list
Quickdraw images, train_list, test_list
Real images, train_list, test_list
Sketch images, train_list, test_list

For the ImageNet-1k dataset, the directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train/ folder and val/ folder respectively.

Except for CIFAR-100, other datasets should be arranged as the following structure:

[dataset_name]
  |__train
  |    |__class1
  |    |    |__www.jpg
  |    |    |__...
  |    |__class2
  |    |    |__xxx.jpg
  |    |    |__...
  |    |__...
  |__val
       |__class1
       |    |__yyy.jpg
       |    |__...
       |__class2
       |    |__zzz.jpg
       |    |__...
       |__...

You can optionally use an LMDB dataset for ImageNet by building it using folder2lmdb.py and passing --use-lmdb to main.py, which may speed up data loading.

Training

We provide three run_code_[dataset].sh file that contains the training hyperparameters.

For example, to train DHVT-Small-CIFAR100-patch4 with 2 GPUs on single node, you can do

CUDA_VISIBLE_DEVICES=0,1 bash run_code_cifar.sh

To train other model variants on other datasets, just follow the above operation. The now variable is to make the directory for output model checkpoints.

Finetuning

Firstly, set the ckpt (the path to the pretrained model checkpoint) and in finetune.sh, and then:

CUDA_VISIBLE_DEVICES=0,1 bash finetune.sh

Pretrained Model Zoo on ImageNet-1k

We provide two DHVT models pretrained on ImageNet 2012.

Method #Params GFLOPs Acc@1 Acc@5 URL
DHVT-T 6.2 1.4 77.6 93.4 (Wait-for-release)
DHVT-S 23.8 5.1 82.3 96.0 (Wait-for-release)

Future Work

  1. Release the pretrained models on ImageNet-1k. (Coming in mid-December)

  2. Recombine the code structure and split the large scripts in vision_transformer.py into multiple smaller ones.

  3. Improve the method to DHVTv2, which is a hierarchical structure and with lower computational costs and higher performance

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Acknowledgement

We would like to thank the authors of DeiT, timm, VTs-Drloc, XCiT, CeiT and mainly EViT, based on which this codebase was built.

Citation

If you use this code for a paper please cite:

@inproceedings{
lu2022bridging,
title={Bridging the Gap Between Vision Transformers and Convolutional Neural Networks on Small Datasets},
author={Zhiying Lu and Hongtao Xie and Chuanbin Liu and Yongdong Zhang},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=bfz-jhJ8wn}
}