jax-ml/jax

`shard_map` doesn't work with `jnp.insert`

Opened this issue · 0 comments

Description

The code below

import jax
import jax.numpy as jnp

def f(x):
  return jnp.insert(x, 0, 0)[None]

from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P

mesh = Mesh(jax.devices("gpu"), axis_name:="test")
f = shard_map(f, mesh, P(axis_name), P(axis_name))
f(jnp.zeros(100))

raises the following error

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 193, in wrapped
    out_flat = shard_map_p.bind(
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 475, in bind
    outs = top_trace.process_shard_map(  # pytype: disable=attribute-error
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 810, in _shard_map_impl
    outs = fun.call_wrapped(*args)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "<stdin>", line 2, in f
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 8495, in insert
    values_ind = indices.at[argsort(indices)].add(arange(n_insert, dtype=indices.dtype))
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/pjit.py", line 188, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **p.params)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 2803, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 442, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 1958, in process_primitive
    out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 1248, in _pjit_rewrite
    out_vals = pjit.pjit_p.bind(*args, jaxpr=jaxpr_, **kwargs)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 2803, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 442, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 902, in process_primitive
    out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 1253, in _pjit_check
    return _check_rep(mesh, jaxpr.jaxpr, in_rep)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 631, in _check_rep
    map(write, e.outvars, out_rep)
TypeError: 'NoneType' object is not iterable

Removing either the shmap or jnp.insert works as expected.

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.34
jaxlib: 0.4.34
numpy:  2.1.3
python: 3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0]
device info: NVIDIA RTX 6000 Ada Generation-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='**hidden**', release='5.15.0-92-generic', version='#102-Ubuntu SMP Wed Jan 10 09:33:48 UTC 2024', machine='x86_64')


$ nvidia-smi
Thu Nov  7 10:25:17 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+