CUDA not working?
turian opened this issue · 1 comments
I'm generating random MSSTFT classes.
I have tensors on cuda and do model.cuda. However, I get the following error:
These are the params I passed to this MSSTFT:
{'fft_sizes': [512, 64, 16, 256, 16384],
'hop_sizes': [316, 17, 10, 135, 12809],
'scale': None,
'scale_invariance': False,
'w_mag': 1.762851192526309,
'w_phs': 0.4449890262755162,
'w_sc': 0.0,
'win_lengths': [256, 64, 4, 256, 16384],
'window': 'hamming_window'}
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-59-8179a07b1db0> in <module>()
1 for x1 in x:
2 model.cuda()
----> 3 z = model(x1.view(1, 1, -1), x)
3 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
/usr/local/lib/python3.7/dist-packages/auraloss/freq.py in forward(self, x, y)
260 mrstft_loss = 0.0
261 for f in self.stft_losses:
--> 262 mrstft_loss += f(x, y)
263 mrstft_loss /= len(self.stft_losses)
264 return mrstft_loss
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
/usr/local/lib/python3.7/dist-packages/auraloss/freq.py in forward(self, x, y)
141 # apply relevant transforms
142 if self.scale is not None:
--> 143 x_mag = torch.matmul(self.fb, x_mag)
144 y_mag = torch.matmul(self.fb, y_mag)
145
RuntimeError: Tensor for 'out' is on CPU, Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for baddbmm)
Hey @turian, I have an idea of the issue. Computation of the loss on GPU should work.
To me, it looks like this is happening because of the Mel filterbanks (self.fb
) not being on GPU when using a MelSTFTLoss
. This is a current known issue, and the current solution is to manually move the filterbanks to the correct device before hand. This will be fixed in the next auraloss release.
I have created an example that runs on CPU and GPU for me. This will demonstrate what I am talking about and the current workaround.
It seems to me that the parameters shown in your examples are not the ones causing the issue, since this issue only appears for me when I create a loss that has scale="mel"
. I have created in this example, one normal STFT loss as well as one that uses the Mel scaling. If you do not move the self.fb
tensors to the correct device manually then I get the same error as you. To address this, below I show move to loop over the STFTLoss objects and move the filterbanks to the correct device before computing the loss, which works for me, and produces the same results as on CPU.
import torch
import auraloss
params = {'fft_sizes': [1024],
'hop_sizes': [512],
'scale': None,
'scale_invariance': False,
'n_bins': None,
'w_mag': 1.762851192526309,
'w_phs': 0.0,
'w_sc': 0.0,
'win_lengths': [1024],
'window': 'hamming_window',
'sample_rate': 44100}
melparams = {'fft_sizes': [1024],
'hop_sizes': [512],
'scale': "mel",
'scale_invariance': False,
'n_bins': 64,
'w_mag': 1.762851192526309,
'w_phs': 0.0,
'w_sc': 0.0,
'win_lengths': [1024],
'window': 'hamming_window',
'sample_rate': 44100}
# standard MRSTFT
mrstft = auraloss.freq.MultiResolutionSTFTLoss(**params)
# use mel STFTs
melmrstft = auraloss.freq.MultiResolutionSTFTLoss(**melparams)
x = (torch.rand(1,1,44100) * 2) - 1
y = (torch.rand(1,1,44100) * 2) - 1
# first compute the loss just on CPU tensors
# both work fine
mrstft_loss = mrstft(x, y)
melmrstft_loss = melmrstft(x, y)
print("cpu mrstft: ", mrstft_loss)
print("cpu melmrstft: ", melmrstft_loss)
# -------- GPU ----------
# move data to GPU
x = x.to("cuda:0")
y = y.to("cuda:0")
# compute loss on GPU
mrstft_loss = mrstft(x, y)
print("gpu mrstft: ", mrstft_loss)
# move MelSTFT loss filterbanks to GPU
# this will be done automatically in future auraloss version
for stft_loss in melmrstft.stft_losses:
stft_loss.fb = stft_loss.fb.to("cuda:0")
# compute loss on GPU
melmrstft_loss = melmrstft(x, y)
print("gpu melmrstft: ", melmrstft_loss)
The output:
python test.py
cpu mrstft: tensor(1.2394)
cpu melmrstft: tensor(0.5864)
gpu mrstft: tensor(1.2394, device='cuda:0')
cpu melmrstft: tensor(0.5864, device='cuda:0')
Let me know if this work around works for you. Also, note that the w_phs
parameter will have no effect at the moment since the phase loss term is currently not implemented. It has some issues, but now with PyTorch 1.8 they will be supported, but I have not moved over to the latest version yet.