speechbrain/speechbrain

dtype mismatch in AttentiveStatisticsPooling with FP16 training mode

MM-0712 opened this issue · 1 comments

Describe the bug

attn = torch.cat([x, mean, std], dim=1)

If the model is trained with FP16 or BF16 mode, here will report dtype mismatch.
So, one solution is that it need add .to(x.dtype).

Expected behaviour

None

To Reproduce

None

Environment Details

No response

Relevant Log Output

No response

Additional Context

No response