Is it impossible to turn a sequence of identical instance blocks into a compiled loop?
cmunna0052 opened this issue · 2 comments
My apologies if this is a trivial question, but I work with transformers and have a number of situations where I create a sequence of encoder or decoder blocks. These end up getting initialized by something like
self.blocks = [EncoderBlock(config) for _ in config.num_blocks]
and then called with something like
for block in self.blocks:
x = block(x)
What is interesting is that the layers are all fundamentally the same function, just with different parameter weights. It seems that this should be amenable to an hk.scan call to speed up compilation. However, I don't see a way to make this work that is not particularly ugly/hacky. Is there an easy implementation that I am missing, or is there no real intention to support this?
Hi @cmunna0052 , hk.experimental.layer_stack
provides this functionality:
def block(x):
return hk.Linear(1)(x)
def f(x):
stack = hk.experimental.layer_stack(num_layers=10, name='stack')(block)
x = stack(x)
return x
f = hk.transform(f)
Perfect, thanks!