data-apis/array-api

RFC: Consider updating `copy` semantics in `astype` to `False/None/True`

Micky774 opened this issue · 12 comments

Currently the specification for astype, as added in #290, specifies that copy=False/True wherein there is no room for a "copy never" option, and the default copy=True means that calls to astype that do not specify the copy kwarg never have a chance to be a no-op -- this results in unnecessary copies as default behavior.

Other functions which use the copy=False/None/True semantics include:

  • asarray
  • __dlpack__
  • from_dlpack
  • reshape

Hence,

  1. I think it is confusing that copy=False implies a copy will still occur while elsewhere it means "copy never"
  2. copy=True is imo a bad default that leads to more memory movement than may be necessary
  3. The ability to specify "copy never" is helpful when one must be careful with memory behavior

The default specification behavior follows NumPy and its derivatives. From the NumPy docs,

By default, astype always returns a newly allocated array.

and the signature

ndarray.astype(dtype, order='K', casting='unsafe', subok=True, copy=True)

The default is True, not None. So the original specification simply followed suit.

If NumPy is the motivation for this, then I think there may be a problem: NumPy has explicitly decided to differ from the Array API in this respect; see https://numpy.org/neps/nep-0056-array-api-main-namespace.html#copy-keyword-semantics

My reading from that NEP is not that NumPy is choosing to differ. NumPy is following the spec where copy kwarg behavior for astype is not the same as for asarray and others.

However, the situation is actually reversed. We based the current spec on what NumPy and its kin did at the time, which was to, by default, always copy.

I see, thanks.

We based the current spec on what NumPy and its kin did at the time, which was to, by default, always copy.

For the default, yes. NumPy does support copy=None though:

>>> x = np.ones((3, 2), dtype=np.float32)
>>> y = x.astype(dtype=np.float32, copy=None)
>>> y is x
True
>>> y = x.astype(dtype=np.float64, copy=None)
>>> y is x
False

I'm not sure if this was left out of the standard's version of astype, or if it missing is an oversight.

In NumPy's implementation, do copy=None and copy=False have the same semantics of "copy only if necessary"?

In NumPy's implementation, do copy=None and copy=False have the same semantics of "copy only if necessary"?

Yes. But that's only because there was resistance to changing copy=False behavior to match the asarray semantics, and it didn't seem nearly as important as asarray. So astype(..., copy=False) still means "if needed" not "never".

Maybe it'll be changed on a longer timescale, I don't know.

Although the historical creation of the astype copy semantics makes sense (from both NumPy and the array API) I personally think at this point the standardization of other copy kwargs to include None leaves the API for astype confusing/misleading for developers that aren't quite as familiar with it. Additionally, the ability to specify "copy never" can be helpful in heterogeneous computing contexts where it's a bit more important to keep track of data movement

To summarize:

  1. I think it is confusing that copy=False implies a copy will still occur while elsewhere it means "copy never"
  2. copy=True is imo a bad default that leads to more memory movement than may be necessary
  3. The ability to specify "copy never" is helpful

I think you can get "copy never" behavior in the current standard using asarray(copy=None) instead of astype. That's mostly equivalent to if x.dtype == new_dtype but I guess there's also subtleties with the devices.

copy=True is imo a bad default that leads to more memory movement than may be necessary

I am trying to figure out why this matters for JAX, since it doesn't actually have to do a physical memory copy - by design it guarantees that there is no difference (immutability). My understanding is that JAX had a copy keyword already in several places to match NumPy's signatures, that it didn't actually make physical memory copies when that was implemented, and that nothing has to change because of the array API standard. The addition of a true "never copy" mode was a lot of work in NumPy (and will be in CuPy as well), but should not have impacted JAX I believe.

This has now come up several times, so we really should make this more clear. The first time was at #495 (comment) I believe. And more recently, we had a more extensive discussion on this in dmlc/dlpack#136 for DLPack. For DLPack it's about exchange between two libraries rather than semantics with a single library, so we put more effort into the "what does copy actually mean".

In summary, the copy keyword is a "logical copy", i.e. the returned data is guaranteed to be unique for all intents and purposes. If that can be achieved without moving data in memory, then that is fine. This also aligns with https://data-apis.org/array-api/draft/purpose_and_scope.html, which described some of the design principles of the standard - in particular that we want to describe semantics of functions without prescribing execution semantics to the extent possible.

I am trying to figure out why this matters for JAX

True copies are sometimes important in JAX; for example functions can be called with donated buffers (in which the compiler is free to reuse the donated memory), and donated buffers cannot be used in subsequent function calls. If a user wants a copy of an array for this purpose, we currently recommend jnp.array(x, copy=True) which copies data from the original array into a new buffer.

In recent work to make JAX compatible with the Array API, we've been trying to understand the recommended semantics of the copy keyword for the astype function. There's inconsistency between the Array API spec and NumPy's implementation, and further the Array API's justification for the departure from the norm seems to stem from incorrect claims about NumPy's semantics for this keyword. Thus the need for this issue to try to clarify things.