Half precision (float16 or bfloat16) support
CloudyDory opened this issue · 10 comments
Does BrainPy fully support half-precision floating point numbers? I have tried to changed some of my own BrainPy code from using brainpy.math.float32
to brainpy.math.float16
or brainpy.math.bfloat16
(by explicitly setting the dtype of all variables and using a debugger to make sure that they won't be promoted to float32
), but it seems that the GPU memory consumption and running speed is almost the same as using float32
.
Great! This requirement needs to explicitly cast all parameters to brainpy.math.float_
. For example, for a HH neuron model, its parameter gNa
should be reinterpreted as gNa = bm.asarray(gNa, bm.float_)
. Ideally, users can set brainpy.math.set(float_=bm.float16)
, then all variables are running with float16
types.
One more thing that needs to be taken care of is that the coefficients of runge kutta methods should also be cast into brainpy.math.float_
type.
One more thing that needs to be taken care of is that the coefficients of runge kutta methods should also be cast into
brainpy.math.float_
type.
Could you let me know how to cast the runge kutta coefficients into brainpy.math.float_
? It seems that the coefficients are automatically generated.
yes, changes should be made in the brainpy framework. Note that dt
should also be cast in the integrators.
Update: I think GPU memory consumption is mostly determined by JAX which preallocates 75% of the total GPU memory by default. This may be the reason why I don't see a reduction of memory consumption after switching to FP16.
The preallocation can be disabled with the setting of brainpy.math.disable_gpu_memory_preallocation()
.