kuprel/min-dalle

Bfloat16 - precision loss

ArulselvanMadhavan opened this issue · 0 comments

Hi, Thanks for this library. I found this to be very helpful in understanding DALL-E architecture.

https://github.com/kuprel/min-dalle/blob/main/min_dalle/models/dalle_bart_decoder.py#L172

I have been trying to reproduce the results in bfloat16. Everything works fine except for this matmul in the decoder. The difference in results between fp32 and bfloat16 were significant enough to affect the classifier-free guidance results. This affect the image generation. Do you have any suggestions on minimizing the significance in the results between fp32 and bfloat16?

Thank you