Stripping Quantized Model
Khalil-2020 opened this issue · 1 comments
Describe the bug
I need help in this one please:
I want to re-implent the "strip_pruning" function described in this link (line 222): https://github.com/tensorflow/model-optimization/blob/v0.7.2/tensorflow_model_optimization/python/core/sparsity/keras/prune.py#L222-L270
But I want this time to apply it to the quantized model so I can try to do the following: apply quantization to a model then stripping the quantized model so I can next apply the pruning (instead of applying pruning then stripping the model then applying quantization like in the guide in the tenorflow page)
Code to reproduce the issue
def stripping_quantize(model):
if not isinstance(model, keras.Model):
raise ValueError(
'Expected model to be a tf.keras.Model
instance but got: ', model)
def _strip_quant_wrap(layer):
if isinstance(layer, tf.keras.Model):
return keras.models.clone_model(
layer, input_tensors=None, clone_function=_strip_quant_wrap)
if (layer.class.name=="QuantizeWrapperV2"):
if not hasattr(layer.layer, '_batch_input_shape') and hasattr(
layer, '_batch_input_shape'):
layer.layer._batch_input_shape = layer._batch_input_shape
return layer.layer
return layer
return keras.models.clone_model(
model, input_tensors=None, clone_function=_strip_quant_wrap)
model_q=stripping_quantize(quantized_model)
when I apply prunning to model_q, I get the following errors :
Screenshots
AttributeError Traceback (most recent call last)
Input In [41], in <cell line: 1>()
----> 1 model_for_pruning = tf.keras.models.clone_model(
2 model_q,
3 clone_function=apply_pruning_to_layers,)
File ~\anaconda3\envs\base\lib\site-packages\keras\models.py:456, in clone_model(model, input_tensors, clone_function)
453 return _clone_sequential_model(
454 model, input_tensors=input_tensors, layer_fn=clone_function)
455 else:
--> 456 return _clone_functional_model(
457 model, input_tensors=input_tensors, layer_fn=clone_function)
File ~\anaconda3\envs\base\lib\site-packages\keras\models.py:197, in _clone_functional_model(model, input_tensors, layer_fn)
193 model_configs, created_layers = _clone_layers_and_model_config(
194 model, new_input_layers, layer_fn)
195 # Reconstruct model from the config, using the cloned layers.
196 input_tensors, output_tensors, created_layers = (
--> 197 functional.reconstruct_from_config(model_configs,
198 created_layers=created_layers))
199 metrics_names = model.metrics_names
200 model = Model(input_tensors, output_tensors, name=model.name)
File ~\anaconda3\envs\pbase\lib\site-packages\keras\engine\functional.py:1338, in reconstruct_from_config(config, custom_objects, created_layers)
1336 while layer_nodes:
1337 node_data = layer_nodes[0]
-> 1338 if process_node(layer, node_data):
1339 layer_nodes.pop(0)
1340 else:
1341 # If a node can't be processed, stop processing the nodes of
1342 # the current layer to maintain node ordering.
File ~\anaconda3\envs\base\lib\site-packages\keras\engine\functional.py:1282, in reconstruct_from_config..process_node(layer, node_data)
1279 if not layer._preserve_input_structure_in_config:
1280 input_tensors = (
1281 base_layer_utils.unnest_if_single_tensor(input_tensors))
-> 1282 output_tensors = layer(input_tensors, **kwargs)
1284 # Update node index map.
1285 output_index = (tf.nest.flatten(output_tensors)[0].
1286 _keras_history.node_index)
File ~\anaconda3\envs\base\lib\site-packages\keras\utils\traceback_utils.py:67, in filter_traceback..error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.traceback)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
File ~\anaconda3\envs\base\lib\site-packages\tensorflow\python\autograph\impl\api.py:692, in convert..decorator..wrapper(*args, **kwargs)
690 except Exception as e: # pylint:disable=broad-except
691 if hasattr(e, 'ag_error_metadata'):
--> 692 raise e.ag_error_metadata.to_exception(e)
693 else:
694 raise
AttributeError: Exception encountered when calling layer "prune_low_magnitude_stem_conv" (type PruneLowMagnitude).
in user code:
File "C:\Users\ASUS\anaconda3\envs\base\lib\site-packages\tensorflow_model_optimization\python\core\sparsity\keras\pruning_wrapper.py", line 288, in call *
self.add_update(self.pruning_obj.weight_mask_op())
File "C:\Users\ASUS\anaconda3\envs\base\lib\site-packages\tensorflow_model_optimization\python\core\sparsity\keras\pruning_impl.py", line 254, in weight_mask_op *
return tf.group(self._weight_assign_objs())
File "C:\Users\ASUS\anaconda3\envs\base\lib\site-packages\tensorflow_model_optimization\python\core\sparsity\keras\pruning_impl.py", line 225, in update_var *
return tf_compat.assign(variable, reduced_value)
File "C:\Users\ASUS\anaconda3\envs\base\lib\site-packages\tensorflow_model_optimization\python\core\keras\compat.py", line 28, in assign *
return ref.assign(value, name=name)
AttributeError: 'Tensor' object has no attribute 'assign'
Call arguments received:
• inputs=tf.Tensor(shape=(None, 151, 151, 3), dtype=float32)
• training=False
• kwargs=<class 'inspect._empty'>
Additional context
If there is anyway to apply quantization then pruning in tensorflow I would like to know how.
THANK YOU!
From https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/quantization/keras/quantize.py
You can use the following function to strip the quantization wrappers from model layers.
def extract_original_model(model_to_unwrap):
"""Extracts original model by removing wrappers."""
layer_quantize_map = {}
requires_output_quantize = set()
def _unwrap(layer):
#if not isinstance(layer, quantize_annotate_mod.QuantizeAnnotate):
if not ("quant" in layer.name) :
return layer
annotate_wrapper = layer
# pylint: disable=protected-access
if layer._inbound_nodes and len(layer._inbound_nodes) == 1:
node = layer._inbound_nodes[0]
inbound_layers = tf.nest.flatten(node.inbound_layers)
if len(inbound_layers) == 1 and not isinstance(
inbound_layers[0], quantize_annotate_mod.QuantizeAnnotate):
requires_output_quantize.add(inbound_layers[0].name)
layer_quantize_map[annotate_wrapper.layer.name] = {
'quantize_config': annotate_wrapper.quantize_config
}
return annotate_wrapper.layer
unwrapped_model = tf.keras.models.clone_model(
model_to_unwrap, input_tensors=None, clone_function=_unwrap)
return unwrapped_model