TensorFlow allows (relatively) easy creation and optimisation of computational graphs for a practically unlimited range of tasks. For each of these tasks, however, one needs:
- An input pipeline
- A model specification
- A training loop
- An evaluation loop
- Metrics
- ...hundreds of other things but I think the point is clear.
To avoid the substantial engineering challenge of implementing all of these components each time you want to train a cool new model, TensorFlow now includes the Estimator
(Architecture for Estimator
training, image taken from https://arxiv.org/abs/1708.02637)
and Experiments API. On top of this, to help with the input pipeline, the Dataset has also now exists.
This repository aims to provide a clear template on how to use each of these ideas in tandem to create a clean, efficient training and evaluation setup.
The code contains the following scripts to be run in order:
experiment/scripts/create_records.py
Createsx, y
examples of a biased sin function with noise. It saves these in adata_dir
as sharded.tfrecords
files. Folder for training and validation shards to ensure data separation from the start.experiment/scripts/train.py
Launches anExperiment
, initialising the training and validation loop, building a distinctEstimator
instance for each loop. Data is read from thedata_dir
. The model is saved in amodel_dir
.experiment/scripts/infer.py
Reloads theEstimator
from themodel_dir
and creates a plot of the predictions on the training and validation set.
- Great introductory post by Peter Roelants Higher-Level APIs in TensorFlow.
- Cheng, H.-T. et al. (2017). TensorFlow Estimators: Managing Simplicity vs. Flexibility in High-Level Machine Learning Frameworks.