Layernorm not supporting axis [-2, 3]
lllllllllaa opened this issue · 2 comments
lllllllllaa commented
Hi,
I wanted to normalise my output on the -2 and -3 axis, (image height and width), however, it seems that the with rms_scaling=true, the self.gamma is not broadcasted to same shape as layer input causing this error,
inputs shape: (1, 1920, 1200, 3)
inv shape: (1, 1, 1, 3)
gamma_cast shape: (1920, 1200)
inv shape: (1, 1920, 1200, 3)
2024-04-30 13:50:54.238379: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: Incompatible shapes: [1,1920,1200,3] vs. [1920,1200]
Traceback (most recent call last):
File "C:\Users\88bbh\PycharmProjects\AI\tempt.py", line 10, in <module>
layer(np.zeros((1, 1920, 1200, 3)))
File "C:\Users\88bbh\PycharmProjects\AI\venv\lib\site-packages\keras\src\utils\traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "C:\Users\88bbh\PycharmProjects\AI\venv\lib\site-packages\tensorflow\python\framework\ops.py", line 5983, in raise_from_not_ok_status
raise core._status_to_exception(e) from None # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Exception encountered when calling LayerNormalization.call().
{{function_node __wrapped__Mul_device_/job:localhost/replica:0/task:0/device:CPU:0}} Incompatible shapes: [1,1920,1200,3] vs. [1920,1200] [Op:Mul] name:
Arguments received by LayerNormalization.call():
• inputs=tf.Tensor(shape=(1, 1920, 1200, 3), dtype=float32)
code to reproduce
layer = keras.layers.LayerNormalization(axis=[-3, -2], rms_scaling=True)
layer.build([None, 1920, 1200, 3])
layer(np.zeros((1, 1920, 1200, 3)))
the error is in layernorm call method
if self.rms_scaling:
# Calculate outputs with only variance and gamma if rms scaling
# is enabled
# Calculate the variance along self.axis (layer activations).
variance = ops.var(inputs, axis=self.axis, keepdims=True)
inv = ops.rsqrt(variance + self.epsilon)
print("inputs shape:", inputs.shape)
print("inv shape:", inv.shape)
print("gamma_cast shape:", self.gamma.shape)
print("inv shape:", (inputs * inv).shape)
outputs = inputs * inv * ops.cast(self.gamma, inputs.dtype)
the error can be fixed by changing
outputs = inputs * inv * ops.cast(self.gamma, inputs.dtype)
to
outputs = inputs * inv * ops.cast(_broadcast(self.gamma), inputs.dtype)
please fix it in the next update
thank you
SuryanarayanaY commented
Hi @lllllllllaa ,
Thanks for reporting. I acknowledge the issue and proposed fix seems correct.Proposed fix on attached PR.
Thanks!
google-ml-butler commented