Allow quantization of tied weights
hunse opened this issue · 2 comments
System information
- TensorFlow version (you are using): 2.6.0 (TFMOT 0.7.2)
- Are you willing to contribute it (Yes/No): Potentially, with some advice on how to implement it
Motivation
Numerous common networks use tied weights of some kind (i.e. the same weights used more than one place in the model), for example autoencoders or language models with shared embedding/deembedding weights. Currently, these models are not supported for quantization because quantize_apply
uses keras.models.clone_model
internally, which "will not preserve the uniqueness of shared objects within the model" (as per the docstring).
Describe the feature
The feature is to support quantization of models with tied weights. The same underlying variable would be used in multiple locations (as per the original unquantized model). Ideally, different layers using the same variable would have separate control over quantization (i.e. some layers could have the shared weights be quantized, while others have them unquantized).
Describe how the feature helps achieve the use case
I don't have a clear idea about how this feature should be implemented, which is the motivation for this issue. TensorFlow does have some support for serialization with shared objects (e.g. saving models in the "tf" format, which uses SharedObjectSavingScope
internally), but I'm not sure if anything is compatible with clone_model
, or if quantize_apply
would have to be completely redone in a way that avoids clone_model
.
Describe how existing APIs don't satisfy your use case (optional if obvious)
I've tried to quantize a model with shared weights, and run into various problems (depending on how exactly I do the sharing). All of these problems are expected, because clone_model
does not support shared weights and will re-instantiate the model with separate variables for each location where the shared variable occurs.
Hi hunse@, thanks for your input!
We haven't considered this feature yet in shared layers, but will consider this to be included in our next batch of updates.
Thanks!
Thanks @js1010.
I was able to get something working for my own code base, by modifying the clone_model_with_weights
function that's used in quantize_apply
to use Keras's SharedObjectSavingScope
and SharedObjectLoadingScope
. I did also have to modify Keras's SharedObjectConfig
to do self[generic_utils.SHARED_OBJECT_KEY] = self.object_id
right away.
import tensorflow as tf
from keras.layers import deserialize as deserialize_layer
from keras.utils import generic_utils
class SharedObjectConfig(generic_utils.SharedObjectConfig):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# set object ID right away, so that objects can be shared even when part of
# a "clone" operation with mixed order of serialization and deserialization
self[generic_utils.SHARED_OBJECT_KEY] = self.object_id
class SharedObjectSavingScope(generic_utils.SharedObjectSavingScope):
def create_config(self, base_config, obj):
"""Create a new SharedObjectConfig for a given object."""
shared_object_config = SharedObjectConfig(base_config, self._next_id)
self._next_id += 1
try:
self._shared_objects_config[obj] = shared_object_config
except TypeError:
# If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
# that has not overridden `__hash__`), a `TypeError` will be thrown.
# We'll just continue on without shared object support.
pass
return shared_object_config
def clone_model_with_weights(model_to_clone):
def clone_fn(layer):
serial = generic_utils.serialize_keras_object(layer)
return deserialize_layer(serial)
with (
SharedObjectSavingScope(),
generic_utils.SharedObjectLoadingScope(),
):
cloned_model = keras.models.clone_model(model_to_clone, clone_function=clone_fn)
cloned_model.set_weights(model_to_clone.get_weights())
return cloned_model
I'm using it to support a layer like this, which holds a reference to another layer (a tf.keras.layers.Embedding
layer), and can then use the transpose of the weights for de-embedding.
import tensorflow as tf
from keras.layers import deserialize as deserialize_layer
from keras.utils.generic_utils import serialize_keras_object
class TiedDeembedding(tf.keras.layers.Layer):
def __init__(self, embedding_layer, **kwargs):
super().__init__(**kwargs)
self.embedding_layer = embedding_layer
def build(self, input_shape):
super().build(input_shape)
self.embeddings = self.embedding_layer.embeddings
def call(self, x):
return tf.matmul(x, self.embeddings, transpose_b=True)
def get_config(self):
config = super().get_config()
config.update(
dict(
embedding_layer=serialize_keras_object(self.embedding_layer),
)
)
return config
@classmethod
def from_config(cls, config, custom_objects=None):
config = config.copy()
embedding_layer = config["embedding_layer"]
if isinstance(embedding_layer, dict):
config["embedding_layer"] = deserialize_layer(
embedding_layer, custom_objects=custom_objects
)
return cls(**config)