JAX array conversion failure in Keras model prediction
Qazalbash opened this issue · 3 comments
I have trained a simple Deep-MLP model and saved it in .keras
format. I am utilizing JAX jitted functions for predictions, passing two inputs as jax.numpy.column_stack
. Despite attempting alternative methods, including using numpy.column_stack
and setting JAX_TRACEBACK_FILTERING=off
, the issue persists. Notably, my Keras backend is configured as KERAS_BACKEND=jax
.
File "/media/project/inference/lippl.py", line 114, in exp_rate_integral
jnp.exp(mass_model.log_prob(m1q) + self.logVT.predict(m1m2).flatten()),
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gradf/miniforge3/envs/gwkenv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/gradf/miniforge3/envs/gwkenv/lib/python3.11/site-packages/optree/ops.py", line 594, in tree_map
return treespec.unflatten(map(func, *flat_args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gradf/miniforge3/envs/gwkenv/lib/python3.11/site-packages/jax/_src/core.py", line 684, in __array__
raise TracerArrayConversionError(self)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[10000,2].
The error occurred while tracing the function likelihood at /media/project/inference/lippl.py:119 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line /media/project/inference/lippl.py:107:29 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)
operation a:f32[10000] = pjit[
name=_uniform
jaxpr={ lambda ; b:key<fry>[] c:i32[] d:i32[]. let
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
g:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e
h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
i:u32[10000] = random_bits[bit_width=32 shape=(10000,)] b
j:u32[10000] = shift_right_logical i 9
k:u32[10000] = or j 1065353216
l:f32[10000] = bitcast_convert_type[new_dtype=float32] k
m:f32[10000] = sub l 1.0
n:f32[1] = sub h g
o:f32[10000] = mul m n
p:f32[10000] = add o g
q:f32[10000] = max g p
in (q,) }
] r s t
from line /media/project/inference/lippl.py:107:13 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line /media/project/inference/lippl.py:108:29 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)
operation a:f32[10000] = pjit[
name=_uniform
jaxpr={ lambda ; b:key<fry>[] c:i32[] d:i32[]. let
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
g:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e
h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
i:u32[10000] = random_bits[bit_width=32 shape=(10000,)] b
j:u32[10000] = shift_right_logical i 9
k:u32[10000] = or j 1065353216
l:f32[10000] = bitcast_convert_type[new_dtype=float32] k
m:f32[10000] = sub l 1.0
n:f32[1] = sub h g
o:f32[10000] = mul m n
p:f32[10000] = add o g
q:f32[10000] = max g p
in (q,) }
] r s t
from line /media/project/inference/lippl.py:108:13 (LogInhomogeneousPoissonProcessLikelihood.exp_rate_integral)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on traced array with shape float32[10000,2].
This indicates that your model contains an operation that tries to retrieve the eager value of a tensor.
Earlier, I see:
jnp.exp(mass_model.log_prob(m1q) + self.logVT.predict(m1m2).flatten()),
So it sounds like you are calling predict()
inside a tracing scope. This is impossible. Perhaps you meant to call self.logVT(m1m2)
instead?
See also: https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call
Oh, and if your call method is stateful in any way, you'll need to use stateless_call()
instead and manage the state updates manually.