[FR] jax.numpy.ndarray.at for Named Arrays
amifalk opened this issue · 4 comments
It would be great to support array.at[]
syntax for named arrays so it's not necessary to completely unwrap a NamedArray, operate, and then rewrap.
Thanks for the suggestion, this is a good idea!
For now, you can still use .at
syntax without completely unwrapping, by doing something like
nmap(lambda arr: arr.at[...].set(...))(my_named_array)
(although admittedly this is a bit awkward).
The other drawback to this approach is that it does not allow the value in the set(...) to be a named array. Ideally this syntax would be valid:
arr = pz.nx.zeros({'batch': 3, 'a': 2})
arr.at[{'a': 0}].set(pz.nx.ones({'batch': 3})
I've added partial support for this in #41 (to be included in release v0.1.3), which lets you do
arr = pz.nx.zeros({'batch': 3, 'a': 2})
arr.untag('a').at[0].set(pz.nx.ones({'batch': 3})).tag('a')
Currently, only positional indexing is supported, because it's a bit tricky to determine what happens when you operate on a named axis slice like arr.at[{'a': pz.slice[:10]}].set(...)
, especially in combination with numpy/JAX advanced indexing. Leaving this open for now as a reminder for me to figure out dict-style indexing as well.
Fixed by #54.