How to reinitialize the hidden states of RNNs?
qlan3 opened this issue · 0 comments
qlan3 commented
I want use initial_state
in this way but get an error: AttributeError: 'Transformed' object has no attribute 'init_hidden_state'
What is the best way to to this?
import haiku as hk
class RNN(hk.Module):
def __init__(self, hidden_size=4, name=None):
super().__init__(name=name)
self.rnn = hk.LSTM(hidden_size)
def __call__(self, h, x):
out, h = self.rnn(x, h)
return h, out
def init_hidden_state(self, batch_size=1):
return self.rnn.initial_state(batch_size)
model = hk.without_apply_rng(hk.transform(lambda h, x: RNN(4)(h, x)))
h = model.init_hidden_state(1)