Per-tensor QAT model Conv2d+BN+relu folding issue
sheh opened this issue · 1 comments
Describe the bug
I need to train QAT (per-tensor) model and then convert it tflite. But I get "folding issue" described here.
System information
TensorFlow version (installed from source or binary): 2.15.0
TensorFlow Model Optimization version (installed from source or binary): 0.8.0
Python version: 3.10.12
Describe the expected behavior
A 1-layer CNN (conv2d+bn+relu) is folded and converted to tflite after QAT in per-tensor mode without splitting computation graph on multiply "Quantize-Dequatize" parts.
Describe the current behavior
After folding a 1-layer CNN (conv2d+bn+relu) the folded layer is unquantized.
Code to reproduce the issue
import tensorflow as tf
from tensorflow_model_optimization.python.core.keras.compat import keras
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import \
default_8bit_quantize_scheme
import tensorflow_model_optimization as tfmot
quantize_apply = tfmot.quantization.keras.quantize_apply
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
def train_qat_convert_tflite(per_tensor):
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(128, 128, 3)),
keras.layers.Conv2D(3, 3, padding='same', use_bias=False),
keras.layers.BatchNormalization(),
keras.layers.Activation('relu'),
keras.layers.Softmax(),
])
annotated_model = quantize_annotate_model(model)
q_aware_model = quantize_apply(annotated_model,
scheme=default_8bit_quantize_scheme.Default8BitQuantizeScheme(disable_per_axis=per_tensor))
q_aware_model.compile(
# optimizer=Adam(learning_rate=learning_rate, epsilon=1e-8, weight_decay=1e-4),
optimizer='Adam',
loss=keras.losses.MeanAbsoluteError(),
metrics=['accuracy'],
)
q_aware_model.fit(
x=tf.random.normal((128, 128, 128, 3)),
y=tf.random.normal((128, 128, 128, 3)),
batch_size=16,
epochs=1,
)
q_aware_model.save(f'{per_tensor=}.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()
open(f'{per_tensor=}.tflite', "wb").write(quantized_tflite_model)
train_qat_convert_tflite(per_tensor=True)
train_qat_convert_tflite(per_tensor=False)
Screenshots
Additional context
I tested #552 but in case of a simple 1-layer CNN (see code) there are no custom layers so if
statement in _replace function is False and I get the next line.
I see that in keras h5 model BN layer is quantized as per-channel because quantization parameters in both cases are tensors not scalar as it is expected for per-tensor mode.