google-deepmind/penzai

[FR] jax.numpy.ndarray.at for Named Arrays

amifalk opened this issue · 4 comments

It would be great to support array.at[] syntax for named arrays so it's not necessary to completely unwrap a NamedArray, operate, and then rewrap.

Thanks for the suggestion, this is a good idea!

For now, you can still use .at syntax without completely unwrapping, by doing something like

nmap(lambda arr: arr.at[...].set(...))(my_named_array)

(although admittedly this is a bit awkward).

The other drawback to this approach is that it does not allow the value in the set(...) to be a named array. Ideally this syntax would be valid:

arr = pz.nx.zeros({'batch': 3, 'a': 2})
arr.at[{'a': 0}].set(pz.nx.ones({'batch': 3})

I've added partial support for this in #41 (to be included in release v0.1.3), which lets you do

arr = pz.nx.zeros({'batch': 3, 'a': 2})
arr.untag('a').at[0].set(pz.nx.ones({'batch': 3})).tag('a')

Currently, only positional indexing is supported, because it's a bit tricky to determine what happens when you operate on a named axis slice like arr.at[{'a': pz.slice[:10]}].set(...), especially in combination with numpy/JAX advanced indexing. Leaving this open for now as a reminder for me to figure out dict-style indexing as well.

Fixed by #54.