Multi-scale Feature Learning Dynamics: Insights for Double Descent
This repository contains the official implementation of:
Multi-scale Feature Learning Dynamics: Insights for Double Descent
Blog post
Check out our interactive blog post here.
Requirements:
To ensure reproducibility, we publish the code, saved logs, and expected results of every experiment. We claim that all figures presented in the manuscript can be reproduced using the following requirements:
Python 3.7.10
PyTorch 1.4.0
torchvision 0.5.0
tqdm
matplotlib 3.4.3
BibTeX
@article{pezeshki2021multi,
title={Multi-scale Feature Learning Dynamics: Insights for Double Descent},
author={Pezeshki, Mohammad and Mitra, Amartya and Bengio, Yoshua and Lajoie, Guillaume},
journal={arXiv preprint arXiv:2112.03215},
year={2021}}
Reproducibility:
ResNet experiments on CIFAR-10
ResNet experiments on CIFAR-10 took 12000 GPU hours on Nvidia V100. The code to manage experiments using the slurm
resource management tool is provided in the README available in the ResNet_experiments
folder.
To reproduce each figure of the manuscript
python fig1.py
:
The generalization error as the training time proceeds. (top): The case where only the fast-learning feature or slow-learning feature are trained. (bottom): The case where both features are trained with \kappa=100.
python fig2_ab.py
:
Heat-map of empirical generalization error (0-1 classification error) for the ResNet-18 trained on CIFAR-10 with $15 % label noise. The X-axis denotes the regularization strength, and Y-axis represents the training time.
python fig2_cd.py
:
The same plot with the analytical results of the teacher-student. We observe a qualitative comparison between the ResNet-18 results and our analytical results.
python fig3.py
:
Left: Phase diagram of the generalization error as a function of R(t) and Q(t). The trajectories describe the evolution of R(t) and Q(t) as training proceeds. Each trajectory corresponds to a different
Match between theory and experiments
Here, we validate our analytical results by comparing the following three methods:
1. Emperical gradient descent
2. Analytical results - the exact general case (Eq. 9 substituted into Eq. 6):
3. Analytical results - the approximate fast-slow case (Eqs. 12, 14 substituted into Eq. 6):
python extra_experiments/emp_vs_analytic.py
Previous experiments with different setups
We also provide further experiments where we vary the following variables:
n: number of training examples
d: number of total dimensions
p: number of fast learning dimensions
Interactive notebook
To try different setups, check out the following colab notebook: Link