hk.custom_getter with_sharding_constraint
kavorite opened this issue · 0 comments
I have this code:
shards = jax.sharding.PositionalSharding(np.array(jax.devices())).reshape(-1, 2)
class ShardGetter:
def __init__(self):
self.transpose = True
self.placement_cache = defaultdict(dict)
def __call__(
self,
next_getter: callable,
value: jax.Array,
context: hk.GetterContext,
):
if value.ndim == 2:
if context.full_name in self.placement_cache:
placement = self.placement_cache[context.module_name][context.name]
else:
if self.transpose:
placement = shards.replicate(0).T
else:
placement = shards.replicate(0)
self.placement_cache[context.module_name][context.name] = placement
self.transpose = not self.transpose
else:
placement = shards.replicate()
value = jax.lax.with_sharding_constraint(value, placement)
return next_getter(value)
what i want is for each matrix to be sharded 'vertically' if its access index is even, 'horizontally' if its access index is odd (approach inspired by https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). It is more appropriate to perform this sharding according to how the weights are accessed than according to how they are created in order to minimize I/O overhead and associated idleness in the forward pass.
However, this code does not work! Everything is just on TPU 0 instead? That is after running it, everything is still placed on the first device:
shard_getter = ShardGetter()
@hk.without_apply_rng
@hk.transform
def objective(inputs: Batch):
# ... do some application-specific, error-computing stuff...
with hk.custom_getter(shard_getter):
loss = model(inputs)
return loss
@jax.jit
def train_init(rng, inputs):
params = objective.init(rng, inputs)
opt_st = optimizer().init(params)
loss = 0.0
step = 0
return TrainState(params, opt_st, loss, step)
inputs = next(batches)
inputs = jax.device_put(inputs, shards.replicate(-1)) # meticulously arrange everything _just so..._
tstate = train_init(jax.device_put(jax.random.PRNGKey(42), shards.replicate()), inputs)
jax.debug.visualize_array_sharding(tstate.params) # should be fully sharded-- somehow not?
What I would like is for this code to apply the sharding constraints specified in ShardGetter.__call__
! As it stands, my monkey patch for this limitation is that I just do this:
@jax.jit
def train_init(rng, inputs):
params = objective.init(rng, inputs)
shtree = dict(shard_getter.placement_cache)
shtree = hkds.merge(hkds.map(lambda *_: shards.replicate(), params), shtree)
params = jtu.tree_map(jax.lax.with_sharding_constraint, params, shtree)
opt_st = optimizer().init(params)
loss = 0.0
step = 0
return TrainState(params, opt_st, loss, step)
which works fine. Understand this is mainly an aesthetic concern (correcting it only required adding three lines). Still, mystified as to what seems to be erasing the sharding constraints after the application of the getter interceptor?