How do you remat GSPMD inserted all-gathers?
Opened this issue · 2 comments
Problem: I have some Jax code that does sequence parallel, so somewhat similar to this
activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('data', 'tensor', None))
activation = norm(activation)
activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('None, 'tensor', None))
# I want to remat this one ^
activation = attention(activation)
I have tried everything I can to remat the activation directly before attention, including Jax policies, explicitly using jax checkpoint on that exact tensor, but nothing to seems to make it remat. The activation directly before attention is a GSPMD inserted all-gather on the sequence dimension (dim=0).
I ended up writing an XLA pass to rematerialize large all-gathers and submitted a PR. openxla/xla#19163
Question: Is this possible to do from Jax end or is my pass really needed?
Thanks for the question.
No, I don't think a new pass is needed.
As I understand it, the standard way to spell this is to us a remat policy to mark the with_sharding_constraint
which induces the allgather as not-saveable. One way to do that would be to use save_only_these_names
and to only name other arrays (that are either upstream of the allgather-inducing with_sharding_constraint
, or downstream of the operations that use the output of attention
). Following your snippet, that might look something like:
activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('data', 'tensor', None))
activation = checkpoint_name(norm(activation), 'scattered_activations')
activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('None, 'tensor', None))
activation = attention(activation)
together with a save_only_these_names
policy that mentions 'scattered_activations'
or something upstream of it.
Did you try something like that? If you already tried it, we should put together a minimal example to debug what's going on.
I am having problems reproducing this in a unit test as GSPMD decides to not all-gather the sequence parallel activation and use a convolution op instead. I do not understand why there are no dot ops in my HLO. Perhaps, my unit test is broken. Any ideas @mattjj ? The only all-gathers here are FSDP all-gathers which are ok.
from functools import partial
from jax.sharding import NamedSharding
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.ad_checkpoint import checkpoint_name
mesh = mesh = jax.make_mesh((2,4), ('data', 'model'))
key = jax.random.PRNGKey(0)
activation_weight = jax.random.uniform(key, (2, 2048, 1024))
key = jax.random.PRNGKey(1)
up_weight = jax.random.uniform(key, (1024, 4*1024))
key = jax.random.PRNGKey(2)
down_weight = jax.random.uniform(key, (4*1024, 1024))
key = jax.random.PRNGKey(3)
scale_weight = jax.random.uniform(key, (1024))
activation = jax.device_put(activation_weight, NamedSharding(mesh, P('data', 'model', None)))
up = jax.device_put(up_weight, NamedSharding(mesh, P('data', 'model')))
down = jax.device_put(down_weight, NamedSharding(mesh, P('model', 'data')))
scale = jax.device_put(scale_weight, NamedSharding(mesh, P(None,)))
@partial(jax.checkpoint,
policy=jax.checkpoint_policies.save_any_names_but_these('all_gather'))
def forward(activation, up, down, scale):
def rms_norm(x):
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('data','model',None)))
x_dtype = x.dtype
x = x.astype(jnp.float32)
moment2 = (x * x).mean(axis=-1, keepdims=True)
x = x * jax.lax.rsqrt(moment2 + 1e-8)
x = x.astype(x_dtype)
x = x * scale
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('data','model',None)))
return x
activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, P('data','model',None)))
activation = rms_norm(activation)
activation = checkpoint_name(activation, name='all_gather')
activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, P('data',None,None)))
activation = jax.lax.dot_general(activation, up, dimension_numbers=(((2), (0)), ((), ())))
activation = jax.lax.dot_general(activation, down, dimension_numbers=(((2), (0)), ((), ())))
return jnp.mean(activation)
jitted = jax.jit(jax.grad(forward), in_shardings=(NamedSharding(mesh, P('data','model',None)), NamedSharding(mesh, P('data', 'model')), NamedSharding(mesh, P('model', 'data')), NamedSharding(mesh, P(None,))))
result = jitted(activation, up, down, scale)
print(result)
print(jitted.lower(activation, up, down, scale).compile().as_text())
HLO
[[[-0.14977819 0.10456353 0.17018048 ... 0.18079197 0.0629528
0.23236921]
[-0.05597442 0.04220502 0.32394016 ... 0.04080057 -0.05255011
0.2142255 ]
[-0.00090512 0.14764619 0.22220817 ... 0.2685859 0.20059824
-0.01242509]
...
[-0.05803363 0.18066119 0.23036703 ... 0.33949146 -0.01644355
0.25753504]
[ 0.05420677 -0.02029712 0.06926298 ... 0.34460986 0.13754913
0.15689133]
[ 0.06267975 0.02606885 0.26323915 ... 0.08300793 0.18933085
0.20336568]]
[[-0.01766005 0.06890903 0.3483382 ... 0.04940775 0.008073
0.23060131]
[ 0.03536307 -0.06162715 0.19932668 ... 0.1174902 -0.00680336
-0.00751439]
[-0.11244462 -0.0749916 0.17251675 ... 0.27089232 0.00431225
0.20271495]
...
[-0.14335199 0.14109462 0.33877328 ... 0.245024 -0.03044733
0.09327929]
[-0.15114704 -0.04898666 0.20905332 ... 0.12068814 0.18083549
0.23931928]
[-0.19988659 0.00061852 0.30437952 ... 0.24774578 0.1629085
0.01701334]]]
HloModule jit_forward, is_scheduled=true, entry_computation_layout={(f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)})->f32[1,512,1024]{2,1,0:T(8,128)}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=8
%add.clone (x.1: f32[], y.1: f32[]) -> f32[] {
%y.1 = f32[]{:T(256)} parameter(1)
%x.1 = f32[]{:T(256)} parameter(0)
ROOT %add.3 = f32[]{:T(256)} add(f32[]{:T(256)} %x.1, f32[]{:T(256)} %y.1)
}
%fused_computation.2.clone.clone (param_0.19: bf16[]) -> bf16[2048,1024] {
%param_0.19 = bf16[]{:T(512)S(6)} parameter(0)
ROOT %broadcast.43 = bf16[2048,1024]{1,0:T(8,128)(2,1)} broadcast(bf16[]{:T(512)S(6)} %param_0.19), dimensions={}
}
%bitcast_fusion.1 (bitcast_input.1: bf16[1024,1024]) -> bf16[1024,1024] {
%bitcast_input.1 = bf16[1024,1024]{1,0:T(8,128)(2,1)} parameter(0)
ROOT %bitcast.4 = bf16[1024,1024]{1,0:T(8,128)(2,1)} bitcast(bf16[1024,1024]{1,0:T(8,128)(2,1)} %bitcast_input.1)
}
%fused_computation.1.clone.clone (param_0.20: bf16[1024,1024], param_1.18: bf16[]) -> bf16[2048,1024] {
%param_1.18 = bf16[]{:T(512)S(6)} parameter(1)
%fusion.8 = bf16[2048,1024]{1,0:T(8,128)(2,1)} fusion(bf16[]{:T(512)S(6)} %param_1.18), kind=kLoop, calls=%fused_computation.2.clone.clone
%param_0.20 = bf16[1024,1024]{1,0:T(8,128)(2,1)} parameter(0)
%fusion.17 = bf16[1024,1024]{1,0:T(8,128)(2,1)} fusion(bf16[1024,1024]{1,0:T(8,128)(2,1)} %param_0.20), kind=kLoop, calls=%bitcast_fusion.1
ROOT %convolution.5 = bf16[2048,1024]{1,0:T(8,128)(2,1)} convolution(bf16[2048,1024]{1,0:T(8,128)(2,1)} %fusion.8, bf16[1024,1024]{1,0:T(8,128)(2,1)} %fusion.17), dim_labels=bf_oi->bf, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=48}
}
%bitcast_fusion (bitcast_input: bf16[1024,1024]) -> bf16[1024,1024] {
%bitcast_input = bf16[1024,1024]{1,0:T(8,128)(2,1)} parameter(0)
ROOT %bitcast.3 = bf16[1024,1024]{1,0:T(8,128)(2,1)} bitcast(bf16[1024,1024]{1,0:T(8,128)(2,1)} %bitcast_input)
}
%fused_computation (param_0.21: bf16[1024,1024], param_1.19: bf16[1024,1024], param_2.14: bf16[]) -> f32[1,2048,1024] {
%param_1.19 = bf16[1024,1024]{1,0:T(8,128)(2,1)} parameter(1)
%param_2.14 = bf16[]{:T(512)S(6)} parameter(2)
%fusion.1.clone.1 = bf16[2048,1024]{1,0:T(8,128)(2,1)} fusion(bf16[1024,1024]{1,0:T(8,128)(2,1)} %param_1.19, bf16[]{:T(512)S(6)} %param_2.14), kind=kOutput, calls=%fused_computation.1.clone.clone, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=48}
%param_0.21 = bf16[1024,1024]{1,0:T(8,128)(2,1)} parameter(0)
%fusion.16 = bf16[1024,1024]{1,0:T(8,128)(2,1)} fusion(bf16[1024,1024]{1,0:T(8,128)(2,1)} %param_0.21), kind=kLoop, calls=%bitcast_fusion
%convolution.2 = f32[2048,1024]{1,0:T(8,128)} convolution(bf16[2048,1024]{1,0:T(8,128)(2,1)} %fusion.1.clone.1, bf16[1024,1024]{1,0:T(8,128)(2,1)} %fusion.16), dim_labels=bf_oi->bf, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=47}
ROOT %bitcast.2 = f32[1,2048,1024]{2,1,0:T(8,128)} bitcast(f32[2048,1024]{1,0:T(8,128)} %convolution.2), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=47}
}
%fused_computation.3 (param_0.16: f32[1,512,1024], param_1.16: f32[1,2048,1024], param_2.33: s32[], param_3.24: f32[512], param_4.15: f32[512], param_5.7: f32[1024]) -> f32[1,512,1024] {
%param_0.16 = f32[1,512,1024]{2,1,0:T(8,128)} parameter(0)
%add.12 = f32[1,512,1024]{2,1,0:T(8,128)} add(f32[1,512,1024]{2,1,0:T(8,128)} %param_0.16, f32[1,512,1024]{2,1,0:T(8,128)} %param_0.16), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
%param_4.15 = f32[512]{0:T(512)} parameter(4)
%broadcast.38 = f32[1,512,1024]{2,1,0:T(8,128)} broadcast(f32[512]{0:T(512)} %param_4.15), dimensions={1}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/broadcast_in_dim" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
%multiply.20 = f32[1,512,1024]{2,1,0:T(8,128)} multiply(f32[1,512,1024]{2,1,0:T(8,128)} %add.12, f32[1,512,1024]{2,1,0:T(8,128)} %broadcast.38), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/add_any" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
%param_1.16 = f32[1,2048,1024]{2,1,0:T(8,128)} parameter(1)
%constant.49 = s32[]{:T(256)} constant(0), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}
%param_2.33 = s32[]{:T(256)} parameter(2)
%dynamic-slice.12 = f32[1,512,1024]{2,1,0:T(8,128)} dynamic-slice(f32[1,2048,1024]{2,1,0:T(8,128)} %param_1.16, s32[]{:T(256)} %constant.49, s32[]{:T(256)} %param_2.33, s32[]{:T(256)} %constant.49), dynamic_slice_sizes={1,512,1024}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294965759","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[]}
%param_5.7 = f32[1024]{0:T(1024)} parameter(5)
%broadcast.39 = f32[1,512,1024]{2,1,0:T(8,128)} broadcast(f32[1024]{0:T(1024)} %param_5.7), dimensions={2}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=39}
%multiply.27 = f32[1,512,1024]{2,1,0:T(8,128)} multiply(f32[1,512,1024]{2,1,0:T(8,128)} %dynamic-slice.12, f32[1,512,1024]{2,1,0:T(8,128)} %broadcast.39), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=39}
%param_3.24 = f32[512]{0:T(512)} parameter(3)
%broadcast.36 = f32[1,512,1024]{2,1,0:T(8,128)} broadcast(f32[512]{0:T(512)} %param_3.24), dimensions={1}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
%multiply.21 = f32[1,512,1024]{2,1,0:T(8,128)} multiply(f32[1,512,1024]{2,1,0:T(8,128)} %multiply.27, f32[1,512,1024]{2,1,0:T(8,128)} %broadcast.36), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
ROOT %add.11 = f32[1,512,1024]{2,1,0:T(8,128)} add(f32[1,512,1024]{2,1,0:T(8,128)} %multiply.20, f32[1,512,1024]{2,1,0:T(8,128)} %multiply.21), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/add_any" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
}
%region_1.47 (Arg_0.48: f32[], Arg_1.49: f32[]) -> f32[] {
%Arg_1.49 = f32[]{:T(256)} parameter(1), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum"}
%Arg_0.48 = f32[]{:T(256)} parameter(0), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum"}
ROOT %add.50 = f32[]{:T(256)} add(f32[]{:T(256)} %Arg_0.48, f32[]{:T(256)} %Arg_1.49), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
}
%region_0.24 (Arg_0.25: f32[], Arg_1.26: f32[]) -> f32[] {
%Arg_1.26 = f32[]{:T(256)} parameter(1), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/reduce_sum"}
%Arg_0.25 = f32[]{:T(256)} parameter(0), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/reduce_sum"}
ROOT %add.27 = f32[]{:T(256)} add(f32[]{:T(256)} %Arg_0.25, f32[]{:T(256)} %Arg_1.26), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
}
%fused_computation.5 (param_0.44: f32[1,512,1024], param_1.46: f32[1,2048,1024], param_2.36: s32[], param_3.27: f32[1024]) -> (f32[512], f32[512]) {
%param_0.44 = f32[1,512,1024]{2,1,0:T(8,128)} parameter(0)
%param_1.46 = f32[1,2048,1024]{2,1,0:T(8,128)} parameter(1)
%constant.50 = s32[]{:T(256)} constant(0), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}
%param_2.36 = s32[]{:T(256)} parameter(2)
%dynamic-slice.14 = f32[1,512,1024]{2,1,0:T(8,128)} dynamic-slice(f32[1,2048,1024]{2,1,0:T(8,128)} %param_1.46, s32[]{:T(256)} %constant.50, s32[]{:T(256)} %param_2.36, s32[]{:T(256)} %constant.50), dynamic_slice_sizes={1,512,1024}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294965759","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[]}
%param_3.27 = f32[1024]{0:T(1024)} parameter(3)
%broadcast.40 = f32[1,512,1024]{2,1,0:T(8,128)} broadcast(f32[1024]{0:T(1024)} %param_3.27), dimensions={2}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=39}
%multiply.29 = f32[1,512,1024]{2,1,0:T(8,128)} multiply(f32[1,512,1024]{2,1,0:T(8,128)} %dynamic-slice.14, f32[1,512,1024]{2,1,0:T(8,128)} %broadcast.40), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=39}
%multiply.24 = f32[1,512,1024]{2,1,0:T(8,128)} multiply(f32[1,512,1024]{2,1,0:T(8,128)} %param_0.44, f32[1,512,1024]{2,1,0:T(8,128)} %multiply.29), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
%constant.58 = f32[]{:T(256)} constant(0)
%reduce.7 = f32[512]{0:T(512)} reduce(f32[1,512,1024]{2,1,0:T(8,128)} %multiply.24, f32[]{:T(256)} %constant.58), dimensions={0,2}, to_apply=%region_1.47, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
%multiply.22.clone.1 = f32[1,512,1024]{2,1,0:T(8,128)} multiply(f32[1,512,1024]{2,1,0:T(8,128)} %param_0.44, f32[1,512,1024]{2,1,0:T(8,128)} %param_0.44), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
%reduce.6.clone.1 = f32[512]{0:T(512)} reduce(f32[1,512,1024]{2,1,0:T(8,128)} %multiply.22.clone.1, f32[]{:T(256)} %constant.58), dimensions={0,2}, to_apply=%region_0.24, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
ROOT %tuple.1 = (f32[512]{0:T(512)}, f32[512]{0:T(512)}) tuple(f32[512]{0:T(512)} %reduce.7, f32[512]{0:T(512)} %reduce.6.clone.1)
}
%multiply.15.reduce_sub_computation (lhs: f32[], rhs: f32[]) -> f32[] {
%rhs = f32[] parameter(1)
%lhs = f32[] parameter(0)
ROOT %add.9 = f32[] add(f32[] %lhs, f32[] %rhs)
}
%rsqrt.2.reduce_sub_computation (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
%rhs.1 = f32[] parameter(1)
%lhs.1 = f32[] parameter(0)
ROOT %add.10 = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
}
%fused_computation.12 (param_0.41: f32[512], param_1.44: f32[512]) -> (f32[512], f32[512]) {
%param_1.44 = f32[512]{0:T(512)} parameter(1)
%reshape.49 = f32[1,512]{1,0:T(2,128)} reshape(f32[512]{0:T(512)} %param_1.44), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
%param_0.41 = f32[512]{0:T(512)} parameter(0)
%reshape.48 = f32[1,512]{1,0:T(2,128)} reshape(f32[512]{0:T(512)} %param_0.41), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/add" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
%rsqrt.12 = f32[1,512]{1,0:T(2,128)} rsqrt(f32[1,512]{1,0:T(2,128)} %reshape.48), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/rsqrt" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
%divide.10 = f32[1,512]{1,0:T(2,128)} divide(f32[1,512]{1,0:T(2,128)} %rsqrt.12, f32[1,512]{1,0:T(2,128)} %reshape.48), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/div" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
%constant.53 = f32[]{:T(256)} constant(-0.5)
%broadcast.45 = f32[1,512]{1,0:T(2,128)} broadcast(f32[]{:T(256)} %constant.53), dimensions={}
%multiply.47 = f32[1,512]{1,0:T(2,128)} multiply(f32[1,512]{1,0:T(2,128)} %divide.10, f32[1,512]{1,0:T(2,128)} %broadcast.45), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
%multiply.44 = f32[1,512]{1,0:T(2,128)} multiply(f32[1,512]{1,0:T(2,128)} %reshape.49, f32[1,512]{1,0:T(2,128)} %multiply.47), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/mul" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
%constant.56 = f32[]{:T(256)} constant(0.0009765625)
%broadcast.46 = f32[1,512]{1,0:T(2,128)} broadcast(f32[]{:T(256)} %constant.56), dimensions={}
%multiply.43 = f32[1,512]{1,0:T(2,128)} multiply(f32[1,512]{1,0:T(2,128)} %multiply.44, f32[1,512]{1,0:T(2,128)} %broadcast.46), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
%constant.51 = f32[] constant(-0), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
%reduce.9 = f32[512]{0:T(512)} reduce(f32[1,512]{1,0:T(2,128)} %multiply.43, f32[] %constant.51), dimensions={0}, to_apply=%multiply.15.reduce_sub_computation, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
%reduce.8.clone.1 = f32[512]{0:T(512)} reduce(f32[1,512]{1,0:T(2,128)} %rsqrt.12, f32[] %constant.51), dimensions={0}, to_apply=%rsqrt.2.reduce_sub_computation, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/rsqrt" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
ROOT %tuple.2 = (f32[512]{0:T(512)}, f32[512]{0:T(512)}) tuple(f32[512]{0:T(512)} %reduce.9, f32[512]{0:T(512)} %reduce.8.clone.1)
}
%fused_computation.13 (param_0.40: f32[512]) -> f32[512] {
%param_0.40 = f32[512]{0:T(512)} parameter(0)
%constant.55 = f32[]{:T(256)} constant(0.0009765625)
%broadcast.48 = f32[512]{0:T(512)} broadcast(f32[]{:T(256)} %constant.55), dimensions={}
%multiply.48 = f32[512]{0:T(512)} multiply(f32[512]{0:T(512)} %param_0.40, f32[512]{0:T(512)} %broadcast.48), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
%constant.54 = f32[]{:T(256)} constant(1e-08)
%broadcast.47 = f32[512]{0:T(512)} broadcast(f32[]{:T(256)} %constant.54), dimensions={}
ROOT %add.13 = f32[512]{0:T(512)} add(f32[512]{0:T(512)} %multiply.48, f32[512]{0:T(512)} %broadcast.47), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/add" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
}
ENTRY %main.71_spmd (param: f32[1,512,1024], param.1: f32[512,1024], param.2: f32[1024,512], param.3: f32[1024]) -> f32[1,512,1024] {
%constant.16 = f32[]{:T(256)} constant(2.38418579e-07)
%constant.15 = f32[]{:T(256)} constant(1)
%constant.29 = s32[8]{0:T(256)} constant({0, 512, 1024, 1536, 0, 512, 1024, 1536}), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}
%param.3 = f32[1024]{0:T(1024)} parameter(3), sharding={replicated}, metadata={op_name="scale"}
%param.2 = f32[1024,512]{1,0:T(8,128)} parameter(2), sharding={devices=[4,2]<=[2,4]T(1,0)}, metadata={op_name="down"}
%param.1 = f32[512,1024]{1,0:T(8,128)} parameter(1), sharding={devices=[2,4]<=[8]}, metadata={op_name="up"}
%param = f32[1,512,1024]{2,1,0:T(8,128)} parameter(0), sharding={devices=[2,4,1]<=[8]}, metadata={op_name="activation"}
%partition-id = u32[]{:T(256)} partition-id()
%dynamic-slice.7 = s32[1]{0:T(256)} dynamic-slice(s32[8]{0:T(256)} %constant.29, u32[]{:T(256)} %partition-id), dynamic_slice_sizes={1}, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1"],"input_window_bounds":[],"estimated_cycles":"199","iteration_bounds":["1"]},"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"21504"}],"retry_config":{"retry_count":"0"}}
%bitcast.1 = s32[]{:T(256)} bitcast(s32[1]{0:T(256)} %dynamic-slice.7), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/sharding_constraint" source_file="<ipython-input-18-8aea41049d30>" source_line=40}
%tuple = (f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[]{:T(256)}) tuple(f32[1,512,1024]{2,1,0:T(8,128)} %param, f32[512,1024]{1,0:T(8,128)} %param.1, f32[1024,512]{1,0:T(8,128)} %param.2, f32[1024]{0:T(1024)} %param.3, f32[]{:T(256)} %constant.15), metadata={op_name="jit(forward)/jit(main)/remat2" source_file="<ipython-input-18-8aea41049d30>" source_line=53}
%get-tuple-element.1 = f32[]{:T(256)} get-tuple-element((f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[]{:T(256)}) %tuple), index=4, metadata={op_name="jit(forward)/jit(main)/remat2" source_file="<ipython-input-18-8aea41049d30>" source_line=53}
%multiply.8 = bf16[]{:T(512)S(6)} multiply(f32[]{:T(256)} %get-tuple-element.1, f32[]{:T(256)} %constant.16), metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=49}
%get-tuple-element.4 = f32[1024]{0:T(1024)} get-tuple-element((f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[]{:T(256)}) %tuple), index=3, metadata={op_name="jit(forward)/jit(main)/remat2" source_file="<ipython-input-18-8aea41049d30>" source_line=53}
%get-tuple-element = f32[1,512,1024]{2,1,0:T(8,128)} get-tuple-element((f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[]{:T(256)}) %tuple), index=0, metadata={op_name="jit(forward)/jit(main)/remat2" source_file="<ipython-input-18-8aea41049d30>" source_line=53}
%get-tuple-element.3 = f32[512,1024]{1,0:T(8,128)} get-tuple-element((f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[]{:T(256)}) %tuple), index=1, metadata={op_name="jit(forward)/jit(main)/remat2" source_file="<ipython-input-18-8aea41049d30>" source_line=53}
%get-tuple-element.2 = f32[1024,512]{1,0:T(8,128)} get-tuple-element((f32[1,512,1024]{2,1,0:T(8,128)}, f32[512,1024]{1,0:T(8,128)}, f32[1024,512]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[]{:T(256)}) %tuple), index=2, metadata={op_name="jit(forward)/jit(main)/remat2" source_file="<ipython-input-18-8aea41049d30>" source_line=53}
%convert.1 = bf16[1024,512]{1,0:T(8,128)(2,1)} convert(f32[1024,512]{1,0:T(8,128)} %get-tuple-element.2), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32","4"],"input_window_bounds":[],"estimated_cycles":"12704","iteration_bounds":["4","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"1572864"}],"retry_config":{"retry_count":"0"}}
%all-gather = bf16[1024,1024]{1,0:T(8,128)(2,1)} all-gather(bf16[1024,512]{1,0:T(8,128)(2,1)} %convert.1), channel_id=1, replica_groups={{0,4},{1,5},{2,6},{3,7}}, dimensions={1}, use_global_device_ids=true, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=48}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"CUSTOM","id":"0"},"scoped_memory_configs":[],"collective_algorithm_config":{"emitter":"1DAllGatherNonMajorDim","debug":"\ngroup_size = 2\nper_stride_size = 8192 bytes\nshard_size = 1048576 bytes"},"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"2097152"}],"retry_config":{"retry_count":"0"}}
%convert = bf16[512,1024]{1,0:T(8,128)(2,1)} convert(f32[512,1024]{1,0:T(8,128)} %get-tuple-element.3), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16","8"],"input_window_bounds":[],"estimated_cycles":"12704","iteration_bounds":["4","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"1572864"}],"retry_config":{"retry_count":"0"}}
%all-gather.1 = bf16[1024,1024]{1,0:T(8,128)(2,1)} all-gather(bf16[512,1024]{1,0:T(8,128)(2,1)} %convert), channel_id=2, replica_groups=[4,2]<=[2,4]T(1,0), dimensions={0}, use_global_device_ids=true, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=47}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"CUSTOM","id":"0"},"scoped_memory_configs":[],"collective_algorithm_config":{"emitter":"1DAllGatherNonMajorDim","debug":"\ngroup_size = 2\nper_stride_size = 1048576 bytes\nshard_size = 1048576 bytes"},"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"2097152"}],"retry_config":{"retry_count":"0"}}
%fusion = f32[1,2048,1024]{2,1,0:T(8,128)} fusion(bf16[1024,1024]{1,0:T(8,128)(2,1)} %all-gather.1, bf16[1024,1024]{1,0:T(8,128)(2,1)} %all-gather, bf16[]{:T(512)S(6)} %multiply.8), kind=kOutput, calls=%fused_computation, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=47}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":["128","2"],"output_window_bounds":["128","8"],"input_window_bounds":["128","2"],"estimated_cycles":"319985","iteration_bounds":["1","2","4"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"14073856"}],"retry_config":{"retry_count":"0"},"convolution_algorithm_config":{"emitter":"EmitAllBatchInSublanes"}}
%all-reduce = f32[1,2048,1024]{2,1,0:T(8,128)} all-reduce(f32[1,2048,1024]{2,1,0:T(8,128)} %fusion), channel_id=3, replica_groups={{0,1,2,3},{4,5,6,7}}, use_global_device_ids=true, to_apply=%add.clone, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/dot_general" source_file="<ipython-input-18-8aea41049d30>" source_line=47}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"CUSTOM","id":"1"},"scoped_memory_configs":[{"memory_space":"0","offset":"0","size":"67108864"}],"collective_algorithm_config":{"emitter":"RotatedPincerEmitter","strategy":"UniDirection1DRingStrategy","debug":"\nUniDirection1DRingStrategy{colors:2 phases:1 cores:{4},{4} nophase0:0 reserved_sflags:0 cross_module_on_2d_plane:0 has_reordering_map:0 use_routing_table_indices:0}"},"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"8388608"}],"retry_config":{"retry_count":"0"}}
%fusion.5 = (f32[512]{0:T(512)}, f32[512]{0:T(512)}) fusion(f32[1,512,1024]{2,1,0:T(8,128)} %get-tuple-element, f32[1,2048,1024]{2,1,0:T(8,128)} %all-reduce, s32[]{:T(256)} %bitcast.1, f32[1024]{0:T(1024)} %get-tuple-element.4), kind=kLoop, calls=%fused_computation.5, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=37}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","16","8"],"input_window_bounds":[],"estimated_cycles":"15476","iteration_bounds":["1","4","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"2326528"}],"retry_config":{"retry_count":"0"}}
%get-tuple-element.6 = f32[512]{0:T(512)} get-tuple-element((f32[512]{0:T(512)}, f32[512]{0:T(512)}) %fusion.5), index=1, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
%get-tuple-element.5 = f32[512]{0:T(512)} get-tuple-element((f32[512]{0:T(512)}, f32[512]{0:T(512)}) %fusion.5), index=0, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/reduce_sum" source_file="<ipython-input-18-8aea41049d30>" source_line=37}
%fusion.15 = f32[512]{0:T(512)} fusion(f32[512]{0:T(512)} %get-tuple-element.6), kind=kLoop, calls=%fused_computation.13, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/rematted_computation/add" source_file="<ipython-input-18-8aea41049d30>" source_line=37}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1"],"input_window_bounds":[],"estimated_cycles":"205","iteration_bounds":["1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"6144"}],"retry_config":{"retry_count":"0"}}
%fusion.14 = (f32[512]{0:T(512)}, f32[512]{0:T(512)}) fusion(f32[512]{0:T(512)} %fusion.15, f32[512]{0:T(512)} %get-tuple-element.5), kind=kLoop, calls=%fused_computation.12, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","4"],"input_window_bounds":[],"estimated_cycles":"252","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"14336"}],"retry_config":{"retry_count":"0"}}
%get-tuple-element.8 = f32[512]{0:T(512)} get-tuple-element((f32[512]{0:T(512)}, f32[512]{0:T(512)}) %fusion.14), index=1, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
%get-tuple-element.7 = f32[512]{0:T(512)} get-tuple-element((f32[512]{0:T(512)}, f32[512]{0:T(512)}) %fusion.14), index=0, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/div" source_file="<ipython-input-18-8aea41049d30>" source_line=36}
ROOT %fusion.3 = f32[1,512,1024]{2,1,0:T(8,128)} fusion(f32[1,512,1024]{2,1,0:T(8,128)} %get-tuple-element, f32[1,2048,1024]{2,1,0:T(8,128)} %all-reduce, s32[]{:T(256)} %bitcast.1, f32[512]{0:T(512)} %get-tuple-element.8, f32[512]{0:T(512)} %get-tuple-element.7, /*index=5*/f32[1024]{0:T(1024)} %get-tuple-element.4), kind=kLoop, calls=%fused_computation.3, metadata={op_name="jit(forward)/jit(main)/transpose(jvp(checkpoint))/add_any" source_file="<ipython-input-18-8aea41049d30>" source_line=37}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","16","8"],"input_window_bounds":[],"estimated_cycles":"15604","iteration_bounds":["1","4","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"3346432"}],"retry_config":{"retry_count":"0"}}
}