tensorflow/model-optimization

Cannot joblib serialize pruned models

eangius opened this issue · 0 comments

Description

I have a scikeras estimator used in a pipeline that trains, deploys & predicts successfully for my needs but as soon as I implement model pruning as per this documentation, the model can no longer serialize with joblib & raises the following messages:

/project/.venv/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:212: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  mask = self.add_variable(

/project/.venv/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:219: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  threshold = self.add_variable(

/project/.venv/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:233: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  self.pruning_step = self.add_variable(

...

_pickle.PicklingError:Can't pickle <function Layer.add_loss.<locals>._tag_callable at 0x1ef3143a0>:  it's not found as keras.engine.base_layer.Layer.add_loss.<locals>._tag_callable

Seems like something within PruneLowMagnitude class may be causing this.

System information

  • TensorFlow version 2.9.1 so this topic is not an issue but may be related.
  • TensorFlow Model Optimization version 0.7.2
  • Python version 3.x
  • SciKeras version 0.8.0

Code to reproduce the issue

import joblib
from sklearn.datasets import make_classification
from sklearn.base import clone
from scikeras.wrappers import KerasClassifier
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow_model_optimization.sparsity.keras import (
   UpdatePruningStep, prune_low_magnitude, strip_pruning
)

# Define the dummy dataset & common settings 
X, y = make_classification(
    n_samples=8000,
    n_features=512,
    n_informative=20,
    n_classes=3,
    random_state=0,
)
compression_level = 6
epochs = 5

# Define a dummy neural-net arch
def setup_model(meta):
    model = Sequential()
    model.add(Dense(meta["n_features_in_"], input_shape=meta["X_shape_"][1:]))
    model.add(Dense(10))
    model.add(Dense(meta["n_classes_"], activation='softmax'))
    return model

arch = KerasClassifier(
    setup_model,
    loss="sparse_categorical_crossentropy",
)

# Define wrapper class that does pruning.
class PrunedClassifier(KerasClassifier):
    def __init__(self, base_estimator, **kwargs):
        super().__init__(**kwargs)
        self.base_estimator = base_estimator
        return

    def fit(self, X, y, **kwargs):
        self.base_estimator = clone(self.base_estimator).fit(X, y)
        super().fit(X, y, **kwargs)
        self.base_estimator.model_ = strip_pruning(self.base_estimator.model_)
        return self

    def predict(self, X):
        return self.base_estimator.predict(X)

    def _keras_build_fn(self):
        model = prune_low_magnitude(self.base_estimator.model_)
        model.compile(
            loss='sparse_categorical_crossentropy',
            optimizer='adam',
        )
        return model


# Baseline arch trains & serializes well
model0 = clone(arch).fit(X, y)
joblib.dump(model1, "model0.gz", compress=compression_level)


# Treatment arch trains well but does not serialize
model1 = PrunedClassifier(
    base_estimator=clone(arch),
    epochs=epochs,
    callbacks=[
        UpdatePruningStep(),
    ],
).fit(X, y)
joblib.dump(model1, "model1.gz", compress=compression_level)