On JAX, Keras replaces any exception inside `call` method of `keras.Model` subclass with misleading error
burnpanck opened this issue · 2 comments
burnpanck commented
MWE:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
class Test(keras.Model):
def call(self, x):
raise RuntimeError("Random misspelling deeply nested in the model")
t = Test()
inp = keras.Input(shape=(32, 3))
t(inp)
Running the above example causes the following exception being raised:
TypeError: Exception encountered when calling Test.call().
Shapes must be 1D sequences of concrete values of integer type, got (None, 32, 3).
Arguments received by Test.call():
• args=('<KerasTensor shape=(None, 32, 3), dtype=float32, sparse=None, name=keras_tensor_1>',)
• kwargs=<class 'inspect._empty'>
On tensorflow, the exception instead reads:
RuntimeError: Exception encountered when calling Test.call().
Could not automatically infer the output shape / dtype of 'test_1' (of type Test). Either the `Test.call()` method is incorrect, or you need to implement the `Test.compute_output_spec() / compute_output_shape()` method. Error encountered:
Random misspelling deeply nested in the model
Arguments received by Test.call():
• args=('<KerasTensor shape=(None, 32, 3), dtype=float32, sparse=None, name=keras_tensor_1>',)
• kwargs=<class 'inspect._empty'>
Note that in the case of tensorflow, the error message contains the original exception string, whereas under JAX, the message misleadingly makes a strong suggestion that there is a problem with a shape. Furthermore, the internal frames of the stack trace get erased (not shown highlighted in the example above to minimize the MWE). If this happens deeply inside a model, an unsuspecting user may be sent off to an many hours long hunt for mismatched shapes that doesn't turn up anything useful.
burnpanck commented
This was using Keras 3.1.1, JAX 0.4.26, and python 3.12
fchollet commented
Good catch. I fixed it at HEAD.