tensorflow/model-optimization

Unable to prune/quantize multiple layers at the same time

sachinkmohan opened this issue · 0 comments

Describe the bug
Unable to prune/quantize multiple layers at the same time

System information

TensorFlow version (installed from source or binary): 2.4.0

TensorFlow Model Optimization version (installed from source or binary): 0.7.2

Python version: 3.6.9

Describe the expected behavior
Ability to prune/quantize multiple layers like Conv2D and Dense layers at the same time.

Describe the current behavior
Current API supports pruning/quantization of either Conv2D or Dense layers at a time, not both.

Code to reproduce the issue
Provide a reproducible code that is the bare minimum necessary to generate the
problem.

For Pruning - Only Conv2D layers(Don't be confused, problem is inability to combine the both)
Code Ref(Modified) - https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide

# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

# Helper function uses `prune_low_magnitude` to make only the 
# Dense layers train with pruning.
def apply_pruning_to_conv2d(layer):
  if isinstance(layer, tf.keras.layers.Conv2D):
    return tfmot.sparsity.keras.prune_low_magnitude(layer)
  return layer

# Use `tf.keras.models.clone_model` to apply `apply_pruning_to_dense` 
# to the layers of the model.
model_for_pruning = tf.keras.models.clone_model(
    base_model,
    clone_function=apply_pruning_to_conv2d,
)

model_for_pruning.summary()

For Quantization - Only Dense layers(Don't be confused, problem is inability to combine the both)
Code Ref - https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide

# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

# Helper function uses `quantize_annotate_layer` to annotate that only the 
# Dense layers should be quantized.
def apply_quantization_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.quantization.keras.quantize_annotate_layer(layer)
  return layer

# Use `tf.keras.models.clone_model` to apply `apply_quantization_to_dense` 
# to the layers of the model.
annotated_model = tf.keras.models.clone_model(
    base_model,
    clone_function=apply_quantization_to_dense,
)

# Now that the Dense layers are annotated,
# `quantize_apply` actually makes the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
quant_aware_model.summary()