Pruning only works for small batch sizes
vvolhejn opened this issue · 4 comments
Describe the bug
When using prune_low_magnitude()
, my model is not pruned if the batch size is low.
System information
TensorFlow version (installed from source or binary): 2.8.0 installed via pip
TensorFlow Model Optimization version (installed from source or binary): 0.7.2 installed via pip
Python version: 3.9.10
Describe the expected behavior
model_for_pruning.fit
should sparsify the model independent of the batch size.
Describe the current behavior
If the batch size is larger than 2 (this is the threshold in my example, at least), the network is not pruned.
Code to reproduce the issue
Based on the Pruning with Keras tutorial.
import tempfile
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
def main(batch_size):
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10)
])
model.compile(
loss=tf.keras.losses.MeanSquaredError(),
optimizer='adam',
metrics=['accuracy']
)
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model)
log_dir = tempfile.mkdtemp()
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
]
model_for_pruning.compile(
loss=tf.keras.losses.MeanSquaredError(),
optimizer='adam',
metrics=['accuracy']
)
model_for_pruning.fit(
np.random.randn(100, 28, 28).astype(np.float32),
np.random.randn(100, 10).astype(np.float32),
callbacks=callbacks,
epochs=2,
batch_size=batch_size,
# validation_split=0.1,
verbose=0,
)
weights = model_for_pruning.get_weights()[1]
# A sanity check to show we're looking at the right weights.
print(f"(Checking weights of shape {weights.shape})")
# What part of the weights are zeros?
print(
f"Sparsity with batch size {batch_size}:",
(weights == 0).mean(),
)
main(batch_size=1)
main(batch_size=2)
main(batch_size=3)
main(batch_size=32)
This prints:
2022-05-25 16:44:16.340347: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
/Users/vaclav/prog/venv/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:233: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
self.pruning_step = self.add_variable(
/Users/vaclav/prog/venv/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:212: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
mask = self.add_variable(
/Users/vaclav/prog/venv/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:219: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
threshold = self.add_variable(
(Checking weights of shape (3, 3, 1, 12))
Sparsity with batch size 1: 0.5
(Checking weights of shape (3, 3, 1, 12))
Sparsity with batch size 2: 0.5
(Checking weights of shape (3, 3, 1, 12))
Sparsity with batch size 3: 0.0
(Checking weights of shape (3, 3, 1, 12))
Sparsity with batch size 32: 0.0
So when the batch size is 1 or 2, everything works fine. But for anything larger, the model is not pruned.
Hi @vvolhejn ,
Since you haven't set the pruning parameters, the default option is applied - ConstantSparsity, with pruning frequency 100. https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py#L141
That means, your model will be pruned at every 100 steps. Your example will run less than 100 steps if batchsize is larger than 2, so that's why you don't get pruned result (the training finishes before applying pruning)
Hope this helps,
Thank you for clearing this up. It makes sense explained like this, but it feels a bit unintuitive to me :/ The docs say "frequency: Only apply pruning every frequency
steps." which doesn't seem to imply the first pruning happens after frequency
steps.
Sorry for the confusion. We wished that "Only" in the sentence implies it, but that might not be enough.