/init2winit

Primary LanguagePythonApache License 2.0Apache-2.0

init2winit

A Jax/Flax codebase for running deterministic, scalable, and well-documented deep learning experiments, with a particular emphasis on neural network initialization, optimization, and tuning experiments.

There is not yet a stable version (nor an official release of this library). All APIs are subject to change.

This is a research project, not an official Google product.

Installation

The current development version requires Python 3.6-3.8.

To install the latest development version inside a virtual environment, run

python3 -m venv env-i2w
source env-i2w/bin/activate
pip install --upgrade pip
pip install "git+https://github.com/google/init2winit.git#egg=init2winit"
pip install --upgrade jax jaxlib==0.1.66+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

where cuda111 corresponds to the installed version of CUDA. For more Jax install information see the Jax README.

Usage

An example MNIST experiment can be run with the following command:

python3 main.py \
    --experiment_dir=/tmp/test_mnist \
    --model=fully_connected \
    --dataset=mnist \
    --num_train_steps=10

For local debugging we recommend using the fake dataset:

python3 main.py \
    --experiment_dir=/tmp/test_fake \
    --num_train_steps=10 \
    --dataset=fake \
    --hparam_overrides='{"input_shape": [28, 28, 1], "output_shape": [10]}'

The hparam_overrides accepts a serialized JSON object with hyperparameter names/values to use. See the flags in main.py for more information on possible configurations.

See the dataset_lib and model_lib directories for currently implemented datasets and models.

Citing

To cite this repository:

@software{init2winit2021github,
  author = {Justin M. Gilmer and George E. Dahl and Zachary Nado and Priya Kasimbeg and Sourabh Medapati},
  title = {{init2winit}: a JAX codebase for initialization, optimization, and tuning research},
  url = {http://github.com/google/init2winit},
  version = {0.0.2},
  year = {2023},
}

For a list of references to the models, datasets, and techniques implemented in this codebase, see references.md.

Contributors

Contributors (past and present):

  • Ankush Garg
  • Behrooz Ghorbani
  • Cheolmin Kim
  • David Cardoze
  • George E. Dahl
  • Justin M. Gilmer
  • Michal Badura
  • Priya Kasimbeg
  • Rohan Anil
  • Sourabh Medapati
  • Sneha Kudugunta
  • Varun Godbole
  • Zachary Nado
  • Vlad Feinberg
  • Derrick Xin
  • Naman Agarwal
  • Daniel Suo
  • Bilal Khan
  • Jeremy Cohen
  • Kacper Krasowiak