Overriding `Layer.forward` unexpectedly changes the signature of `Layer.__call__` under torch backend
LarsKue opened this issue · 5 comments
LarsKue commented
In torch, one typically writes layer __call__
methods by overriding the forward
method. Under keras, we instead use the call
method.
I would not expect overriding forward
to have any effect on how a keras.Layer
is called, even under the torch backend, since this is the purpose of call
. However, it seems when the forward
method is overridden, this takes priority over overriding Layer.call
.
Minimal Example:
import os
os.environ["KERAS_BACKEND"] = "torch"
import keras
class MyLayer(keras.Layer):
def call(self, xz, inverse: bool = False):
if inverse:
return self.inverse(xz)
return self.forward(xz)
def forward(self, x):
pass
def inverse(self, z):
pass
layer = MyLayer()
x = keras.ops.zeros((128, 2))
# TypeError: MyLayer.forward() got an unexpected keyword argument 'inverse'
layer(x, inverse=True)
My keras version: 3.3.3