Efficiency difference in using jax.lax.fori_loop vs looping over identical layers?
hrbigelow opened this issue · 2 comments
This might be a question for jax, but I think it probably comes up in Haiku.
Supposing I have the code within some hk.Module
:
out = input
# the code in each layer is identical, only the parameters differ
for layer in self.layers:
out = layer(out)
return out
And, assume that each layer
is an instance of the same derived hk.Module
class that uses hk.get_parameter
inside its __call__
method.
Given the situation that the code is identical in each layer, one could express it as a jax.lax.fori_loop
, but it is quite awkward.
Would there be any efficiency gain doing so? Or would the jax compiler be smart enough to effectively do this anyhow?
# parameters previously defined by hk.get_parameter in the above, merged across layers
all_layer_params = ...
def layer_fn(i, input):
# the code in any layer of above self.layers
layer_params = jax.lax.dynamic_slice(all_layer_params, i)
...
return jax.lax.fori_loop(0, num_layers, layer_fn, input)
Is there a way to do this idiomatically in Haiku, to take advantage of the internal hk.get_parameter
calls?
Thanks in advance!
Hey @hrbigelow, both versions should work and in theory should be equally efficient, however we've seen a few cases (in particular with transformer models) where if you use structured control flow the XLA compiler does a better job at optimizing (in particular reducing peak memory usage) and (sometimes) overlapping communication with compute.
The recommended pattern in Haiku for repeated application of a block is to use hk.experimental.layer_stack
.
The implementation of layer stack is kind of complex (it handles quite a few edge cases) but it basically boils down to using jax.lax.scan
for the per-layer init
and apply
functions correctly.
Thanks Tom. Actually I'm looking at the examples for hk.experimental.layer_stack
. Just making sure I understand, it doesn't seem possible to somehow use layer_stack
the same way you would use an ordinary hk.Module
that has calls to hk.get_parameter
, is that right?
Instead, if you wanted to build an hk.Module
method that used layer_stack
, you'd need to somehow obtain the pure function f to pass to stack(f)(...)
.