/Epoch_wise_Double_Descent

Official implementation of "Multi-scale Feature Learning Dynamics: Insights for Double Descent".

Primary LanguageJupyter NotebookMIT LicenseMIT

Multi-scale Feature Learning Dynamics: Insights for Double Descent

This repository contains the official implementation of:

Multi-scale Feature Learning Dynamics: Insights for Double Descent

Blog post

Check out our interactive blog post here.

Requirements:

To ensure reproducibility, we publish the code, saved logs, and expected results of every experiment. We claim that all figures presented in the manuscript can be reproduced using the following requirements:

Python 3.7.10
PyTorch 1.4.0
torchvision 0.5.0
tqdm
matplotlib 3.4.3

BibTeX

@article{pezeshki2021multi,
  title={Multi-scale Feature Learning Dynamics: Insights for Double Descent},
  author={Pezeshki, Mohammad and Mitra, Amartya and Bengio, Yoshua and Lajoie, Guillaume},
  journal={arXiv preprint arXiv:2112.03215},
  year={2021}}

Reproducibility:

ResNet experiments on CIFAR-10

ResNet experiments on CIFAR-10 took 12000 GPU hours on Nvidia V100. The code to manage experiments using the slurm resource management tool is provided in the README available in the ResNet_experiments folder.

To reproduce each figure of the manuscript

python fig1.py: The generalization error as the training time proceeds. (top): The case where only the fast-learning feature or slow-learning feature are trained. (bottom): The case where both features are trained with \kappa=100. fig

python fig2_ab.py: Heat-map of empirical generalization error (0-1 classification error) for the ResNet-18 trained on CIFAR-10 with $15 % label noise. The X-axis denotes the regularization strength, and Y-axis represents the training time. fig

python fig2_cd.py: The same plot with the analytical results of the teacher-student. We observe a qualitative comparison between the ResNet-18 results and our analytical results. fig

python fig3.py: Left: Phase diagram of the generalization error as a function of R(t) and Q(t). The trajectories describe the evolution of R(t) and Q(t) as training proceeds. Each trajectory corresponds to a different $\kappa$, the condition number of the modulation matrix where it describes the ratio of the rates at which two sets of features are learned. Right: The corresponding generalization curves for different plotted over the training time axis. fig

Match between theory and experiments

Here, we validate our analytical results by comparing the following three methods:

1. Emperical gradient descent
2. Analytical results - the exact general case (Eq. 9 substituted into Eq. 6):
3. Analytical results - the approximate fast-slow case (Eqs. 12, 14 substituted into Eq. 6):

python extra_experiments/emp_vs_analytic.py fig

Previous experiments with different setups

We also provide further experiments where we vary the following variables:

n: number of training examples
d: number of total dimensions
p: number of fast learning dimensions

Four variants of fig1 fig

Four variants of fig3 fig

Interactive notebook

To try different setups, check out the following colab notebook: Link