`tile_put_sharded` should support scalar inputs
Opened this issue · 0 comments
balancap commented
This issue is closely related to migrating to a proper TileMesh
structure, potentially multi-dimensional, for representing IPU tiling information.
In other words, the following example should be supported:
a = jnp.array(0.0)
a = tile_put_sharded(a, tiles=0)
assert a.shape == ()