google-deepmind/dm-haiku

Is it impossible to turn a sequence of identical instance blocks into a compiled loop?

cmunna0052 opened this issue · 2 comments

My apologies if this is a trivial question, but I work with transformers and have a number of situations where I create a sequence of encoder or decoder blocks. These end up getting initialized by something like

self.blocks = [EncoderBlock(config) for _ in config.num_blocks]

and then called with something like

for block in self.blocks:
    x = block(x)

What is interesting is that the layers are all fundamentally the same function, just with different parameter weights. It seems that this should be amenable to an hk.scan call to speed up compilation. However, I don't see a way to make this work that is not particularly ugly/hacky. Is there an easy implementation that I am missing, or is there no real intention to support this?

Hi @cmunna0052 , hk.experimental.layer_stack provides this functionality:

def block(x):
  return hk.Linear(1)(x)

def f(x):
  stack = hk.experimental.layer_stack(num_layers=10, name='stack')(block)
  x = stack(x)
  return x

f = hk.transform(f)

Perfect, thanks!