`STFT` layer output shape deviates from `STFTTflite` layer in batch dimension
PhilippMatthes opened this issue · 5 comments
Use Case
I want to convert a STFT
layer in my model to a STFTTflite
to deploy it to my mobile device. In the documentation I found that another dimension is added to account for complex numbers. But I also encountered a behaviour that is not documented.
Expected Behaviour
input_shape = (2048, 1) # mono signal
model = keras.models.Sequential() # TFLite incompatible model
model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape))
tflite_model = keras.models.Sequential() # TFLite compatible model
tflite_model.add(kapre.STFTTflite(n_fft=1024, hop_length=512, input_shape=input_shape))
model
has the output shape (None, 3, 513, 1)
. Therefore, tflite_model
should have the output shape (None, 3, 513, 1, 2)
.
Observed Behaviour
The output shape of tflite_model
is (1, 3, 513, 1, 2)
instead of (None, 3, 513, 1, 2)
.
Problem Solution
- If this behaviour is unwanted:
- Change the model output format so that the batch dimension is correctly shaped.
- Otherwise:
- Explain in the documentation why the batch dimension is shaped to
1
. - Explain in the documentation how to include this layer into models which expect the batch dimension to be shaped
None
.
- Explain in the documentation why the batch dimension is shaped to
Hi,
Yeah the restriction to a batch size of one in tflite is enforced and is something that I have been trying to address, but I have been getting seg-faults when running tflite inference so have not yet solved it. If you are interested you can look at my fork for details (work in progress):
https://github.com/kenders2000/kapre/tree/feature/tflite-variable-batch-size
The use case I have developed for this is to train a model using the 'vanilla' kapre.STFT
layers, then when you want to convert the model to tflite for deployment create a new model with the kapre.STFTTflite
layers and load the weights from the vanilla one. On the mobile device you are restricted to performing a single inference at a time, but this is generally not that restrictive.
I recently tried using the tf.signal.stft but it is still not tflite compatible from what I could see.
Hope this helps, I agree there is probably a case for improved documentation to explain this.
Cheers
Paul
Thanks for your response @kenders2000!
I have a follow-up question. I was able to create a batch-compatible TFLite classification model as follows:
inputs = keras.layers.Input(input_shape)
x = kapre.STFTTflite(n_fft=1024, hop_length=512, pad_begin=True)(inputs)
x = kapre.MagnitudeTflite()(x)
x = kapre.MagnitudeToDecibel()(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(n_outputs, activation='softmax')(x)
model = keras.models.Model(inputs, x)
With the resulting model.summary()
:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_8 (InputLayer) [(None, 2048, 1)] 0
_________________________________________________________________
stft_tflite_10 (STFTTflite) (1, 4, 513, 1, 2) 0
_________________________________________________________________
magnitude_tflite_3 (Magnitud (1, 4, 513, 1) 0
_________________________________________________________________
magnitude_to_decibel_3 (Magn (1, 4, 513, 1) 0
_________________________________________________________________
flatten_1 (Flatten) (1, 2052) 0
_________________________________________________________________
dense_1 (Dense) (1, 9) 18477
=================================================================
Total params: 18,477
Trainable params: 18,477
Non-trainable params: 0
Do you expect this model to run into seg-faults with inference on mobile devices?
Note: this is just a simple example model for simplicity.
Hi,
I can see in the model summary that while the input is None
theSTFTTflite
layers still have a batch size of 1. So while I would expect this model to convert and run fine (when you provide a batch size of one), you would still need resize the input dimension of the resulting tflite file to have a batch size of 1.
E.g. using resize_tensor_input()
Cheers
Paul
I see, thanks!
Thank you so much for everyone!