Using `jax.disable_jit()` cannot disable jit for `bm.scan`
Closed this issue · 3 comments
CloudyDory commented
Please:
- Check for duplicate issues.
- Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
Using jax.disable_jit()
context manager can disable jit for jax.lax.scan
, but cannot disable jit for bm.scan
. This make it hard to debug functions inside bm.scan
. See the following code for an example:
import jax
import jax.numpy as jnp
import brainpy.math as bm
def cumsum(res, el):
"""
- `res`: The result from the previous loop.
- `el`: The current array element.
"""
res = res + el
print(res)
return res, res # ("carryover", "accumulated")
a = jnp.array([1, 2, 3, 5, 7, 11, 13, 17])
result_init = 0
with jax.disable_jit():
final, result = jax.lax.scan(cumsum, result_init, a)
b = bm.array([1, 2, 3, 5, 7, 11, 13, 17])
result_init = 0
with jax.disable_jit():
final, result = bm.scan(cumsum, result_init, b)
The printed output is:
1
3
6
11
18
29
42
59
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
chaoming0625 commented
This is a great issue. I add a new PR to fix this #606
chaoming0625 commented
This is a great issue. I add a new PR to fix this #606
chaoming0625 commented
Has been solved in #606