PyTorch implementation of Efficient Neural Architecture Search via Parameters Sharing.
ENAS reduce the computational requirement (GPU-hours) of Neural Architecture Search (NAS) by 1000x via parameter sharing between models that are subgraphs within a large computational graph. SOTA on Penn Treebank
language modeling.
- Python 3.6+
- PyTorch
- tqdm, scipy, imageio, graphviz, tensorboardX
Install prerequisites with:
conda install graphviz
pip install -r requirements.txt
To train ENAS to discover a recurrent cell for RNN:
python main.py --network_type rnn --dataset ptb --controller_optim adam --controller_lr 0.00035 \
--shared_optim sgd --shared_lr 20.0 --entropy_coeff 0.0001
python main.py --network_type rnn --dataset wikitext
To train ENAS to discover CNN architecture (in progress):
python main.py --network_type cnn --dataset cifar --controller_optim momentum --controller_lr_cosine=True \
--controller_lr_max 0.05 --controller_lr_min 0.0001 --entropy_coeff 0.1
or you can use your own dataset by placing images like:
data
├── YOUR_TEXT_DATASET
│ ├── test.txt
│ ├── train.txt
│ └── valid.txt
├── YOUR_IMAGE_DATASET
│ ├── test
│ │ ├── xxx.jpg (name doesn't matter)
│ │ ├── yyy.jpg (name doesn't matter)
│ │ └── ...
│ ├── train
│ │ ├── xxx.jpg
│ │ └── ...
│ └── valid
│ ├── xxx.jpg
│ └── ...
├── image.py
└── text.py
To generate gif
image of generated samples:
python generate_gif.py --model_name=ptb_2018-02-15_11-20-02 --output=sample.gif
More configurations can be found here.
Efficient Neural Architecture Search (ENAS) is composed of two sets of learnable parameters, controller LSTM θ and the shared parameters ω. These two parameters are alternatively trained and only trained controller is used to derive novel architectures.
Controller LSTM decide 1) what activation function to use and 2) which previous node to connect.
The RNN cell ENAS discovered for Penn Treebank
and WikiText-2
dataset:
Best discovered ENAS cell for Penn Treebank
at epoch 27:
You can see the details of training (e.g. reward
, entropy
, loss
) with:
tensorboard --logdir=logs --port=6006
Controller LSTM samples 1) what computation operation to use and 2) which previous node to connect.
The CNN network ENAS discovered for CIFAR-10
dataset:
(in progress)
(in progress)
- Neural Architecture Search with Reinforcement Learning
- Neural Optimizer Search with Reinforcement Learning
Taehoon Kim / @carpedm20