better stack trace for concretization inside scan
Opened this issue · 0 comments
dlwh commented
Jax exceptions from scan don't give any indication where the actual error occurred. It would be better if we could catch this somehow and give a better stack trace.
something like:
scan(self.foo, ...)
def foo(self, x: NamedArray, mask: Optional[AttentionMask | NamedArray], layer_idx, *, key):
k1, k2, k3, k4 = haliax.jax_utils.maybe_rng_split(key, 4)
attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
attn_output = self.resid_dropout(attn_output, key=k2)
x = x + attn_output
ff_output = self.mlp(self.ln_2(x), key=k3)
ff_output = self.resid_dropout(ff_output, key=k4)
x = x + ff_output
#import ipdb; ipdb.set_trace()
if jnp.equal(layer_idx.array, 4):
#x = x + 0.01*jnp.sin(x*1e2)
x = x + 0.01*hax.sin(x*1e2)
return x
produced:
carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
File "/nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/haliax/hof.py", line 83, in wrapped_fn
carry, y = f(carry, *args, **kwargs)
File "/nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/haliax/hof.py", line 124, in scan_compatible_fn
return fn(carry, *args, **kwargs), None
File "/nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/haliax/jax_utils.py", line 69, in wrapper
dynamic_out, static_out = checkpointed_fun(static, dynamic)
jax.errors.ConcretizationTypeError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function new_fun at /nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/jax/_src/ad_checkpoint.py:357 for checkpoint. This concrete value was not available in Python because it depends on the value of the argument dyn_args[0][0][3][<flat index 0>].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
Consider using the `static_argnums` parameter for `jax.remat` or `jax.checkpoint`. See the `jax.checkpoint` docstring and its example involving `static_argnums`:
https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception: