/pt.darts

PyTorch Implementation of DARTS: Differentiable Architecture Search

Primary LanguagePython

DARTS: Differentiable Architecture Search

Liu, Hanxiao, Karen Simonyan, and Yiming Yang. "Darts: Differentiable architecture search." arXiv preprint arXiv:1806.09055 (2018). [arxiv]

Requirements

  • python 3
  • pytorch >= 0.4
  • graphviz
    • First install using apt install and then pip install.
  • numpy
  • tensorboardX

Run example

Adjust the batch size if out of memory (OOM) occurs. It dependes on your gpu memory size and genotype.

  • Search
python search.py --name cifar10 --dataset cifar10
  • Augment
# genotype: from search
python augment.py --name cifar10 --dataset cifar10 --genotype genotype
  • with docker
$ docker run --runtime=nvidia -it khanrc/pytorch-darts:0.1 bash

# you can run directly also
$ docker run --runtime=nvidia -it khanrc/pytorch-darts:0.1 python search.py --name cifar10 --dataset cifar10

Results

The following results were obtained using the default arguments, except for the epochs. --epochs 300 was used in MNIST and Fashion-MNIST.

Dataset Final validation acc Best validation acc
MNIST 99.75% 99.81%
Fashion-MNIST 99.27% 99.39%
CIFAR-10 97.17% 97.23%

97.17%, final validation accuracy in CIFAR-10, is the same number as the paper.

Found architectures

# CIFAR10
Genotype(
    normal=[[('sep_conv_3x3', 0), ('dil_conv_5x5', 1)], [('skip_connect', 0), ('dil_conv_3x3', 2)], [('sep_conv_3x3', 1), ('skip_connect', 0)], [('sep_conv_3x3', 1), ('skip_connect', 0)]],
    normal_concat=range(2, 6),
    reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 3), ('max_pool_3x3', 0)], [('skip_connect', 2), ('max_pool_3x3', 0)]],
    reduce_concat=range(2, 6)
)

# FashionMNIST
Genotype(
    normal=[[('max_pool_3x3', 0), ('dil_conv_5x5', 1)], [('max_pool_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_5x5', 1), ('sep_conv_3x3', 3)], [('sep_conv_5x5', 4), ('dil_conv_5x5', 3)]],
    normal_concat=range(2, 6),
    reduce=[[('sep_conv_3x3', 1), ('avg_pool_3x3', 0)], [('avg_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 3), ('avg_pool_3x3', 0)], [('sep_conv_3x3', 2), ('skip_connect', 3)]],
    reduce_concat=range(2, 6)
)

# MNIST
Genotype(
    normal=[[('sep_conv_3x3', 0), ('dil_conv_5x5', 1)], [('sep_conv_3x3', 2), ('sep_conv_3x3', 1)], [('dil_conv_5x5', 3), ('sep_conv_3x3', 1)], [('sep_conv_5x5', 4), ('dil_conv_5x5', 3)]],
    normal_concat=range(2, 6),
    reduce=[[('dil_conv_3x3', 0), ('sep_conv_3x3', 1)], [('avg_pool_3x3', 0), ('skip_connect', 2)], [('dil_conv_5x5', 3), ('avg_pool_3x3', 0)], [('dil_conv_3x3', 1), ('max_pool_3x3', 0)]],
    reduce_concat=range(2, 6)
)

Architecture progress

cifar10-progress-normal cifar10-progress-reduce
CIFAR-10

mnist-progress-normal mnist-progress-reduce
MNIST

fashionmnist-progress-normal fashionmnist-progress-reduce
Fashion-MNIST

Plots

fashionmnist-search

Search-training phase of Fashion-MNIST

cifar10-val fashionmnist-val

Augment-validation phase of CIFAR-10 and Fashion-MNIST

Reference

https://github.com/quark0/darts (official implementation)

Main differences to reference code

  • Supporting pytorch >= 0.4
  • Code that is easy to read and commented.
  • Implemenation of architect
    • Original implementation is very slow in pytorch >= 0.4.
  • Tested on FashionMNIST / MNIST
  • Tensorboard
  • No RNN

and so on.