brainpy/BrainPy

Using `jax.disable_jit()` cannot disable jit for `bm.scan`

Closed this issue · 3 comments

Please:

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:

Using jax.disable_jit() context manager can disable jit for jax.lax.scan, but cannot disable jit for bm.scan. This make it hard to debug functions inside bm.scan. See the following code for an example:

import jax
import jax.numpy as jnp
import brainpy.math as bm

def cumsum(res, el):
    """
    - `res`: The result from the previous loop.
    - `el`: The current array element.
    """
    res = res + el
    print(res)
    return res, res  # ("carryover", "accumulated")

a = jnp.array([1, 2, 3, 5, 7, 11, 13, 17])
result_init = 0
with jax.disable_jit():
    final, result = jax.lax.scan(cumsum, result_init, a)

b = bm.array([1, 2, 3, 5, 7, 11, 13, 17])
result_init = 0
with jax.disable_jit():
    final, result = bm.scan(cumsum, result_init, b)

The printed output is:

1
3
6
11
18
29
42
59
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>

This is a great issue. I add a new PR to fix this #606

This is a great issue. I add a new PR to fix this #606

Has been solved in #606