tensorflow/model-optimization

How to apply tfmot to a non tf.keras.Model type model..

felicitywang1 opened this issue · 4 comments

Hi,
This is more of a question than a feature request which I don't know where else to post.
So I'm trying to perform quantization aware training to a model that's not of tf.keras.Model type but of a wrapper Model class with some forward(), loss(), and trainable_variables() functions and some layers of tf.keras.Layer() type. The backpropagation is done with tf.GraientTape and optimizer.apply_gradients() outside this Model class.
Is this possible to do? I was thinking perhaps with some lower-level functions I can find in this repo this is doable but can't say for sure and don't really know how much work it would be. Would be great if I can get some advice directly from the developers.
Thank you very much.

Hi @felicitywang1

This answer might not be correct since I don't know your model exactly.

If your model is available to use tf.keras.models.clone_model(your_model), you can try applying quantization by manually remove some model-type checking codes in the API. (https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/quantization/keras/quantize.py#L330-L338)

Thank you @rino20 for the suggestion. No the model cannot use tf.keras.models.clone_model() so I'm afraid purely muting this model-type checking won't help. I don't believe there's a clear solution to our problem. Closing for now.

Reopening with an added question: so the model is not of a keras model type, and cannot use tf.keras.models.clone_model() function. Is there any idea how to apply tfmot to my model? e.g., defining another model of tf.keras.Model type and manually load the trained weights to it, or somehow register any model to be a tf.keras.Model type? Any ideas? Thank you!

MOT officially only suupport keras model. But you can try the manual loading weights if you want to apply MOT techniques. We have wrapper class for applying MOT to the model, you may want to build a new model with those wrapper wrapping each layer, and load your weights to there.

Not an exact example, but you may find some information how to apply quantization in the nested model - https://github.com/tensorflow/models/blob/master/official/projects/qat/vision/modeling/factory.py