fgnt/padertorch

STFT inverse, stacked representation

zmolikova opened this issue · 3 comments

Hi,
using complex_representation='stacked' for inverse STFT leads to an error:

import torch                                                                    
from padertorch.ops import STFT                                                 
                                                                                
stft_signal = torch.rand((2, 4, 10, 257, 2))                                    
torch_stft = STFT(512, 20, window_length=40, \                                  
                        complex_representation='stacked')                       
torch_signal = torch_stft.inverse(stft_signal)
Traceback (most recent call last):
  File "bug.py", line 7, in <module>
    torch_signal = torch_stft.inverse(stft_signal)
  File "/mnt/matylda6/izmolikova/JSALT2020/sse/tools/padertorch/padertorch/ops/_stft.py", line 215, in inverse
    stride=self.shift)
RuntimeError: Expected 3-dimensional input for 3-dimensional weight [512, 1, 40], but got 5-dimensional input of size [2, 6, 10, 257, 1] instead

The problem starts already at

signal_real, signal_imag = torch.chunk(stft_signal, 2, dim=-1)
which leads to a different shape of signal_real and signal_imag than in the concat case.

A quick fix is to unify the representation in the beginning

if self.complex_representation == 'stacked':                            
        stft_signal = torch.cat((stft_signal[...,0], stft_signal[...,1]),   
                                    dim = -1)

and then treat both representations as concat.

Thank you for this issue I have opened a pull request #68 which will solve this problem and adds a test case testing the inverse for the stacked representation. It should be merged shortly.

Solution merged in #68

Thank you!