Calculate number of unique values in a lazy array
lucascolley opened this issue ยท 9 comments
An example from SciPy:
n_clusters = int(xp.unique_values(T).shape[0])
We calculate the number of unique values in T
via a data-dependent shape. As noted in the spec for unique_values
, the function may be omitted by lazy libraries due to the data-dependent shape.
Is there an alternative in the standard which works with lazy arrays? If not, can there be?
To brainstorm a potential solution, assuming there is no easier solution here, could we:
add anedit: this exists alreadyis_lazy
(or something like that) flag to https://data-apis.org/array-api/draft/API_specification/inspection.html#inspection-apis to indicate that the restrictions described in https://data-apis.org/array-api/draft/design_topics/data_dependent_output_shapes.html (may/somewhat?) apply- Add helpers into array-api-compat for these functions which check the flag and direct to the library function if
False
, else call.compute()
first?
I suppose (2) would require a resolution to #748 (comment) first though, given that I don't like the idea of array-api-compat having to track all the possible ways to materialise for the finitely many supported libraries:
Materialization via some function/method in the API that triggers compute would be the one thing that is possibly actionable. However, that is quite tricky. The page I linked above has a few things to say about it.
The inspection API already has a way of determining if a library supports data-dependent output shapes https://data-apis.org/array-api/draft/API_specification/generated/array_api.info.capabilities.html#array_api.info.capabilities
The issue isn't unique_values
(that can be kept lazy just fine in principle) but the use of int()
. That must execute, and the reason that code is there is that it's passed to compiled code right after. That isn't going to work for a fully lazy library (e.g.,ndonnx
- I do have some hope for jax.jit
via a callback mechanism).
So the only library with an issue is Dask. Rewriting the code to int(n.asarray(xp.unique_values(T)).shape[0])
should fix that, since Dask does auto-execute when calling np.asarray
on a Dask.Array instead IIRC.
@lucascolley I think there's nothing to do or change for the standard here - can we close this?
Yep. I suppose the recommendation going forwards will be "if you need to call int
, use DLPack to transfer to a known library which works with int
, or accept that an error may be thrown for lazy libraries".
How is unique_clusters
used in SciPy? If it's being passed to another array function, then in principle, that too can remain lazy.
The title of this issue - to calculate the number of unique values, hence int
I guess it's here https://github.com/scipy/scipy/blob/424708ed018cae6b6584d7d992940fd39f2ebcc0/scipy/cluster/hierarchy.py#L4157
n_clusters = int(xp.unique_values(T).shape[0])
n_obs = int(Z.shape[0] + 1)
L = np.zeros(n_clusters, dtype=np.int32)
M = np.zeros(n_clusters, dtype=np.int32)
Z = np.asarray(Z)
T = np.asarray(T, dtype=np.int32)
s = _hierarchy.leaders(Z, T, L, M, n_clusters, n_obs)
if s >= 0:
raise ValueError(('T is not a valid assignment vector. Error found '
'when examining linkage node %d (< 2n-1).') % s)
L, M = xp.asarray(L), xp.asarray(M)
return (L, M)
If it's being passed to another array function, then in principle, that too can remain lazy.
To spell it out: in the code above, _hierarchy.leaders
is implemented in Cython, and Cython does not know what to do with lazy objects. This is why the int()
is there on the first line - no way to keep things lazy.