Register LMUCell with Keras
Opened this issue · 0 comments
bmorcos commented
If the LMUCell
is wrapped in another layer (e.g. RNN
) then it cannot be serialized since LMUCell
is a custom object unknown to Keras. For example:
# Build an LMU layer
dt = 1e-3
activation = "tanh"
dropout=0.2
lmu_layer = RNN(
keras_lmu.LMUCell(
memory_d=10,
order=8,
theta=10 / dt,
hidden_cell=Dense(1024, activation),
hidden_to_memory=False,
memory_to_memory=False,
input_to_hidden=False,
dropout=dropout,
),
return_sequences=True,
)
# Test serialization
lmu_layer.from_config(
lmu_layer.get_config(),
)
This fails with ValueError: Unknown layer: LMUCell
.
The quick fix is to tell Keras about the LMUCell
via custom_objects
:
# Test serialization
lmu_layer.from_config(
lmu_layer.get_config(),
custom_objects={"LMUCell":keras_lmu.LMUCell}, # <-- This is key
)
Although this allows the LMUCell
to be properly (de)serialized, this requires direct access and may be challenging if using additional scripts on top of the RNN
.
It seems like there is a way to register custom objects with Keras and that may be the proper general solution, just don't have time to test that out right now!
aside
For completeness/reference, using theLMU
layer (instead of the LMUCell
wrapped in an RNN
, for example) serializes fine:
lmu_layer_builtin = keras_lmu.LMU(
memory_d=10,
order=8,
theta=10 / dt,
hidden_cell=Dense(1024, activation),
hidden_to_memory=False,
memory_to_memory=False,
input_to_hidden=False,
dropout=dropout,
return_sequences=True,
)
lmu_layer_builtin.from_config(
lmu_layer_builtin.get_config(),
)