data-apis/array-api

RFC: add `"mutable arrays"` to `capabilities`

jakevdp opened this issue ยท 15 comments

Several parts of the Array API standard assume that array objects are mutable. Some array API implementations (notably JAX) do not support mutating array objects. This has led to array API implementations currently being developed in scipy and sklearn to be entirely unusable in JAX.

Given this, downstream implementations have a few choices:

  1. Use mutability semantics, excluding libraries like JAX.
  2. Avoid mutability semantics to support libraries like JAX.
  3. Explicitly special-case arrays of type jax.numpy.Array, changing the implementation logic for that case.

(1) is a bad choice, because it means JAX will not be supported. (2) is a bad choice, because for libraries like NumPy, it leads to excessive copying of buffers, worsening performance. (3) is a bad choice because it hard-codes the presence of specific implementations in a context that is supposed to be implementation-agnostic.

One way the Array API standard could address this is by adding "mutable arrays" or something similar to the existing capabilities dict. Then downstream implementations could use strategy (3) without special-casing particular implementations.

(to anticipate one response: no, it's not possible to make JAX arrays support mutation: central to JAX are transformations like jit, vmap, grad, etc. that rely on immutability assumptions in their program tracing)

For (3), could you prototype what it would look like in the case of gh-609? For capabilities["mutable arrays"] == True, we use the NumPy syntax x[i] += y. For capabilities["mutable arrays"] == False, we use ...? This would require standardising a way to do this for immutable arrays, right? Or can we just use xp.where?

Several parts of the Array API standard assume that array objects are mutable.

This is very surprising. It would be nice if we can have a list of such occurrences here, because this was not supposed to happen as per our design guideline https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html

This is very surprising. It would be nice if we can have a list of such occurrences here,

The main example is __setitem__, which as far as I can tell is supported in the standard.

For (3), could you prototype what it would look like in the case of gh-609?

For example, it could look something like this in the specific case of updating an array with a mask and a scalar:

info =  xp.__array_namespace_info__().capabilities()

if info['mutable arrays']:
  x[xp.isnan(x)] = 0
else:
  x = xp.where(xp.isnan(x), 0, x)

That would certainly not cover all cases, but it would be enough to fix a large number of the incompatibilities being currently introduced into scipy and scikit-learn.

But in general, yes, it would also be beneficial if the array API standard could add some syntax for out-of-place array updates similar to what's being discussed in #609.

Adding "mutable arrays" to library level capabilities is sub-optimal for libraries that support both mutable and immutable arrays. For example, numpy arrays have flags.writable attribute bit that signals if an array should be considered as mutable or immutable.
What about adding an array object level flags to Array API, something similar to numpy.ndarray.flags?

Agreed with @pearu's comment. There are multiple other issues here though, for example:

(1) What does it mean to be a "mutable array"? To stay with the numpy.ndarray.flags example:

>>> import numpy as np
>>> x = np.arange(5)
>>> y = x[:3]
>>> y.flags.writeable = False
>>> y += 1
...
ValueError: output array is read-only

>>> y[0]
np.int64(0)
>>> x += 1
>>> y[0]
np.int64(1)

So is y mutable? I guess you'd have said no - but its values can still easily change. So there's no right answer here for numpy.ndarray right now.

(2) JAX you'd argue is immutable I'm sure, however as we saw in the example above numpy readonly arrays reject in-place operators like += while JAX doesn't:

>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> x[0]
Array(0, dtype=int32)
>>> x += 1
>>> x[0]
Array(1, dtype=int32)

So I'd say "is a mutable array" is quite ambiguous.

This is very surprising. It would be nice if we can have a list of such occurrences here,

The main example is __setitem__, which as far as I can tell is supported in the standard.

__setitem__ is the one and only painful example for JAX here I believe, however it is not the case that it cannot be implemented in JAX. Nor that it's incompatible with immutability. It's a complex topic, but https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html covers it. The key point is that there's no semantic difference between updating values in-place or out-of-place as long as the update modifies only a single array. The reason JAX never implemented slice/item assignment (with discussions going all the way back to gh-24) is that doing so would be confusing given the mismatch in semantics with NumPy. But it's NumPy semantics that are undefined behavior as soon as views play a role, the JAX design is perfectly fine.


I think that we should not add .flags unless there's a real value-add. At the moment, the motivating code example is best written for one or more specific libraries like:

if is_jax(x):
  x[xp.isnan(x)] = 0
else:
  x = xp.where(xp.isnan(x), 0, x)

(1) ... So is y mutable? I guess you'd have said no

Yes, y is not mutable because you cannot mutate it via operations on y.

  • but its values can still easily change.

In general, there exists always ways to mutate data of immutable objects. One can even mutate JAX arrays easily via dlpack or array interface protocols.

In this specific case, the example demonstrates a common practice of viewing data as read-only while the data could still be modified at some other level or time. For instance, one can open a file in read-only mode and in this context the file descriptor would represent an immutable object while some other process may open the same file in a writable mode which would enable mutations.

(2) JAX you'd argue is immutable I'm sure, however as we saw in the example above numpy readonly arrays reject in-place operators like += while JAX doesn't:

This means that numpy and JAX implement different semantics for in-place operations: for numpy, in-place operation is a mutable operation while for JAX, the in-place operation is a syntactic sugar for transformations: x = op(x, y).

So I'd say "is a mutable array" is quite ambiguous.

I'd disagree. By definition, an "array" is a certain view of (contiguous) data that elements can be accessed via indexing operation.
So, a mutable/immutable array is an array that allows/disallows data mutations via indexing operations. Even if there exists other ways for mutating underlying data (say, via direct memory access, via cosmic rays, etc), these mutation will happen out of context of mutable/immutable arrays usages.

...

You probably meant to write:

if not is_jax(x):
  x[xp.isnan(x)] = 0
else:
  x = xp.where(xp.isnan(x), 0, x)

I find using "jax" in the name of utility predicate function suboptimal because JAX arrays are not the only array objects that are immutable. So, instead of introducing is_jax, I suggest:

if is_mutable(x):
  x[xp.isnan(x)] = 0
else:
  x = xp.where(xp.isnan(x), 0, x)

so that the same code in scipy/... will not be needed to be modified when one invents another Array API compliant array object that is immutable: it will be sufficient to update only the definition of is_mutable.

if is_mutable(x):

The problem is that we now need to fetch a capability of the array, rather than the array namespace, since NumPy has both behaviours. So what way could we address this if not for an info or flags method on arrays (or simply refusing to handle immutable NumPy arrays correctly)?

The key point is that there's no semantic difference between updating values in-place or out-of-place as long as the update modifies only a single array.

Thinking about this a bit, I think the language about "views" in https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html is not quite strong enough.

Let's assume "in-place update" is equivalent to x[0] = 1 in NumPy, and "out-of-place update" is equivalent to x = x.copy(); x[0] = 1 in NumPy, or to x = x.at[0].set(1) in JAX.

The kind of example I have in mind is this:

x = xp.zeros(2)
L = [x]
x[0] = 1
print(L)

What is the result here? If x[0] = 1 has in-place update semantics, then this prints [array([1., 0.])]. If x[0] = 1 has out-of-place update semantics, then this prints [array([0., 0.])].

So the equivalence of in-place and out-of-place semantics doesn't just require the absence of array views in the sense of what's tracked by x.flags.owndata, it also requires that the array's Python refcount be exactly equal to 1.

Recall, Python __setitem__ does not support returning non-None values, that is, the assumption in #845 (comment) that one would be able to implement __setitem__ such that

x[0] = 1

uses JAX's semantics (x = x.at[0].set(1)), is simply impossible in Python (a Python object cannot reset its reference by itself, even not so by modifying the parent frame locals, IIRC). Of course, under jitting everything is possible but I assume that we want to keep the semantics of Python and jitted functions the same.

When modifying the @jakevdp example as follows:

x = xp.zeros(2)
L = [x]
L[0][0] = 1
print(L)

the expected output would be [array([1., 0.])] and since I cannot see how the out-of-place semantics would be possible to support technically, I think there is no alternative output.

and since I cannot see how the out-of-place semantics would be possible to support technically, I think there is no alternative output.

Sorry, I think you misunderstood my point. I wasn't arguing that x[0] = 1 should have out-of-place semantics. I was using this as an example of where in-place and out-of-place semantics differ in a way not already identified by the doc I linked to. If it helps, you can replace x[0] = 1 with *** in that code example, and substitute the appropriate code to get either in-place or out-of-place semantics, and to see the difference in their outputs.

If it helps, you can replace x[0] = 1 with *** in that code example

Ok, fair enough. A better example would be that uses, say, some in-place operation (__iadd__, etc.) that methods support non-None return values. Consider my comment above as a reminder that __setitem__ and any of the inplace operation methods are not equivalent in terms of discussing in-place and out-of-place semantics.

Rather than getting lost in implementation details, let's bring it back to the statement I was responding to:

The key point is that there's no semantic difference between updating values in-place or out-of-place as long as the update modifies only a single array.

I think this is untrue, unless you also consider Python-level references as well as views when reasoning about whether an operation affects a "single array". And limiting operations to objects with a refcount of 1 is far more intrusive than limiting operations to arrays whose buffer is not shared with any other array objects.

I think this is untrue, unless you also consider Python-level references as well as views when reasoning about whether an operation affects a "single array".

Yes, that is a good point, and I agree that that page should be more explicit and grow a section on Python refcount >1. The behavior difference applies not only to __setitem__, but to all in-place operators as well.

And limiting operations to objects with a refcount of 1 is far more intrusive

Agreed. I think (but am not sure, have to give it some more thought) is that that should remain undefined behavior - it's kinda baked into the Python language, and it's already a difference today between JAX and NumPy/PyTorch today for += & co.

Hey all โ€“ we chatted about this in today's meeting, and here is a summary:

  • while it's true that mutability is an array-level property, it's also a framework-level property. For example, if you're branching an implementation on array mutability, it matters whether or not the framework is capable of creating mutable copies.
  • the main purpose of an API like this currently would be to decide between in-place and out-of-place operations for array updates. With the exception of masked scalar updates (which can be done out-of-place using where()), the standard does not provide any mechanism for out-of-place updates, and so there's little reason to add mutability inspection at the moment.
  • If out-of-place update APIs are ever added to the spec (cf. #609), then we should revisit this proposal.