`--mixed-precision` doesnt work with img transformer 2
yoinked-h opened this issue · 2 comments
yoinked-h commented
When trying to train with mixed precision (and natten), the pos embedding gets casted to fp32 and not bf16, causing an error later on in the attention.forward call
crowsonkb commented
I just fixed a similar sounding problem with the dtypes of the tensors being input to natten2dav(), which only occurred using very recent versions of NATTEN (commit: 6ab5146), can you pull and check to see if this fixes your problem?
yoinked-h commented
this seems to have fixed it, ty!