Cannot joblib serialize pruned models
eangius opened this issue · 0 comments
eangius commented
Description
I have a scikeras
estimator used in a pipeline that trains, deploys & predicts successfully for my needs but as soon as I implement model pruning as per this documentation, the model can no longer serialize with joblib
& raises the following messages:
/project/.venv/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:212: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
mask = self.add_variable(
/project/.venv/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:219: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
threshold = self.add_variable(
/project/.venv/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:233: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
self.pruning_step = self.add_variable(
...
_pickle.PicklingError:Can't pickle <function Layer.add_loss.<locals>._tag_callable at 0x1ef3143a0>: it's not found as keras.engine.base_layer.Layer.add_loss.<locals>._tag_callable
Seems like something within PruneLowMagnitude class may be causing this.
System information
TensorFlow
version 2.9.1 so this topic is not an issue but may be related.TensorFlow Model Optimization
version 0.7.2Python
version 3.xSciKeras
version 0.8.0
Code to reproduce the issue
import joblib
from sklearn.datasets import make_classification
from sklearn.base import clone
from scikeras.wrappers import KerasClassifier
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow_model_optimization.sparsity.keras import (
UpdatePruningStep, prune_low_magnitude, strip_pruning
)
# Define the dummy dataset & common settings
X, y = make_classification(
n_samples=8000,
n_features=512,
n_informative=20,
n_classes=3,
random_state=0,
)
compression_level = 6
epochs = 5
# Define a dummy neural-net arch
def setup_model(meta):
model = Sequential()
model.add(Dense(meta["n_features_in_"], input_shape=meta["X_shape_"][1:]))
model.add(Dense(10))
model.add(Dense(meta["n_classes_"], activation='softmax'))
return model
arch = KerasClassifier(
setup_model,
loss="sparse_categorical_crossentropy",
)
# Define wrapper class that does pruning.
class PrunedClassifier(KerasClassifier):
def __init__(self, base_estimator, **kwargs):
super().__init__(**kwargs)
self.base_estimator = base_estimator
return
def fit(self, X, y, **kwargs):
self.base_estimator = clone(self.base_estimator).fit(X, y)
super().fit(X, y, **kwargs)
self.base_estimator.model_ = strip_pruning(self.base_estimator.model_)
return self
def predict(self, X):
return self.base_estimator.predict(X)
def _keras_build_fn(self):
model = prune_low_magnitude(self.base_estimator.model_)
model.compile(
loss='sparse_categorical_crossentropy',
optimizer='adam',
)
return model
# Baseline arch trains & serializes well
model0 = clone(arch).fit(X, y)
joblib.dump(model1, "model0.gz", compress=compression_level)
# Treatment arch trains well but does not serialize
model1 = PrunedClassifier(
base_estimator=clone(arch),
epochs=epochs,
callbacks=[
UpdatePruningStep(),
],
).fit(X, y)
joblib.dump(model1, "model1.gz", compress=compression_level)