data-apis/array-api-extra

ENH: add `lazywhere`

lucascolley opened this issue ยท 6 comments

@mdhaber do you think lazywhere would be in scope for this library? It seems to me like it would be useful beyond SciPy?

Yup, probably! But perhaps should check whether other libraries would like to use it.

This proposal is lacking a bit of context. What is lazywhere and the proposed behavior?

This was not meant to be a full proposal yet! Was just wanting to check whether Matt saw any obvious problems before putting more thought in.


SciPy has a private helper called _lazywhere which implements a performance optimisation for the common case of xp.where(condition, x1, x2) that x1 is equivalent to some (elementwise) function f on some array(s): xp.where(condition, f(arrays), x2).

While xp.where(condition, f(arrays), x2) has to evaluate f(arrays) for each element of the arrays, _lazywhere(condition, arrays, f, fillvalue=x2) only evaluates f(arrays) where condition == True.

Example:

a = xp.asarray([1, 2, 3, 4])
b = xp.asarray([5, 6, 7, 8])

def f(a, b):
    return a + b

_lazywhere(a > 2, (a, b), f, xp.nan)

returns array([nan, nan, 10., 12.]) as desired, and never calculates 1 + 5 or 2 + 6.


_lazywhere can also take an f2 to similarly calculate x2 from arrays only when condition == False.

One thing we would do if we add it here would be to make the parameter names consistent with xp.where. Maybe there are other differences - I haven't thought about that yet.

Matt has already implemented this function in terms of the standard at https://github.com/scipy/scipy/blob/6f3a8bc82d07939dbca329ab3f8a8042f8668c44/scipy/_lib/_util.py#L85-L161.

Re: performance optimisation, just wanted to clarify that due to the overhead of indexing, etc., the underlying calculation may need to be rather slow to observe a speed increase. However, it will typically reduce memory usage, and it is also commonly used to avoid warnings and/or errors (e.g. if the operation is division and some elements could be zero). I'd mention that the (private) function has been in SciPy for a long time; I just translated (/rewrote) it for the array API.

Thanks for proposing this idea! ๐Ÿ˜€

This seems like a nice addition for libraries leveraging delayed computation as well ๐Ÿ˜‰

This seems like a nice addition for libraries leveraging delayed computation as well ๐Ÿ˜‰

Isn't this not needed for those libraries? where(cond, f(x1), x2) is already lazy in those libraries, and they can avoid computing f(x1) when cond is False. In fact, it looks like SciPy's lazywhere uses boolean indexing to perform the optimization, meaning it wouldn't even work for a lazy library like Dask https://github.com/scipy/scipy/blob/8db867294b4d2084b9a21b37695c0b70b172498a/scipy/_lib/_util.py#L157