Pytorch implementation for “Towards Accurate and Compact Architectures via Neural Architecture Transformer”.
Figure. The scheme of NAT++. Our NAT++ takes an arbitrary architecture as input and produces the optimized architecture as the output. We use blue arrows to represent the process of architecture optimization. Red arrows and boxes denote the computation of reward and gradients.
Figure. Operation transition scheme of NAT++. (a) operation transition of NAT++; (b) computation costs of different operations. We set the input channel and output channel to 128, the height and width of the input feature maps to 32. Here, sep denotes a separable convolution and dil denotes a dilated separable convolution..
Python>=3.6, PyTorch==0.4.0, torchvision==0.2.1 graphviz=0.10.1 scipy=1.1.0 pygcn
Please follow the guide to install pygcn.
We consider two benchmark classification datsets, including CIFAR-10 and ImageNet.
CIFAR-10 can be automatically downloaded by torchvision.
ImageNet needs to be manually downloaded (preferably to a SSD) following the instructions here.
We consider to optimize two kinds of architectures, namely loose-end architectures and fully-concat architectures. More details about these two kinds of architectures can be found in ENAS and DARTS, respectively.
Train NAT for fully-concat architectures.
python train_search.py --data $DATA_DIR$ --op_type FULLY_CONCAT_PRIMITIVES
Train NAT for loose-end architectures.
python train_search.py --data $DATA_DIR$ --op_type LOOSE_END_PRIMITIVES
1. Put the input architectures in genotypes.py as follows
DARTS = Genotype(
normal=[('sep_conv_3x3', 0, 2), ('sep_conv_3x3', 1, 2), ('sep_conv_3x3', 0, 3), ('sep_conv_3x3', 1, 3), ('sep_conv_3x3', 1, 4),
('skip_connect', 0, 4), ('skip_connect', 0, 5), ('dil_conv_3x3', 2, 5)], normal_concat=[2, 3, 4, 5],
reduce=[('max_pool_3x3', 0, 2), ('max_pool_3x3', 1, 2), ('skip_connect', 2, 3), ('max_pool_3x3', 1, 3), ('max_pool_3x3', 0, 4),
('skip_connect', 2, 4), ('skip_connect', 2, 5), ('max_pool_3x3', 1, 5)], reduce_concat=[2, 3, 4, 5])
You can obtain the transformed architecture by taking an architecture as input, e.g., --arch DARTS.
python derive.py --data ./data --arch DARTS --model_path pretrained/fully_connect.pt
Figure. An example of architecture transformation of NAT++.
You can visualize both the input and the transformed architectures by
python visualize.py some_arch
where some_arch
should be replaced by any architecture in genotypes.py.
If you use any part of this code in your research, please cite our conference paper:
@inproceedings{guo2019nat,
title={NAT: Neural Architecture Transformer for Accurate and Compact Architectures},
author={Guo, Yong and Zheng, Yin and Tan, Mingkui and Chen, Qi and Chen, Jian and Zhao, Peilin and Huang, Junzhou},
booktitle={Advances in Neural Information Processing Systems},
year={2019}
}