Hierarchical convolutional neural network with knowledge complementation for long-tailed classification
HONG ZHAO*, ZHENGYU LI, WENWEI HE, and Yan Zhao
- torch == 1.0.1
- torchvision == 0.2.2_post3
- tensorboardX == 1.8
- Python 3
The code is developed using the PyTorch framework. We conduct experiments on a single NVIDIA GeForce RTX 2080 Ti GPU. The CUDA nad CUDNN version is 9.0 and 7.1.3 respectively. Other platforms or GPU cards are not fully tested.
# To train long-tailed CIFAR-10 with imbalanced ratio of 50:
python main/train.py --cfg configs/cifar10.yaml
# To validate with the best model:
python main/valid.py --cfg configs/cifar10.yaml
# To debug with CPU mode:
python main/train.py --cfg configs/cifar10.yaml CPU_MODE True
You can change the experimental setting by simply modifying the parameter in the yaml file.
The annotation of a dataset is a dict consisting of two field: annotations
and num_classes
.
The field annotations
is a list of dict with
image_id
, fpath
, im_height
, im_width
and category_id
.
We provide the BBN pretrain models of both 1x scheduler and 2x scheduler for iNaturalist 2018 and iNaturalist 2017.
iNaturalist 2018: Baidu Cloud, Google Drive
iNaturalist 2017: Baidu Cloud, Google Drive
The experimental setup was as follows:
python main.py --dataset cifar10 -a resnet32 --num_classes 10 --imbanlance_rate 0.01 --beta 0.5 --lr 0.01 --epochs 200 -b 64 --momentum 0.9 --weight_decay 5e-3 --resample_weighting 0.0 --label_weighting 1.2 --contrast_weight 4
Download the datasets CIFAR-10, CIFAR-100, ImageNet, and iNaturalist18 to GLMC-2023/data. The directory should look like
GLMC-2023/data
├── CIFAR-100-python
├── CIFAR-10-batches-py
├── ImageNet
| └── train
| └── val
├── train_val2018
└── data_txt
└── ImageNet_LT_val.txt
└── ImageNet_LT_train.txt
└── iNaturalist18_train.txt
└── iNaturalist18_val.txt
for CIFAR-10-LT
python main.py --dataset cifar10 -a resnet32 --num_classes 10 --imbanlance_rate 0.01 --beta 0.5 --lr 0.01 --epochs 200 -b 64 --momentum 0.9 --weight_decay 5e-3 --resample_weighting 0.0 --label_weighting 1.2 --contrast_weight 1
python main.py --dataset cifar10 -a resnet32 --num_classes 10 --imbanlance_rate 0.02 --beta 0.5 --lr 0.01 --epochs 200 -b 64 --momentum 0.9 --weight_decay 5e-3 --resample_weighting 0.0 --label_weighting 1.2 --contrast_weight 1
python main.py --dataset cifar10 -a resnet32 --num_classes 10 --imbanlance_rate 0.1 --beta 0.5 --lr 0.01 --epochs 200 -b 64 --momentum 0.9 --weight_decay 5e-3 --resample_weighting 0.2 --label_weighting 1 --contrast_weight 2
for CIFAR-100-LT
python main.py --dataset cifar100 -a resnet32 --num_classes 100 --imbanlance_rate 0.01 --beta 0.5 --lr 0.01 --epochs 200 -b 64 --momentum 0.9 --weight_decay 5e-3 --resample_weighting 0.0 --label_weighting 1.2 --contrast_weight 4
python main.py --dataset cifar100 -a resnet32 --num_classes 100 --imbanlance_rate 0.02 --beta 0.5 --lr 0.01 --epochs 200 -b 64 --momentum 0.9 --weight_decay 5e-3 --resample_weighting 0.2 --label_weighting 1.2 --contrast_weight 6
python main.py --dataset cifar100 -a resnet32 --num_classes 100 --imbanlance_rate 0.1 --beta 0.5 --lr 0.01 --epochs 200 -b 64 --momentum 0.9 --weight_decay 5e-3 --resample_weighting 0.2 --label_weighting 1.2 --contrast_weight 4