/optimization-toolkit

Primary LanguagePythonApache License 2.0Apache-2.0

Awesome Optimization Toolkit (Try in a colab)

This library illustrates different optimizers performance on different datasets. It also allows users to add their own datasets and optimizers and compare against existing methods.

Quick links to sections in this page
🔍 Quick Start 📜 Optimizers Implemented 🏁 Leaderboard
🔏 Adding an optimizer 🔏 Adding a dataset 🔏 Adding a model

Quick Start

Run MNIST experiments with these three steps.

1. Install requirements

pip install -r requirements.txt

2. Train and Validate

python trainval.py -e mnist -d results -sb results -r 1 -v results.ipynb

Argument Descriptions:

-e  [Experiment group to run like 'mnist, cifar10, cifar100'] 
-sb [Directory where the experiments are saved]
-d  [Directory where the datasets are saved]
-r  [Flag for whether to reset the experiments]
-j  [Scheduler for launching the experiments. Use None for running them on local machine]
-v  [File name where a jupyter is saved for visualization]

3. Visualize the Results

Open results.ipynb and run the first cell to get the following visualization of results.

Adding an optimizer

As an example, let's add RMSProp to the MNIST list of experiments.

  1. Define a new optimizer in src/optimizers/<new_optimizer>.py.
  2. Init the constructor for opt_name = "<new_optimizer>" in src/optimizers/__init__.py.

For example,

elif opt_name == "seg":
        opt = sls_eg.SlsEg(params, n_batches_per_epoch=n_batches_per_epoch)
  1. Add the RMSProp hyperparameter in the EXP_GROUP
EXP_GROUP["mnist"] += [{"name":"RMSProp"}]
  1. Launch the experiment using this command
python trainval.py -e mnist -d results -sb results

Adding a dataset

As an example, let's add the mnist dataset.

Define a new dataset and its according transformations in src/datasets/__init__.py for dataset_name = "<new_dataset>".

For example,

   if dataset_name == "mnist":
        view = torchvision.transforms.Lambda(lambda x: x.view(-1).view(784))
        dataset = torchvision.datasets.MNIST(datadir, train=train_flag,
                               download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,)),
                                   view
                               ])
                               )

Adding a model

As an example, let's add the DenseNet121 model.

  1. Define the matrics, loss functionn, and the accuracy function in the src/models/classifiers.py
  2. Define the base model in the get_classifier(clf_name, train_set) function in src/models/base_classifiers.py.
  3. Define the experiment configuration you would like to run. The dataset, models, optimizers, and hyperparameters can all be defined in the experiment configurations. https://github.com/haven-ai/optimization-benchmark/blob/main/src/models/base_classifiers.py#L341
EXP_GROUPS['new_benchmark'] = {"dataset": [<dataset_name>],
                     "model_base": [<network_name>],
                     "opt": [<optimizer_dict>],}

Train using the following command

python trainval.py -e new_benchmark -v 1 -d ../results -sb ../results

Optimizers Implemented

Name Conference/Journal Implemented
Adam ICLR2015 Yes (opt=adam)
SGD with Goldstein Numer. Math 1962 Yes (opt=sgd_goldstein)
SGD with Armijo line search Pac. J. Math. 1966 Yes (opt=sgd_armijo)
SGD_nesterov Proc. USSR Acad. Sci 1983 Yes (opt=sgd_nesterov)
SGD_polyak USSR Comput. Math. Math. Phys. 1963 Yes (opt=sgd_polyak)
Adagrad JMLR2011 Yes (opt=adam)
SSN PMLR2020 Yes (opt=adagrad)
SGD Ann. Math. Stat. 1952 Yes (opt=sgd)
RMSprop Generating Sequences With Recurrent Neural Networks(2014) Yes (opt=rmsprop)
Adabound ICLR2019 Yes (opt=adabound)
Amsbound ICLR2019 Yes (opt=amsbound)
SPS AISTATS2021 Yes (opt=sps)
Lookahead NeurIPS2019 Yes (opt=lookahead)
Radam ICLR2020 Yes (opt=radam)

Leaderboard

Check out the optimizers in Google Colab) The section is being continually updated with the latest optimizers on standard benchmarks.

synthetic

alt text

ijcnn

alt text

rcv1

alt text

mushrooms

alt text

w8a

alt text

MNIST - MLP

alt text

CIFAR10 - ResNet34

alt text

CIFAR100 - ResNet34

alt text

Pascal - fcn8_vgg16

alt text