Liu, Hanxiao, Karen Simonyan, and Yiming Yang. "Darts: Differentiable architecture search." arXiv preprint arXiv:1806.09055 (2018). [arxiv]
- python 3
- pytorch >= 0.4
- graphviz
- First install using
apt install
and thenpip install
.
- First install using
- numpy
- tensorboardX
Adjust the batch size if out of memory (OOM) occurs. It dependes on your gpu memory size and genotype.
- Search
python search.py --name cifar10 --dataset cifar10
- Augment
# genotype: from search
python augment.py --name cifar10 --dataset cifar10 --genotype genotype
- with docker
$ docker run --runtime=nvidia -it khanrc/pytorch-darts:0.1 bash
# you can run directly also
$ docker run --runtime=nvidia -it khanrc/pytorch-darts:0.1 python search.py --name cifar10 --dataset cifar10
The following results were obtained using the default arguments, except for the epochs. --epochs 300
was used in MNIST and Fashion-MNIST.
Dataset | Final validation acc | Best validation acc |
---|---|---|
MNIST | 99.75% | 99.81% |
Fashion-MNIST | 99.27% | 99.39% |
CIFAR-10 | 97.17% | 97.23% |
97.17%, final validation accuracy in CIFAR-10, is the same number as the paper.
# CIFAR10
Genotype(
normal=[[('sep_conv_3x3', 0), ('dil_conv_5x5', 1)], [('skip_connect', 0), ('dil_conv_3x3', 2)], [('sep_conv_3x3', 1), ('skip_connect', 0)], [('sep_conv_3x3', 1), ('skip_connect', 0)]],
normal_concat=range(2, 6),
reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 3), ('max_pool_3x3', 0)], [('skip_connect', 2), ('max_pool_3x3', 0)]],
reduce_concat=range(2, 6)
)
# FashionMNIST
Genotype(
normal=[[('max_pool_3x3', 0), ('dil_conv_5x5', 1)], [('max_pool_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_5x5', 1), ('sep_conv_3x3', 3)], [('sep_conv_5x5', 4), ('dil_conv_5x5', 3)]],
normal_concat=range(2, 6),
reduce=[[('sep_conv_3x3', 1), ('avg_pool_3x3', 0)], [('avg_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 3), ('avg_pool_3x3', 0)], [('sep_conv_3x3', 2), ('skip_connect', 3)]],
reduce_concat=range(2, 6)
)
# MNIST
Genotype(
normal=[[('sep_conv_3x3', 0), ('dil_conv_5x5', 1)], [('sep_conv_3x3', 2), ('sep_conv_3x3', 1)], [('dil_conv_5x5', 3), ('sep_conv_3x3', 1)], [('sep_conv_5x5', 4), ('dil_conv_5x5', 3)]],
normal_concat=range(2, 6),
reduce=[[('dil_conv_3x3', 0), ('sep_conv_3x3', 1)], [('avg_pool_3x3', 0), ('skip_connect', 2)], [('dil_conv_5x5', 3), ('avg_pool_3x3', 0)], [('dil_conv_3x3', 1), ('max_pool_3x3', 0)]],
reduce_concat=range(2, 6)
)
Search-training phase of Fashion-MNIST
Augment-validation phase of CIFAR-10 and Fashion-MNIST
https://github.com/quark0/darts (official implementation)
- Supporting pytorch >= 0.4
- Code that is easy to read and commented.
- Implemenation of architect
- Original implementation is very slow in pytorch >= 0.4.
- Tested on FashionMNIST / MNIST
- Tensorboard
- No RNN
and so on.