google-deepmind/dm-haiku

Efficiency difference in using jax.lax.fori_loop vs looping over identical layers?

hrbigelow opened this issue · 2 comments

This might be a question for jax, but I think it probably comes up in Haiku.

Supposing I have the code within some hk.Module:

out = input
# the code in each layer is identical, only the parameters differ
for layer in self.layers:
  out = layer(out)
return out

And, assume that each layer is an instance of the same derived hk.Module class that uses hk.get_parameter inside its __call__ method.

Given the situation that the code is identical in each layer, one could express it as a jax.lax.fori_loop, but it is quite awkward.

Would there be any efficiency gain doing so? Or would the jax compiler be smart enough to effectively do this anyhow?

# parameters previously defined by hk.get_parameter in the above, merged across layers
all_layer_params = ...

def layer_fn(i, input):
    # the code in any layer of above self.layers
    layer_params = jax.lax.dynamic_slice(all_layer_params, i)
    ...

return jax.lax.fori_loop(0, num_layers, layer_fn, input)

Is there a way to do this idiomatically in Haiku, to take advantage of the internal hk.get_parameter calls?

Thanks in advance!

Hey @hrbigelow, both versions should work and in theory should be equally efficient, however we've seen a few cases (in particular with transformer models) where if you use structured control flow the XLA compiler does a better job at optimizing (in particular reducing peak memory usage) and (sometimes) overlapping communication with compute.

The recommended pattern in Haiku for repeated application of a block is to use hk.experimental.layer_stack.

The implementation of layer stack is kind of complex (it handles quite a few edge cases) but it basically boils down to using jax.lax.scan for the per-layer init and apply functions correctly.

Thanks Tom. Actually I'm looking at the examples for hk.experimental.layer_stack. Just making sure I understand, it doesn't seem possible to somehow use layer_stack the same way you would use an ordinary hk.Module that has calls to hk.get_parameter, is that right?

Instead, if you wanted to build an hk.Module method that used layer_stack, you'd need to somehow obtain the pure function f to pass to stack(f)(...).