keunwoochoi/kapre

Testing STFT/Magnitude against STFTTflite/MagnitudeTflite

daniel-deychakiwsky opened this issue · 2 comments

I'm testing STFT to Magnitude layers and comparing them against the TFlite versions and get failures for the same set of arguments. Am I doing something incorrectly here or is this expected for some reason?

from kapre import (
    STFT,
    ApplyFilterbank,
    Magnitude,
    MagnitudeTflite,
    MagnitudeToDecibel,
    STFTTflite,
)
from tensorflow.keras.models import Sequential


def get_melgram_layer(
    n_fft: int,
    win_length: int,
    hop_length: int,
    window_name: str,
    pad_begin: bool,
    pad_end: bool,
    sample_rate: int,
    n_mels: int,
    mel_f_min: int,
    mel_f_max: int,
    mel_htk: bool,
    mel_norm: str,
    return_decibel: bool,
    db_amin: float,
    db_ref_value: float,
    db_dynamic_range: float,
    input_data_format: str,
    output_data_format: str,
    name: str,
    for_device: bool,
) -> Sequential:
    melgram_layers = Sequential(name=name)

    if for_device:
        melgram_layers.add(
            STFTTflite(
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                window_name=window_name,
                pad_begin=pad_begin,
                pad_end=pad_end,
                input_data_format=input_data_format,
                output_data_format=output_data_format,
            )
        )
        melgram_layers.add(MagnitudeTflite())
    else:
        melgram_layers.add(
            STFT(
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                window_name=window_name,
                pad_begin=pad_begin,
                pad_end=pad_end,
                input_data_format=input_data_format,
                output_data_format=output_data_format,
            )
        )
        melgram_layers.add(Magnitude())
    # melgram_layers.add(
    #     ApplyFilterbank(
    #         type="mel",
    #         filterbank_kwargs={
    #             "sample_rate": sample_rate,
    #             "n_freq": n_fft // 2 + 1,
    #             "n_mels": n_mels,
    #             "f_min": mel_f_min,
    #             "f_max": mel_f_max,
    #             "htk": mel_htk,
    #             "norm": mel_norm,
    #         },
    #         data_format=output_data_format,
    #     )
    # )
    # if return_decibel:
    #     melgram_layers.add(
    #         MagnitudeToDecibel(
    #             ref_value=db_ref_value, amin=db_amin, dynamic_range=db_dynamic_range
    #         )
    #     )
    return melgram_layers
import numpy as np


def test_get_melgram_layer():
    kwargs = {
        "n_fft": 2048,
        "win_length": 1024,
        "hop_length": 1024,
        "window_name": "hann_window",
        "pad_begin": False,
        "pad_end": False,
        "sample_rate": 22050,
        "n_mels": 256,
        "mel_f_min": 0,
        "mel_f_max": 22050 // 2,
        "mel_htk": False,
        "mel_norm": "slaney",
        "return_decibel": True,
        "db_amin": 1e-05,
        "db_ref_value": 1.0,
        "db_dynamic_range": 150.0,
        "input_data_format": "channels_last",
        "output_data_format": "channels_last",
        "name": "log_mel_spectrogram",
    }

    fake_audio = np.ones((1, 22050, 1))

    kwargs.update({"for_device": False})
    training_melgram = get_melgram_layer(**kwargs)(fake_audio)
    kwargs.update({"for_device": True})
    edge_serving_melgram = get_melgram_layer(**kwargs)(fake_audio)

    np.testing.assert_allclose(training_melgram, edge_serving_melgram)
>       np.testing.assert_allclose(training_melgram, edge_serving_melgram)
E       AssertionError: 
E       Not equal to tolerance rtol=1e-07, atol=0
E       
E       Mismatched elements: 21480 / 21525 (99.8%)
E       Max absolute difference: 0.00228808
E       Max relative difference: 2939.9194
E        x: array([[[[5.120000e+02],
E                [4.345991e+02],
E                [2.560000e+02],...
E        y: array([[[[5.120000e+02],
E                [4.345991e+02],
E                [2.560000e+02],...

That used a DC signal above. If I use a more realistic signal, the difference is significantly reduced. I'm not sure if there's anything to be added to attempt to close the gap.

hi. there have been some discussion on it. in general, it's possible to make them close enough with parameter tuned, though it may not be apparent.

one thing: better to compare it without the decibel scaling, as it exaggerates the numerical difference under abs(x) < 1.

https://groups.google.com/a/ismir.net/g/community/c/LiVRv4I7asw/m/H6Ag-MxGAQAJ

https://colab.research.google.com/drive/1ptS1UkpHa-dW8w7WEf8xTE63mEQg8NQZ

tensorflow/tensorflow#32373

tensorflow/tensorflow#15134