keras-team/keras

On JAX, Keras replaces any exception inside `call` method of `keras.Model` subclass with misleading error

burnpanck opened this issue · 2 comments

MWE:

import os

os.environ["KERAS_BACKEND"] = "jax"

import keras

class Test(keras.Model):
    def call(self, x):
        raise RuntimeError("Random misspelling deeply nested in the model")

t = Test()

inp = keras.Input(shape=(32, 3))

t(inp)

Running the above example causes the following exception being raised:

TypeError: Exception encountered when calling Test.call().

Shapes must be 1D sequences of concrete values of integer type, got (None, 32, 3).

Arguments received by Test.call():
  • args=('<KerasTensor shape=(None, 32, 3), dtype=float32, sparse=None, name=keras_tensor_1>',)
  • kwargs=<class 'inspect._empty'>

On tensorflow, the exception instead reads:

RuntimeError: Exception encountered when calling Test.call().

Could not automatically infer the output shape / dtype of 'test_1' (of type Test). Either the `Test.call()` method is incorrect, or you need to implement the `Test.compute_output_spec() / compute_output_shape()` method. Error encountered:

Random misspelling deeply nested in the model

Arguments received by Test.call():
  • args=('<KerasTensor shape=(None, 32, 3), dtype=float32, sparse=None, name=keras_tensor_1>',)
  • kwargs=<class 'inspect._empty'>

Note that in the case of tensorflow, the error message contains the original exception string, whereas under JAX, the message misleadingly makes a strong suggestion that there is a problem with a shape. Furthermore, the internal frames of the stack trace get erased (not shown highlighted in the example above to minimize the MWE). If this happens deeply inside a model, an unsuspecting user may be sent off to an many hours long hunt for mismatched shapes that doesn't turn up anything useful.

This was using Keras 3.1.1, JAX 0.4.26, and python 3.12

Good catch. I fixed it at HEAD.