/Saddle-LongTail

[NeurIPS 2022] Source code for our paper "Escaping Saddle Points for Effective Generalization on Class-Imbalanced Data"

Primary LanguagePythonMIT LicenseMIT

Escaping Saddle Points for Effective Generalization on Class-Imbalanced Data

Harsh Rangwani*, Sumukh K Aithal*, Mayank Mishra, R. Venkatesh Babu

This is the official PyTorch implementation for our NeurIPS'22 paper: Escaping Saddle Points for Effective Generalization on Class-Imbalanced Data [OpenReview] [arXiv]

Abstract

Long Tail Saddle Points

Real-world datasets exhibit imbalances of varying types and degrees. Several techniques based on re-weighting and margin adjustment of loss are often used to enhance the performance of neural networks, particularly on minority classes. In this work, we analyze the class-imbalanced learning problem by examining the loss landscape of neural networks trained with re-weighting and margin based techniques. Specifically, we examine the spectral density of Hessian of class-wise loss, through which we observe that the network weights converges to a saddle point in the loss landscapes of minority classes. Following this observation, we also find that optimization methods designed to escape from saddle points can be effectively used to improve generalization on minority classes. We further theoretically and empirically demonstrate that Sharpness-Aware Minimization (SAM), a recent technique that encourages convergence to a flat minima, can be effectively used to escape saddle points for minority classes. Using SAM results in a 6.2% increase in accuracy on the minority classes over the state-of-the-art Vector Scaling Loss, leading to an overall average increase of 4% across imbalanced datasets.

TLDR: Tail class loss landscape converges to a saddle point in imbalanced datasets and SAM can effectively escape from these solutions.

Getting started

  • Requirements

    • pytorch 1.9.1
    • torchvision 0.10.1
    • wandb 0.12.2
    • timm 0.5.5
    • prettytable 2.2.0
    • scikit-learn
    • matplotlib
    • tensorboardX
  • Installation

git clone https://github.com/val-iisc/Saddle-LongTail.git
cd Saddle-LongTail
pip install -r requirements.txt

We use Weights and Biases (wandb) to track our experiments and results. To track your experiments with wandb, create a new project with your account. The project and entity arguments in wandb.init must be changed accordingly in .py file for each experiment. To disable wandb tracking, the log_results flag can be removed.

  • Datasets

    The datasets used in the repository can be downloaded by following instructions from the following links: The CIFAR datasets are automatically downloaded to the data/ folder if it is not available.

Training

Sample command to train CIFAR-10 LT dataset with CE+DRW+SAM.

python cifar_train_sam.py --gpu 0 --imb_type exp --imb_factor 0.01 --loss_type LDAM --train_rule DRW --rho 0.8 --rho_schedule none --log_results --dataset cifar10 --seed 0

Sample command to train ImageNet-LT dataset with LDAM+DRW+SAM.

python imnet_train_sam.py --gpu 0 --imb_type exp --imb_factor 0.01 --data_path <Path-to-Dataset> --loss_type LDAM --train_rule DRW --dataset imagenet -b 256 --epochs 90 --arch resnet50 --cos_lr --rho_schedule step --lr 0.2 --seed 0 --rho_steps 0.05 0.1 0.5 0.5 --log_results --wd 2e-4 --margin 0.3

All the commands to reproduce the experiments are available in run.sh

Results and Checkpoints of Models

We show results on CIFAR-10 LT, CIFAR-100 LT, ImageNet-LT and iNaturalist-18 dataset. Complete results is available in the paper.

Dataset Method Accuracy Checkpoints
CIFAR-10 LT (IF=100) LDAM+DRW+SAM 81.9 ckpt
CE+DRW+SAM 80.6 ckpt
CIFAR-100 LT (IF=100) LDAM+DRW+SAM 45.4 ckpt
CE+DRW+SAM 44.6 ckpt
ImageNet-LT LDAM+DRW+SAM 53.1 ckpt
CE+DRW+SAM 47.1 ckpt
iNaturalist-18 LDAM+DRW+SAM 70.1 ckpt
CE+DRW+SAM 65.3 ckpt

Class-Wise Hessian Analysis

We also release the code to compute the spectral density and analyse the loss landscape of the trained models.

Sample command below:

python hessian_analysis.py --gpu 0 --seed 1 --exp_str sample --resume <checkpoint_path> --dataloader_hess train --log_results

On running this command, the Eigen Spectral density of per-class loss is computed and the class-wise spectral density is plotted along with the maximum eigenvalue and the trace of the Hessian.

Overview of the arguments

Generally, all python scripts in the project take the following flags

  • -a: Architecture of the backbone. (resnet32|resnet50)
  • --dataset: Dataset (cifar10|cifar100)
  • --imb_type: Imbalance Type (Exp|Step).
  • --imb_factor: Imbalance Factor (Ratio of samples in the minority class to majority class). Default: 0.01
  • --epochs: Number of Epochs to be trained for. Default 200.
  • --loss_type: Loss Type (CE|LDAM|VS)
  • --gpu: GPU id to use.
  • --rho: $\rho$ value in SAM (Applicable to SAM runs).

Acknowledgement

Our implementation is based on the LDAM and VS-Loss. We use the PyTorch implementation of SAM from https://github.com/davda54/sam. We refer to PyHessian for computation of the Eigen Spectral density and the loss landscape analysis. We thank the authors for releasing their source-code publicly.

Citation

If you find our paper or codebase useful, please consider citing us as:

@inproceedings{
rangwani2022escaping,
title={Escaping Saddle Points for Effective Generalization on Class-Imbalanced Data},
author={Harsh Rangwani and Sumukh K Aithal and Mayank Mishra and Venkatesh Babu Radhakrishnan},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=9DYKrsFSU2}
}