This repository contains the code to train neural nets and compute various measures/norms reported in the following paper:
Towards Understanding the Role of Over-Parametrization in Generalization of Neural Networks
Behnam Neyshabur, Zhiyuan Li, Srinadh Bhojanapalli, Yann LeCun, Nathan Srebro
- Install Python 3.6 and PyTorch 0.4.1.
- Clone the repository:
git clone https://github.com/bneyshabur/over-parametrization.git
- As a simple example, the following command trains a two layer fully connected feedforward network with 1000 hidden units on CIFAR10 dataset and then computes several measures/norms on the learned network:
python main.py --dataset CIFAR10 --nunits 1000
--no-cuda
: disables cuda training--datadir
: path to the directory that contains the datasets (default: datasets)--dataset
: name of the dataset(options: MNIST | CIFAR10 | CIFAR100 | SVHN, default: CIFAR10). If the dataset is not in the desired directory, it will be downloaded.--nunits
: number of hidden units (default: 1024)
After training the network, several norms/measures will be computed and reported on the trained network. Please see the file measures.py
for explanation of each measure. We also compute and report the following generalization bounds:
VC bound
: Generalization bound based on the VC dimension by Harvey et al. 2017L1max bound
: Generalization bound by Bartlett and Mendelson 2002Fro bound
: Generalization bound by Neyshabur et al. 2015Spec_L1 bound
: Generalization bound by Bartlett et al. 2017Spec_Fro bound
: Generalization bound by Neyshabur et al. 2018Our bound
: The Generalization bound proposed in this paper