Model session is an object-oriented pattern that makes it easier to run and serialize TensorFlow models. It streamlines graph and session management and serialization and deserialization, allowing you to focus on the details of your model.
This project contains a ModelSession
base class that defines the pattern
and an example MNISTModelSession
that implements it for an MNIST digits classifier.
To train an MNIST model run
python mnist_model.py train
This will download MNIST data and train a model, serializing it in a
checkpoint
directory.
If you run this command again, it will reload the serialized model and
continue training from where it left off.
Once a model has been trained, you can run it on a test set with
python mnist_model.py test
See python mnist_model.py --help
for more options.
The MNIST model is a two-layer convolutional network with dropout based on the example in the TensorFlow tutorial. It has four parameters:
- The sizes of the first and second convolutional layers
- The learning rate
- The keep probability for the fully-connected layer
We wish to specify each parameter only once and have its value read back transparently during deserialization. The model session pattern allows each to be specified at the point in the model's lifecycle when it is required.
The sizes of the two convolutional layers are passed to the ModelSession.create
method that is used to create a new model.
These become part of the model's structure and do not need to be
specified again when the model is deserialized.
The learning rate, on the other hand, can be specified at training time.
Furthermore, we would like to be able to change its value during training,
so it is a parameter of the MNISTModelSession.train
method and passed
into the model via a feed dictionary.
Similarly, the keep probability takes on a user-specified value during training
and the value 1 during testing, so it too is a parameter of the train
method and passed in to the model via a feed dictionary.
To use the model session pattern, inherit from the ModelSession
base
class and create your graph in a method overriding its static create_graph
method.
You should never call a ModelSession
constructor directly.
Instead call the create
and restore
class methods when you are creating
a new model or reloading a serialized one, respectively.
Code in your model should refer to elements of your model's graph by name,
since this information is reloaded during deserialization.
If your graph contains a variable called iteration
its value will be
used as the global_step
decoration of the checkpoint file whenever a
model is saved.