evolutionaryscale/esm

global flag for bfloat16?

Opened this issue · 4 comments

is there some global flag to make everything bfloat16. I recently had to go through the code an hardcode bfloat16... which seems kind of silly:
image

Otherwise, I was getting bfloat16 vs float errors

(oops, this is Sergey, I accidently posted from a different account) . haha

Edits I had to make:
sokrypton@f8a9f0d

How are you running into these errors? Do you have a repro script? It should be running under an autocast context:

torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) # type: ignore

There's also an example of invoking the raw forward function: https://github.com/evolutionaryscale/esm/blob/10077d8a8e120f632dee0ea25e68008c4993b535/examples/raw_forwards.py

How are you running into these errors? Do you have a repro script? It should be running under an autocast context:

torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) # type: ignore

There's also an example of invoking the raw forward function: https://github.com/evolutionaryscale/esm/blob/10077d8a8e120f632dee0ea25e68008c4993b535/examples/raw_forwards.py

Autocast is not called in encode or decode as far as I can tell, so when calling these you'll get a dtype error in the EncodeInputs class. Not sure what they changed to break it because both functions were working fine in July without this issue (I think the original weights were all float32?).