brainpy/BrainPy

Half precision (float16 or bfloat16) support

CloudyDory opened this issue · 10 comments

Does BrainPy fully support half-precision floating point numbers? I have tried to changed some of my own BrainPy code from using brainpy.math.float32 to brainpy.math.float16 or brainpy.math.bfloat16 (by explicitly setting the dtype of all variables and using a debugger to make sure that they won't be promoted to float32), but it seems that the GPU memory consumption and running speed is almost the same as using float32.

Great! This requirement needs to explicitly cast all parameters to brainpy.math.float_. For example, for a HH neuron model, its parameter gNa should be reinterpreted as gNa = bm.asarray(gNa, bm.float_). Ideally, users can set brainpy.math.set(float_=bm.float16), then all variables are running with float16 types.

One more thing that needs to be taken care of is that the coefficients of runge kutta methods should also be cast into brainpy.math.float_ type.

One more thing that needs to be taken care of is that the coefficients of runge kutta methods should also be cast into brainpy.math.float_ type.

Could you let me know how to cast the runge kutta coefficients into brainpy.math.float_? It seems that the coefficients are automatically generated.

yes, changes should be made in the brainpy framework. Note that dt should also be cast in the integrators.

Update: I think GPU memory consumption is mostly determined by JAX which preallocates 75% of the total GPU memory by default. This may be the reason why I don't see a reduction of memory consumption after switching to FP16.

The preallocation can be disabled with the setting of brainpy.math.disable_gpu_memory_preallocation().