
Shape error for some use cases of `binary_crossentropy`.

albertaillet opened this issue · 3 comments


When wrapping the binary_crossentropy loss function in another keras.losses.Loss, it no longer supports targets with an flat shape and requires a shape of form (..., 1). This does not happen when it is simply wrapped in a function or a class with a __call__() method.

How to reproduce

The following script can be used to reproduce this error.

import keras

def fit_model_with_loss(loss):
    model = keras.Sequential([keras.layers.Dense(1, activation="sigmoid")])
    model.compile(optimizer="sgd", loss=loss, metrics=["accuracy"]), y, batch_size=16, epochs=2)

x = keras.random.uniform((32, 1))
y = keras.random.randint((32,), 0, 1)

loss = keras.losses.get("binary_crossentropy")
fit_model_with_loss(loss)  # works fine

def loss_wrapped_with_function(*args, **kwargs):
    return loss(*args, **kwargs)

fit_model_with_loss(loss_wrapped_with_function)  # works fine

class LossWrapper:
    def __init__(self, loss) -> None:
        self.loss = loss

    def __call__(self, *args, **kwargs):
        return self.loss(*args, **kwargs)

fit_model_with_loss(LossWrapper(loss))  # works fine

class LossWrapperInherit(keras.losses.Loss):
    def __init__(self, loss) -> None:
        self.loss = loss

    def call(self, *args, **kwargs):
        return self.loss(*args, **kwargs)

fit_model_with_loss(LossWrapperInherit(loss))  # gets a shape error

The error is the following:

File "/keras/keras/src/losses/", line 43, in __call__
    losses =, y_pred)
  File "/keras/", line 45, in call
    return self.loss(*args, **kwargs)
  File "/keras/keras/src/losses/", line 1782, in binary_crossentropy
    ops.binary_crossentropy(y_true, y_pred, from_logits=from_logits),
  File "/keras/keras/src/ops/", line 1398, in binary_crossentropy
    return backend.nn.binary_crossentropy(
  File "/keras/keras/src/backend/jax/", line 518, in binary_crossentropy
    raise ValueError(
ValueError: Arguments `target` and `output` must have the same shape. Received: target.shape=(16,), output.shape=(16, 1)

Is there a recommended way?

In case this is an expected behaviour, what is the recommended way to wrap a loss function as a keras.losses.Loss class and handle both flat and (..., 1) target shapes?

Just like in layers, you're supposed to override call(), not __call__(). You could override __call__(), but by doing so you miss a bit of built-in functionality, including auto-broadcasting. So you can just override call() and call self.loss(y_true, y_pred) there.

I agree, however in the example, I think the one not working is the one that overrides call() in the recommended way, if I am not mistaken.

I agree, however in the example, I think the one not working is the one that overrides call() in the recommended way, if I am not mistaken.

Hi @albertaillet ,

I have reproduced the reported error with overriding call method. Attached gist for reference.