Model input names are not preserved during TFLite conversion when inference_input_type is tf.int8
Closed this issue · 4 comments
Describe the bug
When converting a quantised model with input names to TFLite, setting converter.inference_input_type == tf.int8
leads to different names for the inputs.
System information
TensorFlow version (installed from source or binary): 2.6.0
TensorFlow Model Optimization version: 0.7.0
Python version: 3.7.10
Describe the expected behavior
The input name should be preserved in the generated TFLite file, same as what happens during TFLite PTQ.
Describe the current behavior
The input name gets replaced by a completely different name.
Code to reproduce the issue
import tensorflow as tf
import tensorflow_model_optimization as tfmot
inp = tf.keras.Input(shape=(4,), batch_size=1, name='my_input') # Input name is set here
out = tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(4))(inp)
quant_model = tfmot.quantization.keras.quantize_apply(tf.keras.Model(inp, out))
converter = tf.lite.TFLiteConverter.from_keras_model(quant_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.int8 # Without this line the input name is preserved
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
print('\nInput name is: ' + interpreter.get_input_details()[0]['name'])
I think it might be possible to use signatures instead. Assigning @abattery for confirmation.
You can use the signature concept to keep the original interface in the final TFLite model:
Thanks for your reply. Is there any chance you could give an example of using these signatures with TFMOT, maybe using the code I provided above? I don't really think the documentation is clear about this.
Something like this should work:
interpreter = tf.lite.Interpreter(model_content=tflite_model)
print('\nInput name is: ' + interpreter.get_input_details()[0]['name'])
signatures = interpreter.get_signature_list()
print(signatures)
serving_default = interpreter.get_signature_runner('serving_default')
inputs = np.random.random((1, 4)).astype(np.float32)
tf_result = quant_model(inputs)
input_details = interpreter.get_input_details()
scale = input_details[0]['quantization_parameters']['scales'][0]
zp = input_details[0]['quantization_parameters']['zero_points'][0]
inputs = np.clip(np.round(inputs / scale + zp), -128, 127).astype(np.int8)
output = serving_default(my_input=inputs)