global flag for bfloat16?
Opened this issue · 4 comments
(oops, this is Sergey, I accidently posted from a different account) . haha
Edits I had to make:
sokrypton@f8a9f0d
How are you running into these errors? Do you have a repro script? It should be running under an autocast context:
Line 529 in 10077d8
There's also an example of invoking the raw forward function: https://github.com/evolutionaryscale/esm/blob/10077d8a8e120f632dee0ea25e68008c4993b535/examples/raw_forwards.py
How are you running into these errors? Do you have a repro script? It should be running under an autocast context:
Line 529 in 10077d8
There's also an example of invoking the raw forward function: https://github.com/evolutionaryscale/esm/blob/10077d8a8e120f632dee0ea25e68008c4993b535/examples/raw_forwards.py
Autocast is not called in encode or decode as far as I can tell, so when calling these you'll get a dtype error in the EncodeInputs class. Not sure what they changed to break it because both functions were working fine in July without this issue (I think the original weights were all float32?).