Weight miss-match when saving a QAT model with a particular layer pattern
Closed this issue · 3 comments
Describe the bug
When trying to save and load InceptionV3 (see simple model which reproduces this below), there is a miss-match in the amount of weights causing the model loading to fail.
System information
TensorFlow version (installed from source or binary): tf-nightly
TensorFlow Model Optimization version (installed from source or binary): Installed from source
Python version: python3.6
Describe the expected behavior
The model can be saved and loaded without error.
Describe the current behavior
Currently, due to the layer pattern seen in the code below, the QuantizeConfig will be discarded and in turn the appropriate weights will not be created as the layer is not built. This seems to be a limitation which is described here: model_transformer.py#L236
This then results in the following error:
ValueError: Weight count mismatch for layer #4 (named quant_activation in the current model, quant_activation in the save file). Layer expects 3 weight(s). Received 1 saved weight(s)
Code to reproduce the issue
input_1 = tf.keras.layers.Input(shape=(28, 28, 1))
conv2d_1 = tf.keras.layers.Conv2D(32, kernel_size=3)(input_1)
activation = tf.keras.layers.Activation("relu")(conv2d_1)
input_2 = tf.keras.layers.Input(shape=(28, 28, 1))
conv2d_2 = tf.keras.layers.Conv2D(32, kernel_size=3)(input_2)
activation_2 = tf.keras.layers.Activation("relu")(conv2d_2)
concat = tf.keras.layers.Concatenate(axis=3)([activation, activation_2])
output = tf.keras.layers.Dense(units=4, activation=tf.keras.activations.softmax)(concat)
model = tf.keras.Model(inputs=[input_1, input_2], outputs=[output])
model = quantize.quantize_model(model)
with quantize.quantize_scope():
model.save("qat_model.h5")
new_model = tf.keras.models.load_model("qat_model.h5")
It happens because of the concat operation.
We removed output quantization on input of concat op (to makes only concat output has quantization.), but due to we've added a temporal transforms before, result keras model config doesn't reflect that structure: see default_8bit_transforms.py#L609
So the main cause is:
- We modify the output quantizer of previous layer on transform like:
quantize_config.get_output_quantizers = lambda layer: []
- But this changes doesn't change keras model config. it just changes python object, and doesn't preserve when we save & load the model by keras API.
We may have to add a way to save this changes. (e.g. add a flag on keras config for quantizeConfig.)
This should be fixed now at head. You can use pip install git+https://github.com/tensorflow/model-optimization.git
to update your installations.