jax-ml/jax

How do you remat GSPMD inserted all-gathers?

Opened this issue · 2 comments

Problem: I have some Jax code that does sequence parallel, so somewhat similar to this

activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('data', 'tensor', None))
activation = norm(activation)
activation =  jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('None, 'tensor', None))
# I want to remat this one ^
activation = attention(activation)

I have tried everything I can to remat the activation directly before attention, including Jax policies, explicitly using jax checkpoint on that exact tensor, but nothing to seems to make it remat. The activation directly before attention is a GSPMD inserted all-gather on the sequence dimension (dim=0).

I ended up writing an XLA pass to rematerialize large all-gathers and submitted a PR. openxla/xla#19163

Question: Is this possible to do from Jax end or is my pass really needed?

Thanks for the question.

No, I don't think a new pass is needed.

As I understand it, the standard way to spell this is to us a remat policy to mark the with_sharding_constraint which induces the allgather as not-saveable. One way to do that would be to use save_only_these_names and to only name other arrays (that are either upstream of the allgather-inducing with_sharding_constraint, or downstream of the operations that use the output of attention). Following your snippet, that might look something like:

activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('data', 'tensor', None))
activation = checkpoint_name(norm(activation), 'scattered_activations')
activation =  jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('None, 'tensor', None))
activation = attention(activation)

together with a save_only_these_names policy that mentions 'scattered_activations' or something upstream of it.

Did you try something like that? If you already tried it, we should put together a minimal example to debug what's going on.

I am having problems reproducing this in a unit test as GSPMD decides to not all-gather the sequence parallel activation and use a convolution op instead. I do not understand why there are no dot ops in my HLO. Perhaps, my unit test is broken. Any ideas @mattjj ? The only all-gathers here are FSDP all-gathers which are ok.

from functools import partial
from jax.sharding import NamedSharding

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.ad_checkpoint import checkpoint_name
mesh = mesh = jax.make_mesh((2,4), ('data', 'model'))

key = jax.random.PRNGKey(0)
activation_weight = jax.random.uniform(key, (2, 2048, 1024))
key = jax.random.PRNGKey(1)

up_weight = jax.random.uniform(key, (1024, 4*1024))
key = jax.random.PRNGKey(2)

down_weight = jax.random.uniform(key, (4*1024, 1024))
key = jax.random.PRNGKey(3)

scale_weight = jax.random.uniform(key, (1024))
activation = jax.device_put(activation_weight, NamedSharding(mesh, P('data', 'model', None)))
up = jax.device_put(up_weight, NamedSharding(mesh, P('data', 'model')))
down = jax.device_put(down_weight, NamedSharding(mesh, P('model', 'data')))
scale = jax.device_put(scale_weight, NamedSharding(mesh, P(None,)))


@partial(jax.checkpoint,
         policy=jax.checkpoint_policies.save_any_names_but_these('all_gather'))
def forward(activation, up, down, scale):
  def rms_norm(x):
    x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('data','model',None)))
    x_dtype = x.dtype
    x = x.astype(jnp.float32)
    moment2 = (x * x).mean(axis=-1, keepdims=True)
    x = x * jax.lax.rsqrt(moment2 + 1e-8)
    x = x.astype(x_dtype)
    x = x * scale
    x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('data','model',None)))
    return x
  
  activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, P('data','model',None)))
  activation = rms_norm(activation)
  activation = checkpoint_name(activation, name='all_gather')
  activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, P('data',None,None)))
  activation = jax.lax.dot_general(activation, up, dimension_numbers=(((2), (0)), ((), ())))
  activation = jax.lax.dot_general(activation, down, dimension_numbers=(((2), (0)), ((), ())))
  return jnp.mean(activation)


jitted = jax.jit(jax.grad(forward), in_shardings=(NamedSharding(mesh, P('data','model',None)), NamedSharding(mesh, P('data', 'model')), NamedSharding(mesh, P('model', 'data')), NamedSharding(mesh, P(None,))))
result = jitted(activation, up, down, scale)
print(result)
print(jitted.lower(activation, up, down, scale).compile().as_text())

HLO


[[[-0.14977819  0.10456353  0.17018048 ...  0.18079197  0.0629528
    0.23236921]
  [-0.05597442  0.04220502  0.32394016 ...  0.04080057 -0.05255011
    0.2142255 ]
  [-0.00090512  0.14764619  0.22220817 ...  0.2685859   0.20059824
   -0.01242509]
  ...
  [-0.05803363  0.18066119  0.23036703 ...  0.33949146 -0.01644355
    0.25753504]
  [ 0.05420677 -0.02029712  0.06926298 ...  0.34460986  0.13754913
    0.15689133]
  [ 0.06267975  0.02606885  0.26323915 ...  0.08300793  0.18933085
    0.20336568]]

 [[-0.01766005  0.06890903  0.3483382  ...  0.04940775  0.008073
    0.23060131]
  [ 0.03536307 -0.06162715  0.19932668 ...  0.1174902  -0.00680336
   -0.00751439]
  [-0.11244462 -0.0749916   0.17251675 ...  0.27089232  0.00431225
    0.20271495]
  ...
  [-0.14335199  0.14109462  0.33877328 ...  0.245024   -0.03044733
    0.09327929]
  [-0.15114704 -0.04898666  0.20905332 ...  0.12068814  0.18083549
    0.23931928]
  [-0.19988659  0.00061852  0.30437952 ...  0.24774578  0.1629085
    0.01701334]]]
HloModule jit_forward, is_scheduled=true, entry_computation_layout={(f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)})->f32[1,512,1024]{2,1,0:T(8,128)}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=8

%add.clone (x.1: f32[], y.1: f32[]) -> f32[] {
  %y.1 = f32[]{:T(256)} parameter(1)
  %x.1 = f32[]{:T(256)} parameter(0)
  ROOT %add.3 = f32[]{:T(256)} add(f32[]{:T(256)} %x.1, f32[]{:T(256)} %y.1)
}

%fused_computation.2.clone.clone (param_0.19: bf16[]) -> bf16[2048,1024] {
  %param_0.19 = bf16[]{:T(512)S(6)} parameter(0)
  ROOT %broadcast.43 = bf16[2048,1024]{1,0:T(8,128)(2,1)} broadcast(bf16[]{:T(512)S(6)} %param_0.19), dimensions={}
}

%bitcast_fusion.1 (bitcast_input.1: bf16[1024,1024]) -> bf16[1024,1024] {
  %bitcast_input.1 = bf16[1024,1024]{1,0:T(8,128)(2,1)} parameter(0)
  ROOT %bitcast.4 = bf16[1024,1024]{1,0:T(8,128)(2,1)} bitcast(bf16[1024,1024]{1,0:T(8,128)(2,1)} %bitcast_input.1)
}

%fused_computation.1.clone.clone (param_0.20: bf16[1024,1024], param_1.18: bf16[]) -> bf16[2048,1024] {
  %param_1.18 = bf16[]{:T(512)S(6)} parameter(1)
  %fusion.8 = bf16[2048,1024]{1,0:T(8,128)(2,1)} fusion(bf16[]{:T(512)S(6)} %param_1.18), kind=kLoop, calls=%fused_computation.2.clone.clone
  %param_0.20 = bf16[1024,1024]{1,0:T(8,128)(2,1)} parameter(0)
  %fusion.17 = bf16[1024,1024]{1,0:T(8,128)(2,1)} fusion(bf16[1024,1024]{1,0:T(8,128)(2,1)} %param_0.20), kind=kLoop, calls=%bitcast_fusion.1
  ROOT %convolution.5 = bf16[2048,1024]{1,0:T(8,128)(2,1)} convolution(bf16[2048,1024]{1,0:T(8,128)(2,1)} %fusion.8, bf16[1024,1024]{1,0:T(8,128)(2,1)} %fusion.17), dim_labels=bf_oi->bf, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=48}
}

%bitcast_fusion (bitcast_input: bf16[1024,1024]) -> bf16[1024,1024] {
  %bitcast_input = bf16[1024,1024]{1,0:T(8,128)(2,1)} parameter(0)
  ROOT %bitcast.3 = bf16[1024,1024]{1,0:T(8,128)(2,1)} bitcast(bf16[1024,1024]{1,0:T(8,128)(2,1)} %bitcast_input)
}

%fused_computation (param_0.21: bf16[1024,1024], param_1.19: bf16[1024,1024], param_2.14: bf16[]) -> f32[1,2048,1024] {
  %param_1.19 = bf16[1024,1024]{1,0:T(8,128)(2,1)} parameter(1)
  %param_2.14 = bf16[]{:T(512)S(6)} parameter(2)
  %fusion.1.clone.1 = bf16[2048,1024]{1,0:T(8,128)(2,1)} fusion(bf16[1024,1024]{1,0:T(8,128)(2,1)} %param_1.19, bf16[]{:T(512)S(6)} %param_2.14), kind=kOutput, calls=%fused_computation.1.clone.clone, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=48}
  %param_0.21 = bf16[1024,1024]{1,0:T(8,128)(2,1)} parameter(0)
  %fusion.16 = bf16[1024,1024]{1,0:T(8,128)(2,1)} fusion(bf16[1024,1024]{1,0:T(8,128)(2,1)} %param_0.21), kind=kLoop, calls=%bitcast_fusion
  %convolution.2 = f32[2048,1024]{1,0:T(8,128)} convolution(bf16[2048,1024]{1,0:T(8,128)(2,1)} %fusion.1.clone.1, bf16[1024,1024]{1,0:T(8,128)(2,1)} %fusion.16), dim_labels=bf_oi->bf, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=47}
  ROOT %bitcast.2 = f32[1,2048,1024]{2,1,0:T(8,128)} bitcast(f32[2048,1024]{1,0:T(8,128)} %convolution.2), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=47}
}

%fused_computation.3 (param_0.16: f32[1,512,1024], param_1.16: f32[1,2048,1024], param_2.33: s32[], param_3.24: f32[512], param_4.15: f32[512], param_5.7: f32[1024]) -> f32[1,512,1024] {
  %param_0.16 = f32[1,512,1024]{2,1,0:T(8,128)} parameter(0)
  %add.12 = f32[1,512,1024]{2,1,0:T(8,128)} add(f32[1,512,1024]{2,1,0:T(8,128)} %param_0.16, f32[1,512,1024]{2,1,0:T(8,128)} %param_0.16), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
  %param_4.15 = f32[512]{0:T(512)} parameter(4)
  %broadcast.38 = f32[1,512,1024]{2,1,0:T(8,128)} broadcast(f32[512]{0:T(512)} %param_4.15), dimensions={1}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/broadcast_in_dim" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
  %multiply.20 = f32[1,512,1024]{2,1,0:T(8,128)} multiply(f32[1,512,1024]{2,1,0:T(8,128)} %add.12, f32[1,512,1024]{2,1,0:T(8,128)} %broadcast.38), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/add_any" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
  %param_1.16 = f32[1,2048,1024]{2,1,0:T(8,128)} parameter(1)
  %constant.49 = s32[]{:T(256)} constant(0), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}
  %param_2.33 = s32[]{:T(256)} parameter(2)
  %dynamic-slice.12 = f32[1,512,1024]{2,1,0:T(8,128)} dynamic-slice(f32[1,2048,1024]{2,1,0:T(8,128)} %param_1.16, s32[]{:T(256)} %constant.49, s32[]{:T(256)} %param_2.33, s32[]{:T(256)} %constant.49), dynamic_slice_sizes={1,512,1024}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294965759","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[]}
  %param_5.7 = f32[1024]{0:T(1024)} parameter(5)
  %broadcast.39 = f32[1,512,1024]{2,1,0:T(8,128)} broadcast(f32[1024]{0:T(1024)} %param_5.7), dimensions={2}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=39}
  %multiply.27 = f32[1,512,1024]{2,1,0:T(8,128)} multiply(f32[1,512,1024]{2,1,0:T(8,128)} %dynamic-slice.12, f32[1,512,1024]{2,1,0:T(8,128)} %broadcast.39), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=39}
  %param_3.24 = f32[512]{0:T(512)} parameter(3)
  %broadcast.36 = f32[1,512,1024]{2,1,0:T(8,128)} broadcast(f32[512]{0:T(512)} %param_3.24), dimensions={1}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
  %multiply.21 = f32[1,512,1024]{2,1,0:T(8,128)} multiply(f32[1,512,1024]{2,1,0:T(8,128)} %multiply.27, f32[1,512,1024]{2,1,0:T(8,128)} %broadcast.36), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
  ROOT %add.11 = f32[1,512,1024]{2,1,0:T(8,128)} add(f32[1,512,1024]{2,1,0:T(8,128)} %multiply.20, f32[1,512,1024]{2,1,0:T(8,128)} %multiply.21), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/add_any" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
}

%region_1.47 (Arg_0.48: f32[], Arg_1.49: f32[]) -> f32[] {
  %Arg_1.49 = f32[]{:T(256)} parameter(1), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum"}
  %Arg_0.48 = f32[]{:T(256)} parameter(0), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum"}
  ROOT %add.50 = f32[]{:T(256)} add(f32[]{:T(256)} %Arg_0.48, f32[]{:T(256)} %Arg_1.49), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
}

%region_0.24 (Arg_0.25: f32[], Arg_1.26: f32[]) -> f32[] {
  %Arg_1.26 = f32[]{:T(256)} parameter(1), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/reduce_sum"}
  %Arg_0.25 = f32[]{:T(256)} parameter(0), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/reduce_sum"}
  ROOT %add.27 = f32[]{:T(256)} add(f32[]{:T(256)} %Arg_0.25, f32[]{:T(256)} %Arg_1.26), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
}

%fused_computation.5 (param_0.44: f32[1,512,1024], param_1.46: f32[1,2048,1024], param_2.36: s32[], param_3.27: f32[1024]) -> (f32[512], f32[512]) {
  %param_0.44 = f32[1,512,1024]{2,1,0:T(8,128)} parameter(0)
  %param_1.46 = f32[1,2048,1024]{2,1,0:T(8,128)} parameter(1)
  %constant.50 = s32[]{:T(256)} constant(0), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}
  %param_2.36 = s32[]{:T(256)} parameter(2)
  %dynamic-slice.14 = f32[1,512,1024]{2,1,0:T(8,128)} dynamic-slice(f32[1,2048,1024]{2,1,0:T(8,128)} %param_1.46, s32[]{:T(256)} %constant.50, s32[]{:T(256)} %param_2.36, s32[]{:T(256)} %constant.50), dynamic_slice_sizes={1,512,1024}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294965759","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[]}
  %param_3.27 = f32[1024]{0:T(1024)} parameter(3)
  %broadcast.40 = f32[1,512,1024]{2,1,0:T(8,128)} broadcast(f32[1024]{0:T(1024)} %param_3.27), dimensions={2}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=39}
  %multiply.29 = f32[1,512,1024]{2,1,0:T(8,128)} multiply(f32[1,512,1024]{2,1,0:T(8,128)} %dynamic-slice.14, f32[1,512,1024]{2,1,0:T(8,128)} %broadcast.40), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=39}
  %multiply.24 = f32[1,512,1024]{2,1,0:T(8,128)} multiply(f32[1,512,1024]{2,1,0:T(8,128)} %param_0.44, f32[1,512,1024]{2,1,0:T(8,128)} %multiply.29), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
  %constant.58 = f32[]{:T(256)} constant(0)
  %reduce.7 = f32[512]{0:T(512)} reduce(f32[1,512,1024]{2,1,0:T(8,128)} %multiply.24, f32[]{:T(256)} %constant.58), dimensions={0,2}, to_apply=%region_1.47, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
  %multiply.22.clone.1 = f32[1,512,1024]{2,1,0:T(8,128)} multiply(f32[1,512,1024]{2,1,0:T(8,128)} %param_0.44, f32[1,512,1024]{2,1,0:T(8,128)} %param_0.44), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
  %reduce.6.clone.1 = f32[512]{0:T(512)} reduce(f32[1,512,1024]{2,1,0:T(8,128)} %multiply.22.clone.1, f32[]{:T(256)} %constant.58), dimensions={0,2}, to_apply=%region_0.24, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
  ROOT %tuple.1 = (f32[512]{0:T(512)}, f32[512]{0:T(512)}) tuple(f32[512]{0:T(512)} %reduce.7, f32[512]{0:T(512)} %reduce.6.clone.1)
}

%multiply.15.reduce_sub_computation (lhs: f32[], rhs: f32[]) -> f32[] {
  %rhs = f32[] parameter(1)
  %lhs = f32[] parameter(0)
  ROOT %add.9 = f32[] add(f32[] %lhs, f32[] %rhs)
}

%rsqrt.2.reduce_sub_computation (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
  %rhs.1 = f32[] parameter(1)
  %lhs.1 = f32[] parameter(0)
  ROOT %add.10 = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
}

%fused_computation.12 (param_0.41: f32[512], param_1.44: f32[512]) -> (f32[512], f32[512]) {
  %param_1.44 = f32[512]{0:T(512)} parameter(1)
  %reshape.49 = f32[1,512]{1,0:T(2,128)} reshape(f32[512]{0:T(512)} %param_1.44), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
  %param_0.41 = f32[512]{0:T(512)} parameter(0)
  %reshape.48 = f32[1,512]{1,0:T(2,128)} reshape(f32[512]{0:T(512)} %param_0.41), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/add" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
  %rsqrt.12 = f32[1,512]{1,0:T(2,128)} rsqrt(f32[1,512]{1,0:T(2,128)} %reshape.48), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/rsqrt" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
  %divide.10 = f32[1,512]{1,0:T(2,128)} divide(f32[1,512]{1,0:T(2,128)} %rsqrt.12, f32[1,512]{1,0:T(2,128)} %reshape.48), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/div" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
  %constant.53 = f32[]{:T(256)} constant(-0.5)
  %broadcast.45 = f32[1,512]{1,0:T(2,128)} broadcast(f32[]{:T(256)} %constant.53), dimensions={}
  %multiply.47 = f32[1,512]{1,0:T(2,128)} multiply(f32[1,512]{1,0:T(2,128)} %divide.10, f32[1,512]{1,0:T(2,128)} %broadcast.45), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
  %multiply.44 = f32[1,512]{1,0:T(2,128)} multiply(f32[1,512]{1,0:T(2,128)} %reshape.49, f32[1,512]{1,0:T(2,128)} %multiply.47), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
  %constant.56 = f32[]{:T(256)} constant(0.0009765625)
  %broadcast.46 = f32[1,512]{1,0:T(2,128)} broadcast(f32[]{:T(256)} %constant.56), dimensions={}
  %multiply.43 = f32[1,512]{1,0:T(2,128)} multiply(f32[1,512]{1,0:T(2,128)} %multiply.44, f32[1,512]{1,0:T(2,128)} %broadcast.46), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
  %constant.51 = f32[] constant(-0), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
  %reduce.9 = f32[512]{0:T(512)} reduce(f32[1,512]{1,0:T(2,128)} %multiply.43, f32[] %constant.51), dimensions={0}, to_apply=%multiply.15.reduce_sub_computation, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
  %reduce.8.clone.1 = f32[512]{0:T(512)} reduce(f32[1,512]{1,0:T(2,128)} %rsqrt.12, f32[] %constant.51), dimensions={0}, to_apply=%rsqrt.2.reduce_sub_computation, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/rsqrt" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
  ROOT %tuple.2 = (f32[512]{0:T(512)}, f32[512]{0:T(512)}) tuple(f32[512]{0:T(512)} %reduce.9, f32[512]{0:T(512)} %reduce.8.clone.1)
}

%fused_computation.13 (param_0.40: f32[512]) -> f32[512] {
  %param_0.40 = f32[512]{0:T(512)} parameter(0)
  %constant.55 = f32[]{:T(256)} constant(0.0009765625)
  %broadcast.48 = f32[512]{0:T(512)} broadcast(f32[]{:T(256)} %constant.55), dimensions={}
  %multiply.48 = f32[512]{0:T(512)} multiply(f32[512]{0:T(512)} %param_0.40, f32[512]{0:T(512)} %broadcast.48), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
  %constant.54 = f32[]{:T(256)} constant(1e-08)
  %broadcast.47 = f32[512]{0:T(512)} broadcast(f32[]{:T(256)} %constant.54), dimensions={}
  ROOT %add.13 = f32[512]{0:T(512)} add(f32[512]{0:T(512)} %multiply.48, f32[512]{0:T(512)} %broadcast.47), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/add" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
}

ENTRY %main.71_spmd (param: f32[1,512,1024], param.1: f32[512,1024], param.2: f32[1024,512], param.3: f32[1024]) -> f32[1,512,1024] {
  %constant.16 = f32[]{:T(256)} constant(2.38418579e-07)
  %constant.15 = f32[]{:T(256)} constant(1)
  %constant.29 = s32[8]{0:T(256)} constant({0, 512, 1024, 1536, 0, 512, 1024, 1536}), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}
  %param.3 = f32[1024]{0:T(1024)} parameter(3), sharding={replicated}, metadata={op_name="scale"}
  %param.2 = f32[1024,512]{1,0:T(8,128)} parameter(2), sharding={devices=[4,2]<=[2,4]T(1,0)}, metadata={op_name="down"}
  %param.1 = f32[512,1024]{1,0:T(8,128)} parameter(1), sharding={devices=[2,4]<=[8]}, metadata={op_name="up"}
  %param = f32[1,512,1024]{2,1,0:T(8,128)} parameter(0), sharding={devices=[2,4,1]<=[8]}, metadata={op_name="activation"}
  %partition-id = u32[]{:T(256)} partition-id()
  %dynamic-slice.7 = s32[1]{0:T(256)} dynamic-slice(s32[8]{0:T(256)} %constant.29, u32[]{:T(256)} %partition-id), dynamic_slice_sizes={1}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1"],"input_window_bounds":[],"estimated_cycles":"199","iteration_bounds":["1"]},"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"21504"}],"retry_config":{"retry_count":"0"}}
  %bitcast.1 = s32[]{:T(256)} bitcast(s32[1]{0:T(256)} %dynamic-slice.7), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}
  %tuple = (f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[]{:T(256)}) tuple(f32[1,512,1024]{2,1,0:T(8,128)} %param, f32[512,1024]{1,0:T(8,128)} %param.1, f32[1024,512]{1,0:T(8,128)} %param.2, f32[1024]{0:T(1024)} %param.3, f32[]{:T(256)} %constant.15), metadata={op_name="jit(forward)/jit(main)/remat2" source_file="<ipython-input-18-8aea41049d30>" source_line=53}
  %get-tuple-element.1 = f32[]{:T(256)} get-tuple-element((f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[]{:T(256)}) %tuple), index=4, metadata={op_name="jit(forward)/jit(main)/remat2" source_file="<ipython-input-18-8aea41049d30>" source_line=53}
  %multiply.8 = bf16[]{:T(512)S(6)} multiply(f32[]{:T(256)} %get-tuple-element.1, f32[]{:T(256)} %constant.16), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=49}
  %get-tuple-element.4 = f32[1024]{0:T(1024)} get-tuple-element((f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[]{:T(256)}) %tuple), index=3, metadata={op_name="jit(forward)/jit(main)/remat2" source_file="<ipython-input-18-8aea41049d30>" source_line=53}
  %get-tuple-element = f32[1,512,1024]{2,1,0:T(8,128)} get-tuple-element((f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[]{:T(256)}) %tuple), index=0, metadata={op_name="jit(forward)/jit(main)/remat2" source_file="<ipython-input-18-8aea41049d30>" source_line=53}
  %get-tuple-element.3 = f32[512,1024]{1,0:T(8,128)} get-tuple-element((f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[]{:T(256)}) %tuple), index=1, metadata={op_name="jit(forward)/jit(main)/remat2" source_file="<ipython-input-18-8aea41049d30>" source_line=53}
  %get-tuple-element.2 = f32[1024,512]{1,0:T(8,128)} get-tuple-element((f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[]{:T(256)}) %tuple), index=2, metadata={op_name="jit(forward)/jit(main)/remat2" source_file="<ipython-input-18-8aea41049d30>" source_line=53}
  %convert.1 = bf16[1024,512]{1,0:T(8,128)(2,1)} convert(f32[1024,512]{1,0:T(8,128)} %get-tuple-element.2), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32","4"],"input_window_bounds":[],"estimated_cycles":"12704","iteration_bounds":["4","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"1572864"}],"retry_config":{"retry_count":"0"}}
  %all-gather = bf16[1024,1024]{1,0:T(8,128)(2,1)} all-gather(bf16[1024,512]{1,0:T(8,128)(2,1)} %convert.1), channel_id=1, replica_groups={{0,4},{1,5},{2,6},{3,7}}, dimensions={1}, use_global_device_ids=true, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=48}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"CUSTOM","id":"0"},"scoped_memory_configs":[],"collective_algorithm_config":{"emitter":"1DAllGatherNonMajorDim","debug":"\ngroup_size = 2\nper_stride_size = 8192 bytes\nshard_size = 1048576 bytes"},"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"2097152"}],"retry_config":{"retry_count":"0"}}
  %convert = bf16[512,1024]{1,0:T(8,128)(2,1)} convert(f32[512,1024]{1,0:T(8,128)} %get-tuple-element.3), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16","8"],"input_window_bounds":[],"estimated_cycles":"12704","iteration_bounds":["4","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"1572864"}],"retry_config":{"retry_count":"0"}}
  %all-gather.1 = bf16[1024,1024]{1,0:T(8,128)(2,1)} all-gather(bf16[512,1024]{1,0:T(8,128)(2,1)} %convert), channel_id=2, replica_groups=[4,2]<=[2,4]T(1,0), dimensions={0}, use_global_device_ids=true, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=47}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"CUSTOM","id":"0"},"scoped_memory_configs":[],"collective_algorithm_config":{"emitter":"1DAllGatherNonMajorDim","debug":"\ngroup_size = 2\nper_stride_size = 1048576 bytes\nshard_size = 1048576 bytes"},"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"2097152"}],"retry_config":{"retry_count":"0"}}
  %fusion = f32[1,2048,1024]{2,1,0:T(8,128)} fusion(bf16[1024,1024]{1,0:T(8,128)(2,1)} %all-gather.1, bf16[1024,1024]{1,0:T(8,128)(2,1)} %all-gather, bf16[]{:T(512)S(6)} %multiply.8), kind=kOutput, calls=%fused_computation, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=47}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":["128","2"],"output_window_bounds":["128","8"],"input_window_bounds":["128","2"],"estimated_cycles":"319985","iteration_bounds":["1","2","4"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"14073856"}],"retry_config":{"retry_count":"0"},"convolution_algorithm_config":{"emitter":"EmitAllBatchInSublanes"}}
  %all-reduce = f32[1,2048,1024]{2,1,0:T(8,128)} all-reduce(f32[1,2048,1024]{2,1,0:T(8,128)} %fusion), channel_id=3, replica_groups={{0,1,2,3},{4,5,6,7}}, use_global_device_ids=true, to_apply=%add.clone, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=47}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"CUSTOM","id":"1"},"scoped_memory_configs":[{"memory_space":"0","offset":"0","size":"67108864"}],"collective_algorithm_config":{"emitter":"RotatedPincerEmitter","strategy":"UniDirection1DRingStrategy","debug":"\nUniDirection1DRingStrategy{colors:2 phases:1 cores:{4},{4} nophase0:0 reserved_sflags:0 cross_module_on_2d_plane:0 has_reordering_map:0 use_routing_table_indices:0}"},"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"8388608"}],"retry_config":{"retry_count":"0"}}
  %fusion.5 = (f32[512]{0:T(512)}, f32[512]{0:T(512)}) fusion(f32[1,512,1024]{2,1,0:T(8,128)} %get-tuple-element, f32[1,2048,1024]{2,1,0:T(8,128)} %all-reduce, s32[]{:T(256)} %bitcast.1, f32[1024]{0:T(1024)} %get-tuple-element.4), kind=kLoop, calls=%fused_computation.5, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=37}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","16","8"],"input_window_bounds":[],"estimated_cycles":"15476","iteration_bounds":["1","4","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"2326528"}],"retry_config":{"retry_count":"0"}}
  %get-tuple-element.6 = f32[512]{0:T(512)} get-tuple-element((f32[512]{0:T(512)}, f32[512]{0:T(512)}) %fusion.5), index=1, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
  %get-tuple-element.5 = f32[512]{0:T(512)} get-tuple-element((f32[512]{0:T(512)}, f32[512]{0:T(512)}) %fusion.5), index=0, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
  %fusion.15 = f32[512]{0:T(512)} fusion(f32[512]{0:T(512)} %get-tuple-element.6), kind=kLoop, calls=%fused_computation.13, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/add" source_file="<ipython-input-18-8aea41049d30>" source_line=37}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1"],"input_window_bounds":[],"estimated_cycles":"205","iteration_bounds":["1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"6144"}],"retry_config":{"retry_count":"0"}}
  %fusion.14 = (f32[512]{0:T(512)}, f32[512]{0:T(512)}) fusion(f32[512]{0:T(512)} %fusion.15, f32[512]{0:T(512)} %get-tuple-element.5), kind=kLoop, calls=%fused_computation.12, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","4"],"input_window_bounds":[],"estimated_cycles":"252","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"14336"}],"retry_config":{"retry_count":"0"}}
  %get-tuple-element.8 = f32[512]{0:T(512)} get-tuple-element((f32[512]{0:T(512)}, f32[512]{0:T(512)}) %fusion.14), index=1, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
  %get-tuple-element.7 = f32[512]{0:T(512)} get-tuple-element((f32[512]{0:T(512)}, f32[512]{0:T(512)}) %fusion.14), index=0, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
  ROOT %fusion.3 = f32[1,512,1024]{2,1,0:T(8,128)} fusion(f32[1,512,1024]{2,1,0:T(8,128)} %get-tuple-element, f32[1,2048,1024]{2,1,0:T(8,128)} %all-reduce, s32[]{:T(256)} %bitcast.1, f32[512]{0:T(512)} %get-tuple-element.8, f32[512]{0:T(512)} %get-tuple-element.7, /*index=5*/f32[1024]{0:T(1024)} %get-tuple-element.4), kind=kLoop, calls=%fused_computation.3, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/add_any" source_file="<ipython-input-18-8aea41049d30>" source_line=37}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","16","8"],"input_window_bounds":[],"estimated_cycles":"15604","iteration_bounds":["1","4","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"3346432"}],"retry_config":{"retry_count":"0"}}
}