google-deepmind/dm-haiku

How to reinitialize the hidden states of RNNs?

qlan3 opened this issue · 0 comments

qlan3 commented

I want use initial_state in this way but get an error: AttributeError: 'Transformed' object has no attribute 'init_hidden_state'
What is the best way to to this?

import haiku as hk

class RNN(hk.Module):
  def __init__(self, hidden_size=4, name=None):
    super().__init__(name=name)
    self.rnn = hk.LSTM(hidden_size)

  def __call__(self, h, x):
    out, h = self.rnn(x, h)
    return h, out

  def init_hidden_state(self, batch_size=1):
    return self.rnn.initial_state(batch_size)

model = hk.without_apply_rng(hk.transform(lambda h, x: RNN(4)(h, x)))
h = model.init_hidden_state(1)