keras-team/tf-keras

Using Lambda layers to take different slices of a prevous layer's output causes earlier Lambda layers to be overwritten

Closed this issue · 5 comments

Transferring this issue from Tensorflow tensorflow/tensorflow#62060

Current behavior?

I used two Lambda layers to extract slices from the same input vector. The output of the first Lambda is somehow overwritten by the output of the second Lambda.

Note: I have confirmed this happens whether the Lambda layers take the model's input directly or the output of another layer. I have also confirmed that the problem is present whether or not the Lambda layers are the direct outputs of the model. But for the sample code below I've removed the extra layers.

Standalone code to reproduce the issue

import sys

import tensorflow as tf

print(f"{tf.version.VERSION=} {tf.version.GIT_VERSION=} {tf.version.COMPILER_VERSION=}")
print(f"{sys.version=}")

dividers = [0, 2, 5]

assert all(divider >= 0 for divider in dividers)
sizes = [end - start for start, end in zip(dividers[:-1], dividers[1:])]
assert all(size > 0 for size in sizes)
channels = dividers[-1]

i = tf.keras.layers.Input((channels,), name='i')
o = [
    tf.keras.layers.Lambda(lambda x: x[..., start:end],
                           name=f'slice_{start}_{end}')(i)
    for start, end in zip(dividers[:-1], dividers[1:])
]
m = tf.keras.Model(i, o, name='m')
m.build((channels,))
m.summary()
print(f"{m.input_shape=}")
print(f"{m.output_shape=}")
print(f"{m.compute_output_shape(m.input_shape)=}")
x = tf.zeros((1, channels))
print(f"{[y.shape for y in m(x)]=}")
print(f"{[y.shape for y in m.predict(x)]=}")
assert m.output_shape == m.compute_output_shape(m.input_shape)

Relevant log output

tf.version.VERSION='2.10.0' tf.version.GIT_VERSION='v2.10.0-rc3-6-g359c3cdfc5f' tf.version.COMPILER_VERSION='9.3.1 20200408'
sys.version='3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]'
Model: "m"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 i (InputLayer)                 [(None, 5)]          0           []                               
                                                                                                  
 slice_0_2 (Lambda)             (None, 2)            0           ['i[0][0]']                      
                                                                                                  
 slice_2_5 (Lambda)             (None, 3)            0           ['i[0][0]']                      
                                                                                                  
==================================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
__________________________________________________________________________________________________
m.input_shape=(None, 5)
m.output_shape=[(None, 2), (None, 3)]
m.compute_output_shape(m.input_shape)=[TensorShape([None, 3]), TensorShape([None, 3])]
[y.shape for y in m(x)]=[TensorShape([1, 3]), TensorShape([1, 3])]
1/1 [==============================] - 0s 286ms/step
[y.shape for y in m.predict(x)]=[(1, 3), (1, 3)]
Traceback (most recent call last):
  File "/home/hosford42/PycharmProjects/LSLAM/tf_bug.py", line 30, in <module>
    assert m.output_shape == m.compute_output_shape(m.input_shape)
AssertionError

@hosford42 who is the original reporter of the issue here.

I think the root cause here is the mix usage of lambda and for loop here, which is particularly errorprone.

The for loop cause the lambda function that captured by the lambda layer to be always the last loop step. You can actually verify this by visiting the labmda_layer.function.

If I change your model building function as below (without the for loop), it will actually run properly.

def slice_func_0_2(x):
  return x[..., 0:2]

def slice_func_2_5(x):
  return x[..., 2:5]


i = tf.keras.layers.Input((channels,), name='i')
o1 = tf.keras.layers.Lambda(slice_func_0_2, name="slice_0_2")(i)
o2 = tf.keras.layers.Lambda(slice_func_2_5, name="slice_2_5")(i)

o = [o1, o2]
m = tf.keras.Model(i, o, name='m')

I am closing this issue since this is a user code error. Feel free to reopen this if there is anything else we need to address.

In general, please take a look for https://gist.github.com/gisbi-kim/2e5648225cc118fc72ac933ef63c2d64 for pitfall for using lambda in a loop.