tensorflow/model-optimization

Weight miss-match when saving a QAT model with a particular layer pattern

Closed this issue · 3 comments

Describe the bug

When trying to save and load InceptionV3 (see simple model which reproduces this below), there is a miss-match in the amount of weights causing the model loading to fail.

System information

TensorFlow version (installed from source or binary): tf-nightly

TensorFlow Model Optimization version (installed from source or binary): Installed from source

Python version: python3.6

Describe the expected behavior

The model can be saved and loaded without error.

Describe the current behavior

Currently, due to the layer pattern seen in the code below, the QuantizeConfig will be discarded and in turn the appropriate weights will not be created as the layer is not built. This seems to be a limitation which is described here: model_transformer.py#L236

This then results in the following error:

ValueError: Weight count mismatch for layer #4 (named quant_activation in the current model, quant_activation in the save file). Layer expects 3 weight(s). Received 1 saved weight(s)

Code to reproduce the issue

input_1 = tf.keras.layers.Input(shape=(28, 28, 1)) 

conv2d_1 = tf.keras.layers.Conv2D(32, kernel_size=3)(input_1) 
activation = tf.keras.layers.Activation("relu")(conv2d_1) 

input_2 = tf.keras.layers.Input(shape=(28, 28, 1)) 

conv2d_2 = tf.keras.layers.Conv2D(32, kernel_size=3)(input_2) 
activation_2 = tf.keras.layers.Activation("relu")(conv2d_2) 

concat = tf.keras.layers.Concatenate(axis=3)([activation, activation_2]) 

output = tf.keras.layers.Dense(units=4, activation=tf.keras.activations.softmax)(concat) 

model = tf.keras.Model(inputs=[input_1, input_2], outputs=[output])

model = quantize.quantize_model(model)

with quantize.quantize_scope():
   model.save("qat_model.h5")
   new_model = tf.keras.models.load_model("qat_model.h5")

@Xhark Could you answer this?

Xhark commented

It happens because of the concat operation.

We removed output quantization on input of concat op (to makes only concat output has quantization.), but due to we've added a temporal transforms before, result keras model config doesn't reflect that structure: see default_8bit_transforms.py#L609

So the main cause is:

  1. We modify the output quantizer of previous layer on transform like:
    quantize_config.get_output_quantizers = lambda layer: []
  2. But this changes doesn't change keras model config. it just changes python object, and doesn't preserve when we save & load the model by keras API.

We may have to add a way to save this changes. (e.g. add a flag on keras config for quantizeConfig.)

This should be fixed now at head. You can use pip install git+https://github.com/tensorflow/model-optimization.git to update your installations.