easy-tensorflow provides an easy way to train and evaluate TensorFlow models. The goal of this project is not to build an off-the-shelf tool for industrial or commercial purposes but to simplify programming with TensorFlow API. With a standardized pipeline, users do not have to worry about extra book-keeping steps but can focus entirely on input pre-processing and model engineering.
We make use of TF-Slim to make the code concise and flexible. We use tf.train.Example
protocol buffers as the
standard input format.
##0. Cifar-10 demo
Clone the repo to your computer:
$ git clone git@github.com:khanhptnk/easy-tensorflow.git
Go to the cifar10
directory:
$ cd easy-tensorflow/cifar10
Download Cifar-10 Python version:
$ wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Extract downloaded file:
$ tar -zxvf cifar-10-python.tar.gz
Convert raw data to tf.Example protos:
$ python convert_to_tf_example_cifar10.py --input_data_dir=cifar-10-batches-py --output_data_dir=data
Run training:
$ python train_cifar10.py --logdir=/tmp/cifar10
Run evaluation (can be run simultaneously with training):
$ python eval_cifar10.py --logdir=/tmp/cifar10
Monitor stats on TensorBoard:
$ tensorboard --logdir=/tmp/cifar10
Then go to any browser (e.g. Chrome)'s address bar and type "0.0.0.0:6006" to visualize TensorBoard. This is TensorBoard after 30k training iterations.
##1. Code structure (Cifar-10 example)
TrainEvalBase ModelBase InputReaderBase
/ \ | |
/ \ | |
/ \ | |
Trainer Evaluator ModelCifar10 InputReaderCifar10
-
TrainEvalBase: base (abstract) class for training and evaluating.
-
Trainer: a subclass of
TrainEvalBase
, trains a model with provided data and loss function. -
Evaluator: a subclass of
TrainEvalBase
, computes evaluating metrics on a trained model. -
ModelBase: base class for specifying a model architecture. Two methods are required to implemented by any subclass:
arg_scope
, configurations of the model's layers, andcompute
, computing outputs of the model from a batch of input examples. -
ModelCifar10: a subclass of
ModelBase
, implementsarg_scope
andcompute
. -
InputReaderBase: base (abstract) class for reading input, requires the method
read_input
to be implemented by any subclass. -
InputReaderCifar10: a subclass of
InputReaderBase
, implementsread_input
and an input preprocessing method.
##2. Define a new model (Cifar-10 example):
To define a new model, we need to create 5 core files (see the cifar10
directory):
-
convert_to_tf_example_cifar10.py: convert Cifar-10 data in Python version to tf.Example protos.
-
input_reader_cifar10.py: reads examples from files containing tf.Example protos (records) and makes a batch of examples.
-
model_cifar10.py: specifies the model architecture. It implements
arg_scope
to configure model's layers (e.g. weight decay, regularize techniques, activation functions. etc.) andcompute
to arrange model's layers (e.g. which layers follow which layers) in order to return a batch of outputs from a batch of inputs. -
train_cifar10.py: runs training. It creates a
Trainer
object, specifying a training model object, loss function, computation graph, input reader object. Then it invokes therun
method of theTrainer
object to start training. -
eval_cifar10.py: runs evaluating. It creates an
Evaluator
object, specifying an evaluating model object, loss function, computation graph, input reader object. Similarly to training, it invokes therun
method of theEvaluator
object to start evaluating. NOTE: an evaluating object is created by setting theis_training
parameter ofModelBase
to False.
Although there seem to be a lot of files, the amount of code in each file is
minimal. For most files, users simply have to copy them to a new directory and
make tiny changes. The majority of modifications goes into
convert_to_tf_example_cifar10.py
, model_cifar10.py
and
input_reader_cifar10.py
.
##3. Common TensorFlow concepts:
tf.Example proto: a feature vector and can be considered as a Python dict. Each element is a pair of (key, value). The key is the feature name. The value is either a list of type bytes (string), int64, or float. See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto for more details.
For example, if data are labeled, then each proto has two features: one for the observation (e.g. image) and one for the label.