/gRDA-Optimizer

Primary LanguageJupyter Notebook

gRDA-Optimizer

"Generalized Regularized Dual Averaging" is an optimizer that can learn a small sub-network during training, if one starts from an overparameterized dense network.

If you find our method useful, please consider to cite our paper:

@inproceedings{chao2021dp,
author = {Chao, Shih-Kang and Wang, Zhanyu and Xing, Yue and Cheng, Guang},
booktitle = {Advances in Neural Information Processing Systems},
editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
pages = {13986--13998},
publisher = {Curran Associates, Inc.},
title = {Directional Pruning of Deep Neural Networks},
url = {https://proceedings.neurips.cc/paper/2020/file/a09e75c5c86a7bf6582d2b4d75aad615-Paper.pdf},
volume = {33},
year = {2020},
}


Figure: The DP is the asymptotic directional pruning solution computed with gRDA, which lies in the same minimum valley as the SGD on the training loss landscape. DP has a sparsity 90.3% and a test accuracy 76.81%, while the SGD solution has no zero elements and has a test accuracy 76.6%. This example uses wide ResNet28x10 on CIFAR-100. See the paper for more detail.


Figure: training curve and sparsity based on the simple 6-layer CNN provided in the Keras tutorial https://keras.io/examples/cifar10_cnn/. The experiments are done using lr = 0.005 for SGD, SGD momentum and gRDAs. c = 0.005 for gRDA. lr = 0.005 and 0.001 for Adagrad and Adam, respectively.

Requirements

Keras version >= 2.2.5
Tensorflow version >= 1.14.0

How to use

There are three hyperparameters: Learning rate (lr), sparsity control mu (mu), and initial sparse control constant (c) in gRDA optimizer.

  • lr: as a rule of thumb, use the learning rate that works for the SGD without momentum. Scale the learning rate with the batch size.
  • mu: 0.5 < mu < 1. Greater mu will make the parameters more sparse. Selecting it in the set {0.501,0.51,0.55} is generally recommended.
  • c: a small number, e.g. 0 < c < 0.005. Greater c causes the model to be more sparse, especially at the early stage of training. c usually has small effect on the late stage of training. We recommend to first fix mu, then search for the largest c that preserves the testing accuracy with 1-5 epochs.

Keras

Suppose the loss function is the categorical crossentropy,

from grda import GRDA

opt = GRDA(lr = 0.005, c = 0.005, mu = 0.7)
model.compile(optimizer = opt, loss='categorical_crossentropy', metrics=['accuracy'])

Tensorflow

from grda_tensorflow import GRDA

n_epochs = 20
batch_size = 10
batches = 50000/batch_size # CIFAR-10 number of minibatches

opt = GRDA(learning_rate = 0.005, c = 0.005, mu = 0.51)
opt_r = opt.minimize(R_loss, var_list = r_vars)
with tf.Session(config=session_conf) as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(n_epochs + 1):
        for b in range(batches):
            sess.run([R_loss, opt_r], feed_dict = {data: train_x, y: train_y})

PyTorch

You can check the test file mnist_test_pytorch.py. The essential part is below.

from grda_pytorch import gRDA

optimizer = gRDA(model.parameters(), lr=0.005, c=0.1, mu=0.5)
# loss.backward()
# optimizer.step()

See mnist_test_pytorch.py for an illustration on customized learning rate schedule.

PlaidML

Be cautious that it can be unstable with Mac when GPU is implemented, see plaidml/plaidml#168.

To run, define the softthreshold function in the plaidml backend file (plaidml/keras):

def softthreshold(x, t):
     x = clip(x, -t, t) * (builtins.abs(x) - t) / t
     return x

In the main file, add the following before importing other libraries

import plaidml.keras
plaidml.keras.install_backend()

from grda_plaidml import GRDA

Then the optimizer can be used in the same way as Keras.

Experiments

ResNet-50 on ImageNet

These PyTorch models are based on the official implementation: https://github.com/pytorch/examples/blob/master/imagenet/main.py

lr schedule c mu epoch sparsity top1 accuracy file size link
fix lr=0.1 (SGD, no momentum or weight decay) / / 89 / 68.71 98MB link
fix lr=0.1 0.005 0.55 145 91.54 69.76 195MB link
fix lr=0.1 0.005 0.51 136 87.38 70.35 195MB link
fix lr=0.1 0.005 0.501 145 86.03 70.60 195MB link
lr=0.1 (ep1-140) lr=0.01 (after ep140) 0.005 0.55 150 91.59 73.24 195MB link
lr=0.1 (ep1-140) lr=0.01 (after ep140) 0.005 0.51 146 87.28 73.14 195MB link
lr=0.1 (ep1-140) lr=0.01 (after ep140) 0.005 0.501 148 86.09 73.13 195MB link