keunwoochoi/kapre

`STFT` layer output shape deviates from `STFTTflite` layer in batch dimension

PhilippMatthes opened this issue · 5 comments

Use Case

I want to convert a STFT layer in my model to a STFTTflite to deploy it to my mobile device. In the documentation I found that another dimension is added to account for complex numbers. But I also encountered a behaviour that is not documented.

Expected Behaviour

input_shape = (2048, 1)  # mono signal

model = keras.models.Sequential()  # TFLite incompatible model
model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape))

tflite_model = keras.models.Sequential()  # TFLite compatible model
tflite_model.add(kapre.STFTTflite(n_fft=1024, hop_length=512, input_shape=input_shape))

model has the output shape (None, 3, 513, 1). Therefore, tflite_model should have the output shape (None, 3, 513, 1, 2).

Observed Behaviour

The output shape of tflite_model is (1, 3, 513, 1, 2) instead of (None, 3, 513, 1, 2).

Problem Solution

  • If this behaviour is unwanted:
    • Change the model output format so that the batch dimension is correctly shaped.
  • Otherwise:
    • Explain in the documentation why the batch dimension is shaped to 1.
    • Explain in the documentation how to include this layer into models which expect the batch dimension to be shaped None.

Hi,

Yeah the restriction to a batch size of one in tflite is enforced and is something that I have been trying to address, but I have been getting seg-faults when running tflite inference so have not yet solved it. If you are interested you can look at my fork for details (work in progress):

https://github.com/kenders2000/kapre/tree/feature/tflite-variable-batch-size

The use case I have developed for this is to train a model using the 'vanilla' kapre.STFT layers, then when you want to convert the model to tflite for deployment create a new model with the kapre.STFTTflite layers and load the weights from the vanilla one. On the mobile device you are restricted to performing a single inference at a time, but this is generally not that restrictive.

I recently tried using the tf.signal.stft but it is still not tflite compatible from what I could see.

Hope this helps, I agree there is probably a case for improved documentation to explain this.

Cheers

Paul

Thanks for your response @kenders2000!

I have a follow-up question. I was able to create a batch-compatible TFLite classification model as follows:

inputs = keras.layers.Input(input_shape)
x = kapre.STFTTflite(n_fft=1024, hop_length=512, pad_begin=True)(inputs)
x = kapre.MagnitudeTflite()(x)
x = kapre.MagnitudeToDecibel()(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(n_outputs, activation='softmax')(x)
model = keras.models.Model(inputs, x)

With the resulting model.summary():

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_8 (InputLayer)         [(None, 2048, 1)]         0         
_________________________________________________________________
stft_tflite_10 (STFTTflite)  (1, 4, 513, 1, 2)         0         
_________________________________________________________________
magnitude_tflite_3 (Magnitud (1, 4, 513, 1)            0         
_________________________________________________________________
magnitude_to_decibel_3 (Magn (1, 4, 513, 1)            0         
_________________________________________________________________
flatten_1 (Flatten)          (1, 2052)                 0         
_________________________________________________________________
dense_1 (Dense)              (1, 9)                    18477     
=================================================================
Total params: 18,477
Trainable params: 18,477
Non-trainable params: 0

Do you expect this model to run into seg-faults with inference on mobile devices?

Note: this is just a simple example model for simplicity.

Hi,

I can see in the model summary that while the input is None theSTFTTflite layers still have a batch size of 1. So while I would expect this model to convert and run fine (when you provide a batch size of one), you would still need resize the input dimension of the resulting tflite file to have a batch size of 1.

E.g. using resize_tensor_input()

Cheers

Paul

I see, thanks!

Thank you so much for everyone!