This repo provides boiler plate code to train CNNs using Tensorflow's popular Estimator and Dataset APIs. Several popular CNN architectures from different model zoos can be imported, more can added accordingly.
- multi-headed outputs
- multi-gpu usage
- examples for importing CNNs from tf-slim, tensornets and tensorflow/models/official
- full process: starting from raw images to making predictions
python main.py \
--root_path /my_images/ \
--model_save_path ./data/model_run \
--model small_cnn \
--max_epoch 10 \
--batch_size 64 \
--image_size 50 \
--num_gpus 0 \
--num_cpus 2 \
--train_fraction 0.8 \
--color_augmentation True \
--weight_decay 0.001
The code has been tested with Tensorflow 1.8.
Thanks to following model zoos:
- tf-slim models (https://github.com/tensorflow/models/tree/master/research/slim)
- resnet implementation (https://github.com/tensorflow/models/tree/master/official/resnet)
- tensornets (https://github.com/taehoonlee/tensornets)