openxla/xla

Compiling scatter results in very slow while-loop on TPU

Opened this issue · 0 comments

Original JAX issue: jax-ml/jax#21367. I'm using Python version: 3.12.1., jaxlib version: 0.4.28, running on a TPU v4-8 VM.

In my use case I need to do batches of dynamic_update_slice operations, and have been using JAX's vmap for that, but was getting extremely slow runtimes (between 10 and 50x slower than expected). I profiled the code and found that the vmapped dynamic_update_slice, which I expected to be doing a scatter, was actually doing a while-loop of dynamic_update_slice ops, looping over the batch axis. I think this while-loop may be preventing parallelization, and causing the very slow runtimes.

This JAX code demonstrates the slow-down by comparing the vmapped dynamic_update_slice, which lowers to a single scatter, with an equivalent unrolled Python loop of dynamic_update_slices:

from timeit import timeit

from jax import jit, lax, vmap, make_jaxpr
import jax.numpy as jnp


# For f which outputs a single array, this simulates vmap using Python map
pymap = lambda f: lambda *args: jnp.stack(list(map(f, *args)))

operands = jnp.ones((100, 32))
updates = jnp.ones((100, 2))
starts = jnp.ones((100, 1), dtype='int32')

f = lax.dynamic_update_slice

f_vmapped = jit(vmap(f))
f_pymapped = jit(pymap(f))

# Ensure compiled
f_vmapped(operands, updates, starts)
f_pymapped(operands, updates, starts)

t_vmapped = timeit(
    lambda: f_vmapped(operands, updates, starts).block_until_ready(), number=100
) / 100

t_pymapped = timeit(
    lambda: f_pymapped(operands, updates, starts).block_until_ready(), number=100
) / 100

print(f"Time vmap(f): {t_vmapped:.2}s")
print(f"Time pymap(f): {t_pymapped:.2}s")

Running it on a TPU v4-8 VM I get:

Time vmap(f): 0.00088s
Time pymap(f): 0.00036s

So, to be clear, what I think could be happening is that the unrolled Python loop is faster than scatter because it can be parallelized (the loop iterations have no dependence on each other), whereas the scatter is (for some reason) compiling to a while-loop which cannot be parallelized.

The lowered StableHLO of f_vmapped does contain a scatter and no loop, as expected. Note that the unique_indices flag of the scatter is true:

>>> print(f_vmapped.lower(operands, updates, starts).as_text())
module @jit_dynamic_update_slice attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<100x32xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<100x2xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<100x1xi32> {mhlo.layout_mode = "default"}) -> (tensor<100x32xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.slice %arg2 [0:100, 0:1] : (tensor<100x1xi32>) -> tensor<100x1xi32>
    %1 = stablehlo.reshape %0 : (tensor<100x1xi32>) -> tensor<100xi32>
    %c = stablehlo.constant dense<0> : tensor<i32>
    %2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<100xi32>
    %3 = stablehlo.compare  LT, %1, %2,  SIGNED : (tensor<100xi32>, tensor<100xi32>) -> tensor<100xi1>
    %c_0 = stablehlo.constant dense<32> : tensor<i32>
    %4 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<100xi32>
    %5 = stablehlo.add %1, %4 : tensor<100xi32>
    %6 = stablehlo.select %3, %5, %1 : tensor<100xi1>, tensor<100xi32>
    %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<100xi32>) -> tensor<100x1xi32>
    %8 = stablehlo.iota dim = 0 : tensor<100x1xi32>
    %9 = stablehlo.concatenate %8, %7, dim = 1 : (tensor<100x1xi32>, tensor<100x1xi32>) -> tensor<100x2xi32>
    %c_1 = stablehlo.constant dense<99> : tensor<i32>
    %10 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %c_2 = stablehlo.constant dense<30> : tensor<i32>
    %11 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %12 = stablehlo.concatenate %10, %11, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %c_3 = stablehlo.constant dense<2147483647> : tensor<ui32>
    %13 = stablehlo.convert %c_3 : (tensor<ui32>) -> tensor<i32>
    %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor<i32>) -> tensor<2xi32>
    %15 = stablehlo.minimum %12, %14 : tensor<2xi32>
    %16 = stablehlo.broadcast_in_dim %15, dims = [1] : (tensor<2xi32>) -> tensor<100x2xi32>
    %c_4 = stablehlo.constant dense<0> : tensor<i32>
    %17 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<i32>) -> tensor<100x2xi32>
    %18 = stablehlo.clamp %17, %9, %16 : tensor<100x2xi32>
    %19 = "stablehlo.scatter"(%arg0, %18, %arg1) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
    ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
      stablehlo.return %arg4 : tensor<f32>
    }) : (tensor<100x32xf32>, tensor<100x2xi32>, tensor<100x2xf32>) -> tensor<100x32xf32>
    return %19 : tensor<100x32xf32>
  }
}

However, after optimization/compilation the HLO contains a while loop with a dynamic-update-slice in the body:

>>> print(f_vmapped.lower(operands, updates, starts).compile().as_text())
HloModule jit_dynamic_update_slice, is_scheduled=true, entry_computation_layout={(f32[100,32]{0,1:T(8,128)}, f32[100,2]{0,1:T(2,128)}, s32[100,1]{0,1:T(1,128)})->f32[100,32]{0,1:T(8,128)}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}

%fused_computation (param_0.3: s32[100,1]) -> s32[100,1] {
  %param_0.3 = s32[100,1]{0,1:T(1,128)S(3)} parameter(0)
  %constant.24 = s32[]{:T(128)} constant(0)
  %broadcast.20 = s32[100,1]{0,1:T(1,128)} broadcast(s32[]{:T(128)} %constant.24), dimensions={}
  %compare.4 = pred[100,1]{0,1:T(4,128)(4,1)} compare(s32[100,1]{0,1:T(1,128)S(3)} %param_0.3, s32[100,1]{0,1:T(1,128)} %broadcast.20), direction=LT, metadata={op_name="jit(dynamic_update_slice)/jit(main)/lt" source_file="/home/jamietownsend/issue.py" source_line=21}
  %constant.23 = s32[]{:T(128)} constant(32)
  %broadcast.18 = s32[100,1]{0,1:T(1,128)} broadcast(s32[]{:T(128)} %constant.23), dimensions={}
  %add.6 = s32[100,1]{0,1:T(1,128)} add(s32[100,1]{0,1:T(1,128)S(3)} %param_0.3, s32[100,1]{0,1:T(1,128)} %broadcast.18), metadata={op_name="jit(dynamic_update_slice)/jit(main)/add" source_file="/home/jamietownsend/issue.py" source_line=21}
  ROOT %select.2 = s32[100,1]{0,1:T(1,128)S(3)} select(pred[100,1]{0,1:T(4,128)(4,1)} %compare.4, s32[100,1]{0,1:T(1,128)} %add.6, s32[100,1]{0,1:T(1,128)S(3)} %param_0.3), metadata={op_name="jit(dynamic_update_slice)/jit(main)/select_n" source_file="/home/jamietownsend/issue.py" source_line=21}
}

%fused_computation.1 (param_0.4: s32[100,1], param_1.9: s32[100,1], param_2.10: s32[2]) -> s32[100,2] {
  %constant.25 = s32[]{:T(128)} constant(0)
  %broadcast.17 = s32[100,2]{0,1:T(2,128)} broadcast(s32[]{:T(128)} %constant.25), dimensions={}
  %param_1.9 = s32[100,1]{0,1:T(1,128)S(3)} parameter(1)
  %pad.7 = s32[100,2]{0,1:T(2,128)} pad(s32[100,1]{0,1:T(1,128)S(3)} %param_1.9, s32[]{:T(128)} %constant.25), padding=0_0x0_1, metadata={op_name="jit(dynamic_update_slice)/jit(main)/concatenate[dimension=1]" source_file="/home/jamietownsend/issue.py" source_line=21}
  %param_0.4 = s32[100,1]{0,1:T(1,128)S(3)} parameter(0)
  %pad.6 = s32[100,2]{0,1:T(2,128)} pad(s32[100,1]{0,1:T(1,128)S(3)} %param_0.4, s32[]{:T(128)} %constant.25), padding=0_0x1_0, metadata={op_name="jit(dynamic_update_slice)/jit(main)/concatenate[dimension=1]" source_file="/home/jamietownsend/issue.py" source_line=21}
  %add.5 = s32[100,2]{0,1:T(2,128)} add(s32[100,2]{0,1:T(2,128)} %pad.7, s32[100,2]{0,1:T(2,128)} %pad.6), metadata={op_name="jit(dynamic_update_slice)/jit(main)/concatenate[dimension=1]" source_file="/home/jamietownsend/issue.py" source_line=21}
  %param_2.10 = s32[2]{0:T(128)S(3)} parameter(2)
  %broadcast.16 = s32[100,2]{0,1:T(2,128)} broadcast(s32[2]{0:T(128)S(3)} %param_2.10), dimensions={1}, metadata={op_name="jit(dynamic_update_slice)/jit(main)/broadcast_in_dim[shape=(100, 2) broadcast_dimensions=(1,)]" source_file="/home/jamietownsend/issue.py" source_line=21}
  ROOT %clamp.0 = s32[100,2]{0,1:T(2,128)S(3)} clamp(s32[100,2]{0,1:T(2,128)} %broadcast.17, s32[100,2]{0,1:T(2,128)} %add.5, s32[100,2]{0,1:T(2,128)} %broadcast.16), metadata={op_name="jit(dynamic_update_slice)/jit(main)/clamp" source_file="/home/jamietownsend/issue.py" source_line=21}
}

%dynamic-slice.reduce_sub_computation (lhs.1: s32[], rhs.1: s32[]) -> s32[] {
  %rhs.1 = s32[] parameter(1)
  %lhs.1 = s32[] parameter(0)
  ROOT %add.2 = s32[] add(s32[] %lhs.1, s32[] %rhs.1)
}

%fused_computation.5.clone (param_0.15: s32[100,2], param_1.23: s32[]) -> (s32[2], s32[1,2]) {
  %param_0.15 = s32[100,2]{0,1:T(2,128)S(3)} parameter(0)
  %param_1.23 = s32[]{:T(128)} parameter(1)
  %constant.30.clone.2 = s32[]{:T(128)} constant(0)
  %dynamic-slice.8.clone.2 = s32[1,2]{0,1:T(2,128)S(3)} dynamic-slice(s32[100,2]{0,1:T(2,128)S(3)} %param_0.15, s32[]{:T(128)} %param_1.23, s32[]{:T(128)} %constant.30.clone.2), dynamic_slice_sizes={1,2}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[]}
  %constant.38 = s32[] constant(0)
  %reduce.4 = s32[2]{0:T(128)S(3)} reduce(s32[1,2]{0,1:T(2,128)S(3)} %dynamic-slice.8.clone.2, s32[] %constant.38), dimensions={0}, to_apply=%dynamic-slice.reduce_sub_computation
  ROOT %tuple.8 = (s32[2]{0:T(128)S(3)}, s32[1,2]{0,1:T(2,128)S(3)}) tuple(s32[2]{0:T(128)S(3)} %reduce.4, s32[1,2]{0,1:T(2,128)S(3)} %dynamic-slice.8.clone.2)
}

%and.reduce_sub_computation (lhs: pred[], rhs: pred[]) -> pred[] {
  %rhs = pred[]{:T(512)} parameter(1)
  %lhs = pred[]{:T(512)} parameter(0)
  ROOT %and = pred[]{:T(512)} and(pred[]{:T(512)} %lhs, pred[]{:T(512)} %rhs)
}

%fused_computation.2.clone (param_0.16: s32[2], param_1.24: s32[1], param_2.24: s32[1]) -> pred[] {
  %constant.40 = s32[]{:T(128)} constant(0)
  %broadcast.23 = s32[2]{0:T(128)} broadcast(s32[]{:T(128)} %constant.40), dimensions={}
  %param_2.24 = s32[1]{0:T(128)} parameter(2)
  %pad.11 = s32[2]{0:T(128)} pad(s32[1]{0:T(128)} %param_2.24, s32[]{:T(128)} %constant.40), padding=0_1
  %param_1.24 = s32[1]{0:T(128)} parameter(1)
  %pad.10 = s32[2]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.24, s32[]{:T(128)} %constant.40), padding=1_0
  %add.9 = s32[2]{0:T(128)} add(s32[2]{0:T(128)} %pad.11, s32[2]{0:T(128)} %pad.10)
  %compare.9 = pred[2]{0:T(512)(128)(4,1)} compare(s32[2]{0:T(128)} %broadcast.23, s32[2]{0:T(128)} %add.9), direction=LE
  %param_0.16 = s32[2]{0:T(128)S(3)} parameter(0)
  %compare.8 = pred[2]{0:T(512)(128)(4,1)} compare(s32[2]{0:T(128)S(3)} %param_0.16, s32[2]{0:T(128)} %add.9), direction=GE
  %and.3 = pred[2]{0:T(512)(128)(4,1)} and(pred[2]{0:T(512)(128)(4,1)} %compare.9, pred[2]{0:T(512)(128)(4,1)} %compare.8)
  %constant.41 = pred[]{:T(512)} constant(true)
  ROOT %reduce.5 = pred[]{:T(512)} reduce(pred[2]{0:T(512)(128)(4,1)} %and.3, pred[]{:T(512)} %constant.41), dimensions={0}, to_apply=%and.reduce_sub_computation
}

%fused_computation.3.clone (param_0.17: f32[100,2], param_1.25: s32[], param_2.25: f32[100,32], param_3.16: s32[], param_4.12: s32[], param_5.5: pred[]) -> f32[1,2] {
  %param_5.5 = pred[]{:T(512)} parameter(5)
  %broadcast.24 = pred[1,2]{0,1:T(4,128)(4,1)} broadcast(pred[]{:T(512)} %param_5.5), dimensions={}
  %param_0.17 = f32[100,2]{0,1:T(2,128)S(3)} parameter(0)
  %param_1.25 = s32[]{:T(128)} parameter(1)
  %constant.42 = s32[]{:T(128)} constant(0)
  %dynamic-slice.9 = f32[1,2]{0,1:T(2,128)} dynamic-slice(f32[100,2]{0,1:T(2,128)S(3)} %param_0.17, s32[]{:T(128)} %param_1.25, s32[]{:T(128)} %constant.42), dynamic_slice_sizes={1,2}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[]}
  %param_2.25 = f32[100,32]{0,1:T(8,128)S(3)} parameter(2)
  %param_3.16 = s32[]{:T(128)} parameter(3)
  %param_4.12 = s32[]{:T(128)} parameter(4)
  %dynamic-slice.10 = f32[1,2]{0,1:T(2,128)} dynamic-slice(f32[100,32]{0,1:T(8,128)S(3)} %param_2.25, s32[]{:T(128)} %param_3.16, s32[]{:T(128)} %param_4.12), dynamic_slice_sizes={1,2}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"0","ones":"0","bitwidth":"32"}]},"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[]}
  ROOT %select.4 = f32[1,2]{0,1:T(2,128)S(3)} select(pred[1,2]{0,1:T(4,128)(4,1)} %broadcast.24, f32[1,2]{0,1:T(2,128)} %dynamic-slice.9, f32[1,2]{0,1:T(2,128)} %dynamic-slice.10)
}

%wide.while_body (wide.param.1: (s32[], f32[100,32], s32[100,2], f32[100,2], s32[], /*index=5*/s32[2])) -> (s32[], f32[100,32], s32[100,2], f32[100,2], s32[], /*index=5*/s32[2]) {
  %constant.33..sunk = s32[]{:T(128)} constant(1)
  %wide.param.1 = (s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) parameter(0)
  %get-tuple-element.47 = s32[]{:T(128)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.1), index=0
  %get-tuple-element.57 = s32[]{:T(128)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.1), index=4
  %get-tuple-element.48 = f32[100,32]{0,1:T(8,128)S(3)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.1), index=1
  %get-tuple-element.55 = s32[100,2]{0,1:T(2,128)S(3)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.1), index=2
  %get-tuple-element.56 = f32[100,2]{0,1:T(2,128)S(3)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.1), index=3
  %get-tuple-element.58 = s32[2]{0:T(128)S(3)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.1), index=5
  %fusion.7 = (s32[2]{0:T(128)S(3)}, s32[1,2]{0,1:T(2,128)S(3)}) fusion(s32[100,2]{0,1:T(2,128)S(3)} %get-tuple-element.55, s32[]{:T(128)} %get-tuple-element.47), kind=kLoop, calls=%fused_computation.5.clone, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","1"],"input_window_bounds":[],"estimated_cycles":"589","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"9216"}],"retry_config":{"retry_count":"0"}}
  %get-tuple-element.32 = s32[2]{0:T(128)S(3)} get-tuple-element((s32[2]{0:T(128)S(3)}, s32[1,2]{0,1:T(2,128)S(3)}) %fusion.7), index=0
  %slice.12 = s32[1]{0:T(128)} slice(s32[2]{0:T(128)S(3)} %get-tuple-element.32), slice={[1:2]}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1"],"input_window_bounds":[],"estimated_cycles":"586","iteration_bounds":["1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"1024"}],"retry_config":{"retry_count":"0"}}
  %get-tuple-element.31 = s32[1,2]{0,1:T(2,128)S(3)} get-tuple-element((s32[2]{0:T(128)S(3)}, s32[1,2]{0,1:T(2,128)S(3)}) %fusion.7), index=1
  %slice.11 = s32[1,1]{0,1:T(1,128)} slice(s32[1,2]{0,1:T(2,128)S(3)} %get-tuple-element.31), slice={[0:1], [0:1]}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","1"],"input_window_bounds":[],"estimated_cycles":"587","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"1024"}],"retry_config":{"retry_count":"0"}}
  %bitcast.5 = s32[1]{0:T(128)} bitcast(s32[1,1]{0,1:T(1,128)} %slice.11)
  %fusion.8 = pred[]{:T(512)} fusion(s32[2]{0:T(128)S(3)} %get-tuple-element.58, s32[1]{0:T(128)} %slice.12, s32[1]{0:T(128)} %bitcast.5), kind=kLoop, calls=%fused_computation.2.clone, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1"],"input_window_bounds":[],"estimated_cycles":"590","iteration_bounds":["1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"4096"}],"retry_config":{"retry_count":"0"}}
  %bitcast.3 = s32[]{:T(128)} bitcast(s32[1,1]{0,1:T(1,128)} %slice.11)
  %bitcast.4 = s32[]{:T(128)} bitcast(s32[1]{0:T(128)} %slice.12)
  %fusion.9 = f32[1,2]{0,1:T(2,128)S(3)} fusion(f32[100,2]{0,1:T(2,128)S(3)} %get-tuple-element.56, s32[]{:T(128)} %get-tuple-element.47, f32[100,32]{0,1:T(8,128)S(3)} %get-tuple-element.48, s32[]{:T(128)} %bitcast.3, s32[]{:T(128)} %bitcast.4, /*index=5*/pred[]{:T(512)} %fusion.8), kind=kLoop, calls=%fused_computation.3.clone, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","1"],"input_window_bounds":[],"estimated_cycles":"598","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"17408"}],"retry_config":{"retry_count":"0"}}
  %add.8 = s32[]{:T(128)} add(s32[]{:T(128)} %get-tuple-element.47, s32[]{:T(128)} %constant.33..sunk)
  %dynamic-update-slice.1 = f32[100,32]{0,1:T(8,128)S(3)} dynamic-update-slice(f32[100,32]{0,1:T(8,128)S(3)} %get-tuple-element.48, f32[1,2]{0,1:T(2,128)S(3)} %fusion.9, s32[]{:T(128)} %bitcast.3, s32[]{:T(128)} %bitcast.4), backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"0","ones":"0","bitwidth":"32"}]},"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"40960"}],"retry_config":{"retry_count":"0"}}
  ROOT %tuple.14 = (s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) tuple(s32[]{:T(128)} %add.8, f32[100,32]{0,1:T(8,128)S(3)} %dynamic-update-slice.1, s32[100,2]{0,1:T(2,128)S(3)} %get-tuple-element.55, f32[100,2]{0,1:T(2,128)S(3)} %get-tuple-element.56, s32[]{:T(128)} %get-tuple-element.57, /*index=5*/s32[2]{0:T(128)S(3)} %get-tuple-element.58)
}

%wide.while_cond (wide.param.0: (s32[], f32[100,32], s32[100,2], f32[100,2], s32[], /*index=5*/s32[2])) -> pred[] {
  %constant.36 = s32[]{:T(128)} constant(100)
  %wide.param.0 = (s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) parameter(0)
  %get-tuple-element.22 = s32[]{:T(128)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %wide.param.0), index=0
  ROOT %compare.7 = pred[]{:T(512)} compare(s32[]{:T(128)} %get-tuple-element.22, s32[]{:T(128)} %constant.36), direction=LT
}

ENTRY %main.25 (Arg_0.1: f32[100,32], Arg_1.2: f32[100,2], Arg_2.3: s32[100,1]) -> f32[100,32] {
  %constant.33 = s32[]{:T(128)} constant(1)
  %constant.4 = s32[]{:T(128)} constant(0)
  %constant.34 = s32[2]{0:T(128)} constant({99, 30})
  %copy-start.2 = (s32[2]{0:T(128)S(3)}, s32[2]{0:T(128)}, u32[]{:S(2)}) copy-start(s32[2]{0:T(128)} %constant.34)
  %constant.6 = s32[2]{0:T(128)} constant({99, 30})
  %copy-start.4 = (s32[2]{0:T(128)S(3)}, s32[2]{0:T(128)}, u32[]{:S(2)}) copy-start(s32[2]{0:T(128)} %constant.6)
  %Arg_2.3 = s32[100,1]{0,1:T(1,128)} parameter(2)
  %copy-start.3 = (s32[100,1]{0,1:T(1,128)S(3)}, s32[100,1]{0,1:T(1,128)}, u32[]{:S(2)}) copy-start(s32[100,1]{0,1:T(1,128)} %Arg_2.3)
  %Arg_1.2 = f32[100,2]{0,1:T(2,128)} parameter(1)
  %copy-start.1 = (f32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)}, u32[]{:S(2)}) copy-start(f32[100,2]{0,1:T(2,128)} %Arg_1.2)
  %Arg_0.1 = f32[100,32]{0,1:T(8,128)} parameter(0)
  %copy.9 = f32[100,32]{0,1:T(8,128)S(3)} copy(f32[100,32]{0,1:T(8,128)} %Arg_0.1), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["2","1"],"input_window_bounds":[],"estimated_cycles":"621","iteration_bounds":["2","1"]},"megacore_config":{"use_single_core":false,"core_id":"0","megacore_split_dim":"0","megacore_allreduce_bytes":"0"},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"16384"}],"retry_config":{"retry_count":"0"}}
  %copy.10 = s32[]{:T(128)} copy(s32[]{:T(128)} %constant.4)
  %copy-done.3 = s32[100,1]{0,1:T(1,128)S(3)} copy-done((s32[100,1]{0,1:T(1,128)S(3)}, s32[100,1]{0,1:T(1,128)}, u32[]{:S(2)}) %copy-start.3)
  %fusion = s32[100,1]{0,1:T(1,128)S(3)} fusion(s32[100,1]{0,1:T(1,128)S(3)} %copy-done.3), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(dynamic_update_slice)/jit(main)/select_n" source_file="/home/jamietownsend/issue.py" source_line=21}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","1"],"input_window_bounds":[],"estimated_cycles":"589","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"1024"}],"retry_config":{"retry_count":"0"}}
  %iota = s32[100,1]{0,1:T(1,128)S(3)} iota(), iota_dimension=0, metadata={op_name="jit(dynamic_update_slice)/jit(main)/iota[dtype=int32 shape=(100, 1) dimension=0]" source_file="/home/jamietownsend/issue.py" source_line=21}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","1"],"input_window_bounds":[],"estimated_cycles":"592","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"1024"}],"retry_config":{"retry_count":"0"}}
  %copy-done.4 = s32[2]{0:T(128)S(3)} copy-done((s32[2]{0:T(128)S(3)}, s32[2]{0:T(128)}, u32[]{:S(2)}) %copy-start.4)
  %fusion.1 = s32[100,2]{0,1:T(2,128)S(3)} fusion(s32[100,1]{0,1:T(1,128)S(3)} %fusion, s32[100,1]{0,1:T(1,128)S(3)} %iota, s32[2]{0:T(128)S(3)} %copy-done.4), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(dynamic_update_slice)/jit(main)/clamp" source_file="/home/jamietownsend/issue.py" source_line=21}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","1"],"input_window_bounds":[],"estimated_cycles":"609","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"2048"}],"retry_config":{"retry_count":"0"}}
  %copy-done.2 = s32[2]{0:T(128)S(3)} copy-done((s32[2]{0:T(128)S(3)}, s32[2]{0:T(128)}, u32[]{:S(2)}) %copy-start.2)
  %copy-done.1 = f32[100,2]{0,1:T(2,128)S(3)} copy-done((f32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)}, u32[]{:S(2)}) %copy-start.1)
  %tuple.17 = (s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) tuple(s32[]{:T(128)} %copy.10, f32[100,32]{0,1:T(8,128)S(3)} %copy.9, s32[100,2]{0,1:T(2,128)S(3)} %fusion.1, f32[100,2]{0,1:T(2,128)S(3)} %copy-done.1, s32[]{:T(128)} %constant.33, /*index=5*/s32[2]{0:T(128)S(3)} %copy-done.2)
  %while.1 = (s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) while((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %tuple.17), condition=%wide.while_cond, body=%wide.while_body
  %get-tuple-element.59 = f32[100,32]{0,1:T(8,128)S(3)} get-tuple-element((s32[]{:T(128)}, f32[100,32]{0,1:T(8,128)S(3)}, s32[100,2]{0,1:T(2,128)S(3)}, f32[100,2]{0,1:T(2,128)S(3)}, s32[]{:T(128)}, /*index=5*/s32[2]{0:T(128)S(3)}) %while.1), index=1
  %copy-start = (f32[100,32]{0,1:T(8,128)}, f32[100,32]{0,1:T(8,128)S(3)}, u32[]{:S(2)}) copy-start(f32[100,32]{0,1:T(8,128)S(3)} %get-tuple-element.59)
  ROOT %copy-done = f32[100,32]{0,1:T(8,128)} copy-done((f32[100,32]{0,1:T(8,128)}, f32[100,32]{0,1:T(8,128)S(3)}, u32[]{:S(2)}) %copy-start)
}