This folder contains a couple of scripts that can be used to train models and
convert the result into a tflite
files.
Currently, only MNIST models are supported, but it should be easy to use
train_mnist.py
as a starting point for other problem areas (e.g., CIFAR10).
Run
$ virtualenv --python $(which python3) $ source venv/bin/activate $ pip install tensorflow
And then, for example,
$ ./train_and_quantize mnist_simple1
Consult Models.py
to see which models can be trained.
The file train_mnist.py
specifies a script that performs quantization aware
training given the name of one of the models defined in models.py
(see bit
about models further down).
Example usage:
$ ./train_mnist.py usage: train_mnist.py [-h] [-m name] [-l] [--epochs epochs] [--checkpoint-dir dir] [--freeze model name] trains a model with quantization aware training optional arguments: -h, --help show this help message and exit -m name, --model-name name name of model. Must be defined in models.py -l, --list-models lists available models --epochs epochs number of epochs for training. Default is 1 --checkpoint-dir dir directory to save checkpoint information. Default is "./chkpt/checkpoints" --freeze model name freezes the model
$ ./train_mnist.py -m simple [snip 🦀] 60000/60000 [==============================] - 2s 29us/sample - loss: 0.4387 - acc: 0.8832 10000/10000 [==============================] - 1s 54us/sample - loss: 0.2266 - acc: 0.9360 Evaluation results: test loss 0.22657053579986094 test accuracy 0.936
This generates some checkpoints in the folder ./chkpt/
that will be needed later
The script checkpoint2pb.py
takes the checkpoints from the previous step and
creates a frozen graph def file:
$ ./checkpoint2pb.py usage: checkpoint2pb.py [-h] model_name checkpoints checkpoint2pb.py: error: the following arguments are required: model_name, checkpoints
$ ./checkpoint2pb.py mnist_simple chkpt/checkpoints [snip 🦀] ----------------------- writing mnist_simple.pb input_arrays: flatten_input output_arrays: dense_1/Softmax
Notice that the model name needs to prefixed with the name of the group it
belongs to (here mnist
).
Finally, the .pb
file can be converted into a fully optimized and quantized
model file.
$ ./pb2tflite.sh usage: ./pb2tflite.sh [frozen_model.pb] [input_arrays] [output_arrays]
$ ./pb2tflite.sh mnist_simple.pb flatten_input dense_1/Softmax converting "mnist_simple.pb" to "mnist_simple.tflite" [snip 🦀] done!
All models should be defined in models.py
. See existing models for how to go
about defining a new model.
Two models are currently defined for the MNIST dataset:
Two fully connected layers with relu activation.
One convolution and two fully connected layers with relu6 activation.