data-apis/array-api

Calculate number of unique values in a lazy array

lucascolley opened this issue ยท 9 comments

An example from SciPy:

n_clusters = int(xp.unique_values(T).shape[0])

We calculate the number of unique values in T via a data-dependent shape. As noted in the spec for unique_values, the function may be omitted by lazy libraries due to the data-dependent shape.

Is there an alternative in the standard which works with lazy arrays? If not, can there be?

x-ref gh-748, data-apis/array-api-compat#175

To brainstorm a potential solution, assuming there is no easier solution here, could we:

  1. add an is_lazy (or something like that) flag to https://data-apis.org/array-api/draft/API_specification/inspection.html#inspection-apis to indicate that the restrictions described in https://data-apis.org/array-api/draft/design_topics/data_dependent_output_shapes.html (may/somewhat?) apply edit: this exists already
  2. Add helpers into array-api-compat for these functions which check the flag and direct to the library function if False, else call .compute() first?

I suppose (2) would require a resolution to #748 (comment) first though, given that I don't like the idea of array-api-compat having to track all the possible ways to materialise for the finitely many supported libraries:

Materialization via some function/method in the API that triggers compute would be the one thing that is possibly actionable. However, that is quite tricky. The page I linked above has a few things to say about it.

The inspection API already has a way of determining if a library supports data-dependent output shapes https://data-apis.org/array-api/draft/API_specification/generated/array_api.info.capabilities.html#array_api.info.capabilities

The issue isn't unique_values (that can be kept lazy just fine in principle) but the use of int(). That must execute, and the reason that code is there is that it's passed to compiled code right after. That isn't going to work for a fully lazy library (e.g.,ndonnx - I do have some hope for jax.jit via a callback mechanism).

So the only library with an issue is Dask. Rewriting the code to int(n.asarray(xp.unique_values(T)).shape[0]) should fix that, since Dask does auto-execute when calling np.asarray on a Dask.Array instead IIRC.

@lucascolley I think there's nothing to do or change for the standard here - can we close this?

Yep. I suppose the recommendation going forwards will be "if you need to call int, use DLPack to transfer to a known library which works with int, or accept that an error may be thrown for lazy libraries".

How is unique_clusters used in SciPy? If it's being passed to another array function, then in principle, that too can remain lazy.

The title of this issue - to calculate the number of unique values, hence int

I guess it's here https://github.com/scipy/scipy/blob/424708ed018cae6b6584d7d992940fd39f2ebcc0/scipy/cluster/hierarchy.py#L4157

    n_clusters = int(xp.unique_values(T).shape[0])
    n_obs = int(Z.shape[0] + 1)
    L = np.zeros(n_clusters, dtype=np.int32)
    M = np.zeros(n_clusters, dtype=np.int32)
    Z = np.asarray(Z)
    T = np.asarray(T, dtype=np.int32)
    s = _hierarchy.leaders(Z, T, L, M, n_clusters, n_obs)
    if s >= 0:
        raise ValueError(('T is not a valid assignment vector. Error found '
                          'when examining linkage node %d (< 2n-1).') % s)
    L, M = xp.asarray(L), xp.asarray(M)
    return (L, M)

If it's being passed to another array function, then in principle, that too can remain lazy.

To spell it out: in the code above, _hierarchy.leaders is implemented in Cython, and Cython does not know what to do with lazy objects. This is why the int() is there on the first line - no way to keep things lazy.