tensorflow/model-optimization

`prune_low_magnitude` can only prune an object of the following types: tf.keras.models.Sequential, tf.keras functional model, tf.keras.layers.Layer, list of tf.keras.layers.Layer. You passed an object of type: KerasTensor.

Closed this issue · 3 comments

I am trying to prune espcn Superresolution model but unable to prune because of the error.

  7     inputs = keras.Input(shape=(None, None, channels))

----> 8 x = prune.prune_low_magnitude(tf.keras.layers.Conv2D(64, 5, **conv_args)(inputs), **pruning_params)
9 x = prune.prune_low_magnitude(layers.Conv2D(32, 3, **conv_args)(x), **pruning_params)
10 x = prune.prune_low_magnitude(layers.Conv2D(channels * (upscale_factor ** 2), 3, **conv_args)(x), **pruning_params)

The line with arrow has an error. Kindly help me with this issue, as I am stuck on this issue from past 4 days.

Xhark commented

Hi, would you please try this changes?

prune.prune_low_magnitude(tf.keras.layers.Conv2D(64, 5, **conv_args)(inputs), **pruning_params)
=>
prune.prune_low_magnitude(tf.keras.layers.Conv2D(64, 5, **conv_args), **pruning_params)(inputs)

Thank you so much Kim for your reply. Yeah, your suggestion has worked.

Again Thanks a lot.

Thanks let us know if there is any further issue,