STFT inverse, stacked representation
zmolikova opened this issue · 3 comments
zmolikova commented
Hi,
using complex_representation='stacked'
for inverse STFT leads to an error:
import torch
from padertorch.ops import STFT
stft_signal = torch.rand((2, 4, 10, 257, 2))
torch_stft = STFT(512, 20, window_length=40, \
complex_representation='stacked')
torch_signal = torch_stft.inverse(stft_signal)
Traceback (most recent call last):
File "bug.py", line 7, in <module>
torch_signal = torch_stft.inverse(stft_signal)
File "/mnt/matylda6/izmolikova/JSALT2020/sse/tools/padertorch/padertorch/ops/_stft.py", line 215, in inverse
stride=self.shift)
RuntimeError: Expected 3-dimensional input for 3-dimensional weight [512, 1, 40], but got 5-dimensional input of size [2, 6, 10, 257, 1] instead
The problem starts already at
padertorch/padertorch/ops/_stft.py
Line 201 in 8eec9aa
signal_real
and signal_imag
than in the concat
case.
A quick fix is to unify the representation in the beginning
if self.complex_representation == 'stacked':
stft_signal = torch.cat((stft_signal[...,0], stft_signal[...,1]),
dim = -1)
and then treat both representations as concat
.
jensheit commented
Thank you for this issue I have opened a pull request #68 which will solve this problem and adds a test case testing the inverse for the stacked representation. It should be merged shortly.
zmolikova commented
Thank you!