`shard_map` doesn't work with `jnp.insert`
Opened this issue · 0 comments
mrlazy1708 commented
Description
The code below
import jax
import jax.numpy as jnp
def f(x):
return jnp.insert(x, 0, 0)[None]
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P
mesh = Mesh(jax.devices("gpu"), axis_name:="test")
f = shard_map(f, mesh, P(axis_name), P(axis_name))
f(jnp.zeros(100))
raises the following error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 193, in wrapped
out_flat = shard_map_p.bind(
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 475, in bind
outs = top_trace.process_shard_map( # pytype: disable=attribute-error
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 810, in _shard_map_impl
outs = fun.call_wrapped(*args)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "<stdin>", line 2, in f
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 8495, in insert
values_ind = indices.at[argsort(indices)].add(arange(n_insert, dtype=indices.dtype))
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/pjit.py", line 188, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **p.params)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 2803, in bind
return self.bind_with_trace(top_trace, args, params)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 442, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 1958, in process_primitive
out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 1248, in _pjit_rewrite
out_vals = pjit.pjit_p.bind(*args, jaxpr=jaxpr_, **kwargs)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 2803, in bind
return self.bind_with_trace(top_trace, args, params)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 442, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 902, in process_primitive
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 1253, in _pjit_check
return _check_rep(mesh, jaxpr.jaxpr, in_rep)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 631, in _check_rep
map(write, e.outvars, out_rep)
TypeError: 'NoneType' object is not iterable
Removing either the shmap
or jnp.insert
works as expected.
System info (python version, jaxlib version, accelerator, etc.)
>>> import jax; jax.print_environment_info()
jax: 0.4.34
jaxlib: 0.4.34
numpy: 2.1.3
python: 3.10.15 (main, Oct 3 2024, 07:27:34) [GCC 11.2.0]
device info: NVIDIA RTX 6000 Ada Generation-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='**hidden**', release='5.15.0-92-generic', version='#102-Ubuntu SMP Wed Jan 10 09:33:48 UTC 2024', machine='x86_64')
$ nvidia-smi
Thu Nov 7 10:25:17 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+