data-apis/array-api

Eager functions in API appear in conflict with lazy implementation

cbourjau opened this issue Β· 31 comments

We are looking at adapting this API for a lazy array library built on top of ONNX (and Spox). It seems to be an excellent fit for most parts. However, certain functions in the specification appear to be in conflict with a lazy implementation. One instance is __bool__ which is defined as:

array.__bool__() β†’ bool

The output is documented as "a Python bool object representing the single element of the array". This is problematic for lazy implementations since the value is not available at this point. How should a standard compliant lazy implementation deal with this apparent conflict?

It's impossible for __bool__ to be lazy. bool(x) will call x.__bool__() and convert the result into a boolean.

How should a standard compliant lazy implementation deal with this apparent conflict?

I guess it's up to how you want to design your library, but I guess you really only have two choices: either make bool() implicitly perform a compute on the lazy graph, or make it raise an exception.

Note that the standard does already discuss this in the context of data-dependent output shapes https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html, but perhaps we also need to identify other functions such as __bool__ that are problematic for such libraries.

One question I have is, for a lazy evaluation library, how should it work for array consuming libraries (e.g. scipy) that are written against a generic array API expecting both eager and non-eager array backends? Is it the case that a scipy function should just take in a lazy array object and return a computation graph object, and fail if it tries to call a function that cannot be performed lazily? Or do these libraries need to know how to call compute() (or whatever) at the end? Should the function to perform an actual compute be standardized?

There's also a question of how we can support libraries like this in the test suite. The test suite calls things like __bool__ all the time because it compares actual values in the tests. So it would need to support calling compute() or whatever to get actual array values. But this is orthogonal to the discussions here for the actual standard (feel free to open an issue in the test suite repo if you want to discuss this further).

One question I have is, for a lazy evaluation library, how should it work for array consuming libraries (e.g. scipy) that are written against a generic array API expecting both eager and non-eager array backends?

I think clearly advertising if the array backend in use currently is lazy or eager might make sense as part of the standard. If "lazy", any API return types are also lazy and must be explicitly materialised by the user with a standard top-level compute(array, *args) call. If a library is "eager" (as most are) then return values are already available.

This may make things easier to work with in the test suite as well?

I guess you really only have two choices: either make bool() implicitly perform a compute on the lazy graph, or make it raise an exception.

I think doing implicit computation is not possible in situations where you are looking to construct a computation graph but do not have your "inputs"/"arguments" available yet. It would also lead to a situation where some parts of the API yield concrete values and others are lazily constructing the computation graph which may be hard to navigate for a user.

I suppose this very similar to https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html as you have mentioned. Just raising for __bool__ might be a solution but maybe it makes sense to update the specification to suggest that the function should throw if the eager value is not available.

Sorry but it makes zero sense to return anything other than the built-in bool. It's part of the Python language requirement
and we can't do anything about it. In particular, the return type cannot be anything other than bool, such as an array type or a deferred evaluation proxy type. You can go check that, for example, Dask forces evaluating the graph in its array and errors out in its proxy object, respectively, when __bool__ is called.

If there's a computation which could compute a Boolean that you could then use in the subsequent evaluation, lazily, what you need should be simply arr.astype(arr.__array_namespace__.bool), which should work in a lazy array implementation.

I don't disagree, the lazy/eager suggestion was for the specific question not for __bool__.

Not sure what question you're referring to, Aditya?

Regarding Arron's question about when to call compute(), I have a strong opinion that it should be the end user's responsibility, not any array-consuming libraries'. Suppose we have 3 libraries A, B, C, where A provides a lazy array, B calls A's APIs, C calls B's APIs, and the end user creates A arrays and calls C APIs. For simplicity we could assume A being Dask, B SciPy, and C scikit-learn. If B or C calls compute() on behave of the user (of C), the graph created by the user would be disconnected/materialized at the library boundary, not across the 3 libraries. This also saves all of the array commuting libraries' lives, by staying at lazy/eager agnostic.

Regarding lazy and .compute(), I agree with @leofang's answer - array-consuming libraries should be agnostic to this, and return the same array type as that of the input array type. So no reason not to stay lazy. We had the same question on the dataframe side, where I wrote up a longer answer: data-apis/dataframe-api#120

Sorry but it makes zero sense to return anything other than the built-in bool. It's part of the Python language requirement
and we can't do anything about it.

I'm not sure I agree with that. We must indeed specify -> bool in the standard, however if there's an implementation that's purely lazy, I don't see why it would be a problem to return a proxy object that duck types with bool. It will technically deviate from what the Python language docs and the standard docs say, but everything will work just fine I believe. We have multiple similar cases where we specify returning a Python scalar, (e.g. finfo attributes) and we've always said it's fine to duck type. I don't yet see why __bool__ would be different.

Does anyone know why Dask forces evaluation? My guess would be that it's a pragmatic decision only - keeping conditionals as graph nodes can explore the graph size very quickly, since everything that follows bifurcates each time you call __bool__.

Thanks for the insightful discussion and links! A little more about our use case: We are building a lazy library that can be used to "trace" standard compliant code and then exports the resulting graph to ONNX. I.e. we are (primarily) interested in the computational graph. The values on which it will run are not necessarily available. That said, we happen to have some eager computations on top of that for debugging and testing.

I would be fine with raising an error when __bool__ is called.

Should the standard be updated to reflect the fact that lazy libraries may raise from __bool__ and similar functions?

I'm not sure I agree with that. We must indeed specify -> bool in the standard, however if there's an implementation that's purely lazy, I don't see why it would be a problem to return a proxy object that duck types with bool. It will technically deviate from what the Python language docs and the standard docs say, but everything will work just fine I believe. We have multiple similar cases where we specify returning a Python scalar, (e.g. finfo attributes) and we've always said it's fine to duck type. I don't yet see why bool would be different.

Just to be completely clear, it's physically impossible for __bool__ to return anything other than bool. Returning anything other than True or False from __bool__ results in an error:

>>> class Test:
...     def __bool__(self):
...         return 1
>>> bool(Test())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: __bool__ should return bool, returned int

(the same is true for __int__, __float__, and __complex__).

And even if this limitation didn't exist, that wouldn't really help, because 90% of the time bool(x) is called implicitly as the result of something like if x or assert x.

I think the main takeaways from this discussion are:

  • We should have a page discussing delayed evaluation, noting that it is supported and certain operations may not work with it, and that the function an actual compute is not standardized because it should typically be done by end-users on a per-library basis.

  • We should add a note to __bool__, __int__, __float__, and __complex__ that they may not be supported by such libraries.

Regarding lazy and .compute(), I agree with @leofang's answer - array-consuming libraries should be agnostic to this, and return the same array type as that of the input array type. So no reason not to stay lazy.

I don't see how you can follow this policy in an Array API consuming library that implements an iterative solver (e.g. scikit-learn fitting a machine learning model). At each iteration you typically compute some summary scalar value and compare that to a scalar tolerance level (the result would be a scalar boolean) to make the decision to do another iteration or not.

In this case it seems that we have no other choice than letting the Array API consuming library decide explicitly when it needs to trigger the evaluation to collect the scalar value, even when the outcome of this iterative loop is a collection of n-dimensional arrays of the same type as the input.

Just to be completely clear, it's physically impossible for __bool__ to return anything other than bool.

Thanks for clarifying @asmeurer, that's the thing I was missing in my answer.

In this case it seems that we have no other choice than letting the Array API consuming library decide explicitly when it needs to trigger the evaluation to collect the scalar value

The array-consuming library doesn't have to, as pointed out by Leo and Aaron, the Python language already forces the evaluation. And this is then what Dask does:

>>> import dask.array as da
>>> x = da.ones((2,3))
>>> y = x + 2*x
>>> y
dask.array<add, shape=(2, 3), dtype=float64, chunksize=(2, 3), chunktype=numpy.ndarray>
>>> y.compute()
array([[3., 3., 3.],
       [3., 3., 3.]])
>>> y.sum()
dask.array<sum-aggregate, shape=(), dtype=float64, chunksize=(), chunktype=numpy.ndarray>
>>> y.sum() > 1
dask.array<gt, shape=(), dtype=bool, chunksize=(), chunktype=numpy.ndarray>
>>> bool(y.sum() > 1)  # an if-statement does the same as `bool()` here
True

So no need for any .compute() or similar call within the consuming library.

FWIW it's the same in cuNumeric too. When a Python scalar is needed, an expression is force-evaluated in a blocking manner.

seberg commented

The question is here remains if the standard allows for raising when bool() is used to warn users about evaluation being triggered implicitly (and maybe __float__, etc.).

That would still mean that sklearn would have to call some form of compute() to avoid that error.

The question is here remains if the standard allows for raising when bool() is used to warn users about evaluation being triggered implicitly

If the Python language conclusive answers with "not allowed" here, as it seems to do, I think we should adhere to that. Given both Dask and cuNumeric also comply with it, that should be fine, right?

seberg commented

The python language isn't that conclusive about it. We raise errors all the time, although for things where it's truly impossible. Aaron specifically mentioned that as a possible path above and it was what the @cbourjau was eyeing for, I think.

cuNumeric, Dasks, and maybe others do the implicit compute() happily, I am fine with prescribing at least __bool__ be special enough that users of lazy libraries need to be aware that this is never lazy.
But, I am not sure how much of a trap that is for some lazy library users. Maybe there are solutions though, like an optional warning/error @cbourjau?

The array-consuming library doesn't have to, as pointed out by Leo and Aaron, the Python language already forces the evaluation. And this is then what Dask does:

Alright so an iterative function would have to call bool(scalar_array) (or float(scalar_array)) whenever needed to implicitly trigger evaluation.

However, if bool() / float() / int() are the only library-agnostic ways to trigger evaluation, I am not sure how we would avoid triggering redundant computation in the following idiom where we want to collect several scalar metrics:

def iterative_solver(data, params, tol=1e-4, maxiter=1_000):
    record = defaultdict(list)
    for iter_idx in range(maxiter):
        params = compute_one_step(data, params)
        record["iter"].append(iter_idx)
        record["a"].append(float(metric_a(data, params)))
        record["b"].append(float(metric_b(data, params)))

        if stopping_criterion(data, params) < tol:  # calls bool() implicitly
            break
    
    return params, record

My understanding is that with the current implicit semantics when data & params are lazy arrays, each call to float() and bool() in the previous code would redundantly recompute all the chained compute_one_step calls from the original inputs (and possibly re-evaluate the input generating ancestors redundantly as well).

The only way around this would be to insert explicit checkpoints (such as dask.array's persist method) after each call to compute_one_step.

EDIT: fixed missing .append calls.

Alright so an iterative function would have to call bool(scalar_array) (or float(scalar_array)) whenever needed to implicitly trigger evaluation.

I think it's actively non-idiomatic to do so. You want to write code that does not care whether evaluation is triggered, but rather expresses only the logic and leaves execution semantics to the array library.

The only way around this would be to insert explicit checkpoints (such as dask.array's persist method) after each call to compute_one_step.

Isn't that just a quality-of-implementation issue? A good lazy library should automatically cache/persist calculated values that it can know have to be reused again. In this particular case though, the problem may be that Dask won't see the line for iter_idx in range(maxiter):? If you'd replace range with da.arange, it will be able to tell. Although also, perhaps the example is misleading - if data and params don't change inside the for-loop, you can move it out. And if they do change, then there's nothing to persist.

Thanks for this great discussion! I think it might be useful to reiterate the following point: A lazy library (such as the one we are building on top of ONNX) may have an eager mode on top of it for debugging purposes, but those eager values must never influence the lazy computational graph that we are building. We are essentially trying to compile a sklearn-like pipeline ahead of time into an ONNX graph. We don't have any (meaningful) values available when executing the Python code that produces our lazy arrays. We have no other choice but to throw an exception if an eager value is requested. It would, however, be a pitty if that fact would stop a lazy array implementation from being standard compliant. Hence this issue to clarify if it would be ok by the standard to raise in those cases.

On the topic of control flow:
The ONNX standard does offer lazy control flow (If, Loop, Scan) operators. Rather than using Python's syntactic sugar for if-else statements and for-loops, those operators are more akin to the built-inmap function. It would be necessary to offer similar control flow functions through the array API if a use case like the above iterative_solver were to be supported lazily.

params = compute_one_step(data, params) updates the local variable to point to a newly derived lazy array at each iteration, e.g. at the end of the second iteration: compute_one_step(data, compute_one_step(data, params)).

The number of necessary iterations is data dependent and cannot be guessed ahead of time. We could unroll the loop to maxiter, never calling float() or bool() inside the loop, and keeping lazy variables for the recorded metrics. However if as result this alternative program would evaluate the full loop and then call stopping_criterion a-posteriori to trim the output. That would be incredibly wasteful.

So we need it at least keep the bool() call inside the loop to interrupt it as soon as possible and avoid scheduling unneeded computation. Then the metric values would be kept as lazy variables and we could call float() on each of them after the Python loop break happens, but that would make the code much less natural to write in my opinion.

The only way around this would be to insert explicit checkpoints (such as dask.array's persist method) after each call to compute_one_step.

Isn't that just a quality-of-implementation issue? A good lazy library should automatically cache/persist calculated values that it can know have to be reused again

You are right that dask is probably clever enough to not recompute everything from the start in the code I wrote above because it would not garbage collect intermediate results for which they are still live dependencies in the driver Python program. Not sure if other libraries such as jax would tolerate this pattern though.

I gave it a try and indeed dask is smart enough to avoid any recomputation while triggering computation as needed as part of the control flow:

https://gist.github.com/ogrisel/6a4304e1831051203a98118875ead2d4

I am not sure if we can expect all the other lazy Array API implementations to follow the same semantics without a more explicit API though.

I updated the above gist to also try Jax and it has the same semantics as dask w.r.t. float / bool once float64 support is enabled.

If float64 support is not enabled, it still works but on gets warnings.

I think the current behavior of jax and dask is convenient: to come back to the original question of this issue, I think the Array API specification/documentation for __bool__ should be updated to state that this method should return a Python boolean scalar value and that lazy array implementations should therefore trigger evaluation (and block) when calling such method (similarly for __float__ and __int__).

I think the Array API specification/documentation for __bool__ should be updated to state that this method should return a Python boolean scalar value and that lazy array implementations should therefore trigger evaluation (and block)

As @cbourjau mentioned, it is not always possible to trigger a computation for lazy array implementations where you cannot use any "concrete input values" when building the computation graph (in ONNX, users serialise the computation graph for later execution with concrete inputs). This solution would make it impossible for such libraries to be fully standard compliant which would be a bit of a shame.

In that case it would be helpful to have:

  • a standard way to declare ahead of time if a given array namespace (that can otherwise implement lazy evaluation semantics) guarantees eager/blocking evaluation semantics on __bool__ & friends,
  • a standard exception type when __bool__ is called on an array that does not implement eager/blocking evaluation semantics for such methods.

Otherwise an array consuming library that implements something akin to iterative_solver would have no way to meaningfully report to its users that a given array api namespace is not suitable. Ideally, the consuming library would want a declarative way to detect this limitation before starting such a loop rather than having to protect the sensitive code with try/except.

As @cbourjau mentioned, it is not always possible to trigger a computation for lazy array implementations where you cannot use any "concrete input values" when building the computation graph

Can you explain what you mean? Coming from a scikit-learn background/use-case I can't quite imagine what the "without concrete input values" means. In what use case would it happen that __float__ gets called and there are no input values to the computation?

Are you "tracing" the computation to build a graph?

torch.compile can generate a symbolic graph from a Python program with a data-dependent control flow (presumably via static analysis).

jax.jit does not complain with the above iterative_solver function either but I am not sure what it does under the hood in case of data-dependent control flows.

EDIT: decorating the iterative_solver function with jax.jit makes jax complain as follows on the first call to float():

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float64[].
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.
The error occurred while tracing the function iterative_solver at /var/folders/_y/lfnx34p13w3_sr2k12bjb05w0000gn/T/ipykernel_11860/3412930719.py:5 for jit. This concrete value was not available in Python because it depends on the values of the arguments data and params.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

See also: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit

and later in that same document:

https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#structured-control-flow-primitives

As noted previously, this function runs fine without the jax.jit decorator.

I updated the gist with torch and jax if you are interested in reproducing the above.

Are you "tracing" the computation to build a graph?

Yes that's right.

If we didn't make it so that the dask/jax behaviour becomes what the standard says should be done, wouldn't you still end up in trouble for tracing? In scikit-learn we'd have to explicitly be triggering the computation (instead of implicitly via __bool__ or __float__) and then we'd be back to square one that there are no concrete values.

I'm still not sure I fully understand how spox (I assume this is the library you are thinking about) does its thing but it feels like tracing is not a use-case for the array API? Like, it is a neat trick but you wouldn't change the design to make it easier to do tracing if more mainstream uses got harder. The work PyTorch has done is pretty exciting (as Olivier already said), in particular I think Torch Dynamo is the bit that does the tracing (or is it torch inductor?). Maybe worth investigating how they do it.

The work PyTorch has done is pretty exciting (as Olivier already said), in particular I think Torch Dynamo is the bit that does the tracing (or is it torch inductor?). Maybe worth investigating how they do it.

From https://pytorch.org/docs/master/func.ux_limitations.html#data-dependent-python-control-flow: JAX supports transforming over data-dependent control flow using special control flow operators (e.g. jax.lax.cond, jax.lax.while_loop). We’re investigating adding equivalents of those to PyTorch.

In fact there now is a torch.cond, but it seems so new that that is not yet reflected in the docs.

I'm still not sure I fully understand how spox (I assume this is the library you are thinking about) does its thing but it feels like tracing is not a use-case for the array API? Like, it is a neat trick but you wouldn't change the design to make it easier to do tracing if more mainstream uses got harder.

I believe the majority of tracing use cases will work, using Python control flow based on values is one of very few things that won't work. And such code isn't a good fit for tracing anyway. So I think:

  • we're fine keeping __bool__ & co around
  • we can add a note that most lazy implementations will force evaluation when hitting __bool__/__int__/__float__, but implementations that must be fully lazy will have to raise an exception here
  • this should be rare, and I don't think it is necessary at this point to either support a special API like cond for this, nor to have scikit-learn & co worry about this as a problem right now.

I agree with @rgommers's summary as a pragmatic stance for the short term.

I think we need to wait a year or two w.r.t. how tensor libraries with JIT compiler support will evolve to start thinking how to standardize API for data-dependent control flow operators (and maybe even for standard a jit compiler decorator).

However, those compiler related aspects are an important evolution of the numerical / data science Python ecosystem and I think we should keep them in mind to later consider Array API spec extension (similar to what is done for the xp.linalg submodule).

I fully agree with @rgommers summary, too. Thanks for the great discussion. Should I make a PR that clarifies the different behaviors in the standard?

@cbourjau that would be great, thank you