openxla/xla

Involuntary Full Rematerialization

Opened this issue · 2 comments

Ported this issue from jax-ml/jax#21562

This code

import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import PartitionSpec as PS, NamedSharding, Mesh


devices = np.asarray(jax.devices()).reshape((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))

shardtype2 = NamedSharding(mesh, PS(None, ('x', 'y'), None))
shardtype1 = NamedSharding(mesh, PS('y', None, 'x'))

def f(a, b, c):
    d = a + b 
    d = jax.lax.with_sharding_constraint(d, shardtype2)
    return c + d 


fjit = jax.jit(f, in_shardings=(shardtype1, shardtype1, shardtype2), out_shardings=shardtype2)

a = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)
b = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)
c = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)

print(fjit(a, b, c).block_until_ready())

Gives this warning

E0531 10:54:04.832741 2609008 spmd_partitioner.cc:569] [spmd] Involuntary full rematerialization. The compiler was not able to go from sharding {devices=[2,1,4]<=[4,2]T(1,0)} to {devices=[1,8,1]<=[8]} without doing a full rematerialization of the tensor. You probably want to enrich the sharding annotations to prevent this from happening.
E0531 10:54:04.832805 2609008 spmd_partitioner.cc:569] [spmd] Involuntary full rematerialization. The compiler was not able to go from sharding {devices=[2,1,4]<=[4,2]T(1,0)} to {devices=[1,8,1]<=[8]} without doing a full rematerialization of the tensor. You probably want to enrich the sharding annotations to prevent this from happening.

We've had to write our own resharding logic instead of solely relying on with_sharding_constraint to avoid this issue.

This means there was some conflict in sharding - potentially sharding that was propagated from an intermediate.
The XLA compiler had to rematerialize the full tensor to reshard it.

This PR improves logging to show you which Hlo Instruction has the conflict
#15402

If you want to further debug this, look at the Hlo dump after ShardingPropagation pass but before SPMD partitioning. Then using my logging PR look at the HloSharding metadata of that HloInstruction and the instructions around it. Most likely you will see there is a conflict like (4,8)->(8,4) triggered the reshard.