This is the official PyTorch implementation of our CNSN paper, in which we propose CrossNorm (CN) and SelfNorm (SN), two simple, effective, and complementary normalization techniques to improve generalization robustness under distribution shifts.
@article{tang2021cnsn,
title={CrossNorm and SelfNorm for Generalization under Distribution Shifts},
author={Zhiqiang Tang, Yunhe Gao, Yi Zhu, Zhi Zhang, Mu Li, Dimitris Metaxas},
journal={arXiv preprint arXiv:2102.02811},
year={2021}
}
conda create --name cnsn python=3.7
conda activate cnsn
conda install numpy
conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=10.0 -c pytorch
-
Download CIFAR-10-C and CIFAR-100-C datasets with:
mkdir -p ./data curl -O https://zenodo.org/record/2535967/files/CIFAR-10-C.tar curl -O https://zenodo.org/record/3555552/files/CIFAR-100-C.tar tar -xvf CIFAR-100-C.tar -C data/ tar -xvf CIFAR-10-C.tar -C data/
-
Download ImageNet-C with:
mkdir -p ./data/ImageNet-C curl -O https://zenodo.org/record/2235448/files/blur.tar curl -O https://zenodo.org/record/2235448/files/digital.tar curl -O https://zenodo.org/record/2235448/files/noise.tar curl -O https://zenodo.org/record/2235448/files/weather.tar tar -xvf blur.tar -C data/ImageNet-C tar -xvf digital.tar -C data/ImageNet-C tar -xvf noise.tar -C data/ImageNet-C tar -xvf weather.tar -C data/ImageNet-C
We have included sample scripts in cifar10-scripts
, cifar100-scripts
, and imagenet-scripts
.
For example, there are 5 scripts for CIFAR-100 and WideResNet:
-
./cifar100-scripts/wideresnet/run-cn.sh
-
./cifar100-scripts/wideresnet/run-sn.sh
-
./cifar100-scripts/wideresnet/run-cnsn.sh
-
./cifar100-scripts/wideresnet/run-cnsn-consist.sh
(Use CNSN with JSD consistency regularization) -
./cifar100-scripts/wideresnet/run-cnsn-augmix.sh
(Use CNSN with AugMix)
- Pretrained ResNet-50 ImageNet classifiers are available:
- Results of the above 4 ResNet-50 models on ImageNet:
+CN | +SN | +CNSN | +CNSN+IBN+AugMix | |
---|---|---|---|---|
Top-1 err | 23.3 | 23.7 | 23.3 | 22.3 |
mCE | 75.1 | 73.8 | 69.7 | 62.8 |