keras-team/keras

Overriding `Layer.forward` unexpectedly changes the signature of `Layer.__call__` under torch backend

LarsKue opened this issue · 5 comments

In torch, one typically writes layer __call__ methods by overriding the forward method. Under keras, we instead use the call method.

I would not expect overriding forward to have any effect on how a keras.Layer is called, even under the torch backend, since this is the purpose of call. However, it seems when the forward method is overridden, this takes priority over overriding Layer.call.

Minimal Example:

import os
os.environ["KERAS_BACKEND"] = "torch"

import keras


class MyLayer(keras.Layer):
    def call(self, xz, inverse: bool = False):
        if inverse:
            return self.inverse(xz)
        return self.forward(xz)

    def forward(self, x):
        pass

    def inverse(self, z):
        pass


layer = MyLayer()

x = keras.ops.zeros((128, 2))

# TypeError: MyLayer.forward() got an unexpected keyword argument 'inverse'
layer(x, inverse=True)

My keras version: 3.3.3