This is the code for the paper "A Geometric Analysis of Neural Collapse with Unconstrained Features".
- We provide the first global optimization landscape analysis of Neural Collapse (NC) – an intriguing empirical phenomenon that arises in the last-layer classifiers and features of neural networks during the terminal phase of training.
- We study the problem based on a simplified unconstrained feature model, which isolates the topmost layers from the classifier of the neural network. In this context, we show that the cross-entropy loss with weight decay has a benign global landscape: the only global minimizers are the Simplex Equiangular Tight Frames (ETFs) while all other critical points are strict saddles whose Hessian exhibit negative curvature directions.
- Our experiments demonstrate that one may fix the last-layer classifier to be a Simplex ETF with
d = K
for network training, which reduces memory cost by over 20% on ResNet18 without sacrificing the generalization performance.
- CUDA 11.0
- python 3.8.3
- torch 1.6.0
- torchvision 0.7.0
- scipy 1.5.2
- numpy 1.19.1
By default, the code assumes the datasets for MNIST and CIFAR10 are stored under ~/data/
. If the datasets are not there, they will be automatically downloaded from torchvision.datasets
. User may change this default location of datasets in args.py
through the argument --data_dir
.
$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer SGD --batch_size 256 --lr 0.05
$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer Adam --batch_size 64 --lr 0.001
$ python train_2nd_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer LBFGS --lr 0.1 --history_size 10 --batch_size 2048
Note: For each epoch during training, a model will be saved under the directory model_weights/<the uid name fed to the above commands>/
for the purpose of validating the NC phenomenon in the future.
There are many other training options, e.g., --epochs
, --weight_decay
and so on, can be found in args.py
.
$ python validate_NC.py --gpu_id 0 --dataset <mnist or cifar10> --batch_size 256 --load_path <path to the uid name>
After training, by running the above command, we are able to calculate the four NC metrics defined in the paper. All the information of the NC metrics will be saved in an output file named info.pkl
.
Finally, the evolutions of the NC metrics as well as the training/testing accuracy can be visualized by plotting them in figures:
$ python plot.py
Note: Please refer to plot.py
for the details of plotting each figure in the paper.
$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset cifar10_random --optimizer SGD --batch_size 64 --lr 0.01 --model <MLP or ResNet18_adapt> --width <specify width for model> --depth <specify depth for MLP> --weight_decay 1e-4
$ python validate_NC.py --gpu_id 0 --dataset cifar10_random --batch_size 1000 --load_path <path to the uid name> --model <MLP or ResNet18_adapt> --width <specify width for model> --depth <specify depth for MLP>
$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer SGD --batch_size 64 --lr 0.05 --model <specify model> --weight_decay <specify weight decay> --sep_decay --feature_decay_rate <specify weight decay on features>
$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer SGD --lr 0.05 --ETF
$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer SGD --lr 0.05 --fixdim 10
$ python train_1st_order.py --gpu_id 0 --uid <saving directory name> --dataset <mnist or cifar10> --optimizer SGD --lr 0.05 --SOTA
For technical details and full experimental results, please check our paper.
@article{zhu2021geometric,
title={A Geometric Analysis of Neural Collapse with Unconstrained Features},
author={Zhihui Zhu and Tianyu Ding and Jinxin Zhou and Xiao Li and Chong You and Jeremias Sulam and Qing Qu},
year={2021},
eprint={2105.02375},
archivePrefix={arXiv},
primaryClass={cs.LG}
}