Suggestion: Allow bfloat16 use to improve speed/memory usage
Opened this issue · 1 comments
Hi, as you may have noticed, I've been developing a fork of this in an attempt to repurpose to train models for upscaling AM and FM radio recordings. In any event, while researching possible ways to improve speed and memory consumption, I started looking at 16-bit floating point formats. The standard float16 doesn't appear to have enough range to be useful for this project (I ended up with nan values relatively quickly), but bfloat16 (which uses the same number of exponent bits as float32) seems to work quite well and does speed up training a bit and has a significant impact on memory usage. You can see an example of its implementation in a recent commit I made. Note that I did have to restructure the code so loss calculation is done using standard 32-bit float values (as recommended by pytorch).
Thank you for your contribution!