
On JAX, Keras replaces any exception inside `call` method of `keras.Model` subclass with misleading error

burnpanck opened this issue · 2 comments


import os

os.environ["KERAS_BACKEND"] = "jax"

import keras

class Test(keras.Model):
    def call(self, x):
        raise RuntimeError("Random misspelling deeply nested in the model")

t = Test()

inp = keras.Input(shape=(32, 3))


Running the above example causes the following exception being raised:

TypeError: Exception encountered when calling

Shapes must be 1D sequences of concrete values of integer type, got (None, 32, 3).

Arguments received by
  • args=('<KerasTensor shape=(None, 32, 3), dtype=float32, sparse=None, name=keras_tensor_1>',)
  • kwargs=<class 'inspect._empty'>

On tensorflow, the exception instead reads:

RuntimeError: Exception encountered when calling

Could not automatically infer the output shape / dtype of 'test_1' (of type Test). Either the `` method is incorrect, or you need to implement the `Test.compute_output_spec() / compute_output_shape()` method. Error encountered:

Random misspelling deeply nested in the model

Arguments received by
  • args=('<KerasTensor shape=(None, 32, 3), dtype=float32, sparse=None, name=keras_tensor_1>',)
  • kwargs=<class 'inspect._empty'>

Note that in the case of tensorflow, the error message contains the original exception string, whereas under JAX, the message misleadingly makes a strong suggestion that there is a problem with a shape. Furthermore, the internal frames of the stack trace get erased (not shown highlighted in the example above to minimize the MWE). If this happens deeply inside a model, an unsuspecting user may be sent off to an many hours long hunt for mismatched shapes that doesn't turn up anything useful.

This was using Keras 3.1.1, JAX 0.4.26, and python 3.12

Good catch. I fixed it at HEAD.