/FocusNet-Classifying-better-by-focusing-on-confusing-classes

The PyTorch implementation of our Pattern Recognition 2022 paper, FocusNet, on ILSVRC2012

FocusNet

The implementation of our Pattern Recognition 2022 paper: "FocusNet: Classifying better by focusing on confusing classes"

Paper: https://www.sciencedirect.com/science/article/abs/pii/S003132032200190X?via%3Dihub

Note:

  • This repository mainly relies on "ImageNet training in PyTorch". Therefore, it is helpful for you to refer to its document.
  • The first version of our architecture was named ClonalNet, and after the second revision we changed its name to FocusNet. Therefore, the following clonalnet is just focusnet.

ImageNet training in PyTorch

This implements training of popular model architectures, such as ResNet, AlexNet, and VGG on the ImageNet dataset.

Requirements

  • Install PyTorch (pytorch.org)
  • pip install -r requirements.txt
  • Note: the requirements.txt in this repository is not the same as the official requirements. If something goes wrong, please use the official requirements.
  • Download the ImageNet dataset from http://www.image-net.org/

Training

To train our network, run clonalnet_main.py with the desired model architecture and the path to the ImageNet dataset:

python clonalnet_main.py --data /path/to/ILSVRC2012 -a resnet18 --seed 42 --gpu 0  -ebc
                                                       resnet34
                                                       mobilenet_v2

The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs.

Validation

To evaluate our network, run clonalnet_main.py with the desired model architecture and the path to the ImageNet dataset:

python clonalnet_main.py --data /path/to/ILSVRC2012 -a resnet18 --seed 42 --gpu 0  -ebc -e --resume clonalnet_resnet18_model_best.pth.tar
                                                       resnet34                                     clonalnet_resnet34_model_best.pth.tar
                                                       mobilenet_v2                                 clonalnet_mobilenet_v2_model_best.pth.tar

Logs

The clonal_resnet18_from_scratch.log and the clonal_resnet34_from_scratch.log are the training logs of the clonalnet_resnet18 and the clonalnet_resnet34.

Baseline

To validate the baseline results, please run:

# resnet18 / resnet34
python main.py --paradigm baseline --data /path/to/ILSVRC2012 -a resnet18 --seed 10 -e --pretrained --gpu 0
                                                                 resnet34
# mobilenet_V2
python main.py --paradigm baseline --data /path/to/ILSVRC2012 -a mobilenet_v2 --seed 10 -e --pretrained --gpu 0 --resume models/_pytorch_pretrained_checkpoints/baseline_mobilenet_v2_model_best.pth.tar

Results on ILSVRC2012

Models Acc@1 Acc@5 Checkpoint
ResNet18 69.760 89.082 PyTorch Pre-trained
ClonalNet (r18) 70.422 89.562 Baidu, code:1234; Google Driver
ResNet34 73.310 91.420 PyTorch Pre-trained
ClonalNet (r34) 74.366 91.884 Baidu, code:1234; Google Driver
MobileNet_v2 65.558 86.744 Baidu, code:1234; Google Driver
ClonalNet (MobileNet_v2) 66.300 87.232 Baidu; Google Driver

you can also download more checkpoints at here: Baidu, code: 1234; Google Driver.

Reference

If you find our work is helpful to you, please cite it:

@article{zhang2022focusnet,
  title={FocusNet: Classifying better by focusing on confusing classes},
  author={Zhang, Xue and Sheng, Zehua and Shen, Hui-Liang},
  journal={Pattern Recognition},
  pages={108709},
  year={2022},
  publisher={Elsevier}
}