Shape error for some use cases of `binary_crossentropy`.
albertaillet opened this issue · 3 comments
Problem
When wrapping the binary_crossentropy
loss function in another keras.losses.Loss
, it no longer supports targets with an flat shape and requires a shape of form (..., 1)
. This does not happen when it is simply wrapped in a function or a class with a __call__()
method.
How to reproduce
The following script can be used to reproduce this error.
import keras
def fit_model_with_loss(loss):
model = keras.Sequential([keras.layers.Dense(1, activation="sigmoid")])
model.compile(optimizer="sgd", loss=loss, metrics=["accuracy"])
model.fit(x, y, batch_size=16, epochs=2)
x = keras.random.uniform((32, 1))
y = keras.random.randint((32,), 0, 1)
loss = keras.losses.get("binary_crossentropy")
fit_model_with_loss(loss) # works fine
def loss_wrapped_with_function(*args, **kwargs):
return loss(*args, **kwargs)
fit_model_with_loss(loss_wrapped_with_function) # works fine
class LossWrapper:
def __init__(self, loss) -> None:
super().__init__()
self.loss = loss
def __call__(self, *args, **kwargs):
return self.loss(*args, **kwargs)
fit_model_with_loss(LossWrapper(loss)) # works fine
class LossWrapperInherit(keras.losses.Loss):
def __init__(self, loss) -> None:
super().__init__()
self.loss = loss
def call(self, *args, **kwargs):
return self.loss(*args, **kwargs)
fit_model_with_loss(LossWrapperInherit(loss)) # gets a shape error
The error is the following:
File "/keras/keras/src/losses/loss.py", line 43, in __call__
losses = self.call(y_true, y_pred)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/keras/reproduce_keras_error.py", line 45, in call
return self.loss(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/keras/keras/src/losses/losses.py", line 1782, in binary_crossentropy
ops.binary_crossentropy(y_true, y_pred, from_logits=from_logits),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/keras/keras/src/ops/nn.py", line 1398, in binary_crossentropy
return backend.nn.binary_crossentropy(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/keras/keras/src/backend/jax/nn.py", line 518, in binary_crossentropy
raise ValueError(
ValueError: Arguments `target` and `output` must have the same shape. Received: target.shape=(16,), output.shape=(16, 1)
Is there a recommended way?
In case this is an expected behaviour, what is the recommended way to wrap a loss function as a keras.losses.Loss
class and handle both flat and (..., 1)
target shapes?
Just like in layers, you're supposed to override call()
, not __call__()
. You could override __call__()
, but by doing so you miss a bit of built-in functionality, including auto-broadcasting. So you can just override call()
and call self.loss(y_true, y_pred)
there.
I agree, however in the example, I think the one not working is the one that overrides call()
in the recommended way, if I am not mistaken.
I agree, however in the example, I think the one not working is the one that overrides
call()
in the recommended way, if I am not mistaken.
Hi @albertaillet ,
I have reproduced the reported error with overriding call
method. Attached gist for reference.