RF torch `lstm` fails with torch amp option.
LucaG1 opened this issue · 6 comments
Hi,
I'm currently trying to train a transducer model using rf. I use the torch_amp="bfloat16"
option from previous setups. In the predictor I use a rf.LayerNorm
followed by rf.LSTM
. I think this fails because the LayerNorm uses float32 and the LSTM float16. The lstm
function of the rf pytorch backend fails here:
returnn/returnn/torch/frontend/_backend.py
Line 2021 in bc0d5bf
This is the error message:
ValueError: tensor_raw_tensor_setter: tensor.dtype != raw_tensor.dtype, from tensor dtype 'float32' and raw_tensor dtype 'float16'
When I disable torch mixed precision I get past this error. However I'm not sure how this should be handled correctly. Maybe this is missing a cast? Or mixed precision can not be used in this case?
The config I used can be found here: /u/luca.gaudino/setups/2023-08-10--rf-librispeech/debug/train_rnnt_rf_ls960.config
You can run it via /u/luca.gaudino/bin/returnn_launcher_nocuda.sh /u/luca.gaudino/setups/2023-08-10--rf-librispeech/returnn/rnn.py /u/luca.gaudino/setups/2023-08-10--rf-librispeech/debug/train_rnnt_rf_ls960.config
on a 24gb gpu node.
This problem is not specific to returnn_frontend, but also happens with pure PyTorch. I solved this by wrapping the LSTM cell call with: with torch.autocast(device_type="cuda", enabled=False):
Correction: this is only a problem when using the cell type LSTM, not with the sequence one, so in your case this really just might be a check problem in returnn_frontend
Yea, RF assumes float32 here but got float16. In a couple of other cases, I just overwrite the dtype with whatever Torch has returned, i.e. basically removing the check. E.g. see the softmax
function. Basically just add this before you assign out.raw_tensor
:
out.dtype = TorchBackend.get_dtype_name_raw(out_raw)
Do you want to commit this or should I?
I just pushed it now.
Thanks, I think the same thing has to be done for the states. Should I just commit this or do you want to amend the previous commit?
Why did you reopen this? It's not fixed yet? You mean it also needs to be done for the states? Yes, just add another commit for this directly to master then.