ENH/API: xp-bound namespaces, array-api-compat
lucascolley opened this issue · 11 comments
Currently, functions of this package require passing a standard-compatible namespace as xp=xp
. This works fine, but there have been suggestions that it might be nice to avoid this requirement. There are at least a few ways we could go about this:
(1) xpx.bind_namespace
Usage:
import array_api_strict as xpx
...
xp = array_namespace(x)
xpx = xpx.bind_namespace(xp)
x = xpx.atleast_nd(x, ndim=2)
y = xp.sum(x)
z = xpx.some_func(y)
A potential implementation:
extra_funcs = {'atleast_nd': atleast_nd, ...}
def bind_namespace(xp: ModuleType) -> ModuleType:
class BoundNamespace:
def __getattr__(self, name: str):
if name in extra_funcs:
return functools.partial(extra_funcs[name], xp=xp)
else:
return AttributeError(...)
return BoundNamespace(xp)
I like this idea. If we encounter use cases where a library wants to use multiple xpx
functions in the same local scope and finds the xp=xp
pattern too cumbersome, I think we should add this. I think we can leave it out for now until that situation arises.
(2) xpx.extra_namespace
Usage:
import array_api_strict as xpx
...
xp = array_namespace(x)
xpx = xpx.extra_namespace(xp)
x = xpx.atleast_nd(x, ndim=2)
y = xpx.sum(x) # XXX: xpx instead of xp
z = xpx.some_func(y)
A potential implementation:
extra_funcs = {'atleast_nd': atleast_nd, ...}
def extra_namespace(xp: ModuleType) -> ModuleType:
class ExtraNamespace:
def __getattr__(self, name: str):
if name in extra_funcs:
return functools.partial(extra_funcs[name], xp=xp)
else:
return getattr(xp, name) # XXX: delegate to xp instead of error
return ExtraNamespace(xp)
I would not want to add this yet. I think we should keep separation between the standard namespace and the 'extra' namespace, at least until this library matures.
(3) Use array_api_compat.array_namespace
internally
This would provide the most flexible API and be the least LOC to use. One could use xpx
functions on standard-incompatible arrays, and let array-api-compat handle the compatibility, without having to pass an xp
argument.
We don't yet have a use case where it is clearly beneficial to be able to pass standard-incompatible arrays. Consumer libraries using array-api-extra would already be computing with standard-compatible arrays internally. I don't see the need to support the following use case:
import torch
import array_api_strict as xpx
...
x = torch.asarray([1, 2, 3])
xpx.some_func(x) # works
torch.some_standard_func(x) # does not work
Another complication is that consumer libraries like SciPy wrap array_namespace
to provide custom behaviour for scalars and other types. We would want the internal array_namespace
to be the consumer library's wrapped version rather than the base one from array-api-compat.
I'm also not sure that the 1 LOC save over option (1) of this post for standard-compatible arrays is worth introducing a dependency on array-api-compat.
Overall, this would complicate things a lot with situations of co-vendoring array-api-compat and array-api-extra, which is the primary use-case for the library right now. This might be a better idea in the future if a need for handling standard-incompatible arrays arises (for example, if one wants to use functions from xpx
with just a single library).
Hi @lucascolley! Thanks for writing this up and making this library. I think it's really helpful for the ecosystem to have something like this, and also potentially a good place for staging things the standard may or may not want to adopt.
I read through all the above and will add my two cents.
To be honest, I think (3) is the way to go. This is based on we want to make this as easy to use as possible. I know it sounds silly, but I think adding in extra functions like xpx.bind_namespace
will discourage people from use. And it's nice to work in the same way as array-api-compat
. Regarding vendoring, I would just make array-api-compat
a hard dependency.
On the issue of Python scalars and lists and that kind of thing... I think there needs to be a solution to this, though not sure if this is in the scope of array-api-extra
or something else.
I would suggest depending on array_api_compat, but always use it as import array_api_compat as compat
(or whatever) and do that in a centralized place, so that people who want to vendor both can easily change it to point to their vendored array-api-compat.
It might take a little thinking on how to handle all the different vendoring/depending combinations, but I think it's doable, and will be much simpler in the long run as you're ultimately going to want to use a lot of stuff in compat (not just array_namespace).
I think I agree with depending on array-api-compat
. It will be necessary soon enough, and the issue does not seem that hard.
I'd expect that even functions with a universal implementation in terms of the array API primitives only are going to need library-specific direct calls for performance reasons at some point once usage of this package takes off.
I would suggest depending on array_api_compat, but always use it as
import array_api_compat as compat
(or whatever) and do that in a centralized place, so that people who want to vendor both can easily change it to point to their vendored array-api-compat.
If I understand correctly:
- In
array_api_extra
, add a hard dependency in requirements.txt toarray_api_compat
- In
array_api_extra
, add a filearray_api_extra/src/array_api_extra/compat.py
:
from array_api_compat import *
-
In
array_api_extra
, add meson to the build toolchain -
In
array_api_extra/src/array_api_extra/meson.build
, copy-paste lines https://github.com/scipy/scipy/blob/aa4e5771d25e36ff31bc270a3e4d44b6ba240f1e/scipy/_lib/meson.build#L143-L214 from scipy'smeson.build
.
Note: that's a lot of lines to copy-paste! I think that it would become necessary to move everything to a meson.build file inarray_api_compat
and produce a guide + CI tests? -
Create new file
scipy/scipy/_lib/array_api_compat_vendor/compat.py
:
# This file is copied to ../array_api_extra/ by meson.build, replacing the
# external dependency on array_api_compat with the one vendored by scipy
from ..array_api_compat import *
- In scipy/scipy/_lib/meson.build`, add a line to the section
py3.install_sources(
[
'array_api_extra/src/array_api_extra/__init__.py',
'array_api_extra/src/array_api_extra/_funcs.py',
'array_api_extra/src/array_api_extra/_typing.py',
'array_api_compat_vendor/compat.py', # new line
],
subdir: 'scipy/_lib/array_api_extra',
)
That doesn't look right to me, there should be no need to mess with Meson. Step 1 is fine, step 2 can conceptually be:
# Allow packages that vendor `array-api-extra` as well as `array-api-compat` to override the import location
if os.path.exists('../../array_api_compat') #FIXME make this robust with __file__ etc.
# array-api-compat should be vendored right next to array-api-extra
importlib.from_spec # TODO the 3-line importlib thingy to import directly from a relative location
else:
import array_api_compat as comp
(sorry, no time to work it out, but it can be made to work with a few lines of code)
EDIT: back now, that was even overcomplicating things I think. If we have a _compat.py
file in array-api-extra
, and require that for the vendoring use case the package author puts a _array_api_compat.py
next to the vendored packages, then this should do it:
# Allow packages that vendor `array-api-extra` as well as `array-api-compat` to override the import location
try:
# array-api-compat should be vendored right next to array-api-extra
from .. import _array_api_compat as compat
except ImportError:
# it's an external dependency
import array_api_compat as comp
Fixed the example now. The main point is that this does not require any build system changes inside array-api-extra
, and it will work within libraries that do the vendoring with any build system. Steps (5) and (6) above are needed, but those are very tiny changes. The total diff for this is about 15 lines of code at most, and is a one-time thing to add.
What about the issue of wrapped array_namespace
?
I think I agree with @izaid's take on array_namespace
: (3) is most user-friendly, and then it works the same way as in array-api-compat
.
The issue I'm referring to is that in SciPy, scipy._lib._array_api.array_namespace
wraps array_api_compat.array_namespace
to add extra checks (and default to NumPy for scalars). If array-api-extra uses array_api_compat.array_namespace
internally, we get:
from scipy._lib._array_api import array_namespace, array_api_extra as xpx
...
xp = array_namespace(x)
xpx.cov(x, xp=xp) # uses `xp` from the wrapped `array_namespace`
...
xpx.cov(y) # uses unwrapped `array_namespace`
This seems like a potential footgun to me as one might expect the array_namespace
behaviour to be uniform throughout.
Ah okay. Yes, it seems necessary to allow array_namespace
to be customized. Shouldn't be too hard by modifying my code snippet from above a bit:
# Allow packages that vendor `array-api-extra` as well as `array-api-compat` to override the import location
try:
# allow vendoring array-api-compat and/or customizing array_namespace
from .. import _array_api_extra_vendor import array_namespace, array_api_compat as compat
except ImportError:
# it's an external dependency with no customization of array_namespace
import array_api_compat as compat
array_namespace = compat.array_namespace
Ah okay. Yes, it seems necessary to allow
array_namespace
to be customized. Shouldn't be too hard by modifying my code snippet from above a bit:# Allow packages that vendor `array-api-extra` as well as `array-api-compat` to override the import location try: # allow vendoring array-api-compat and/or customizing array_namespace from .. import _array_api_extra_vendor import array_namespace, array_api_compat as compat except ImportError: # it's an external dependency with no customization of array_namespace import array_api_compat as compat array_namespace = compat.array_namespace
Would it be a bit cleaner as follows?
array_api_extra/src/array_api_extra/compat.py
:
> try:
> from .._array_api_compat_vendor import *
> except ImportError:
> from array_api_compat import *
scipy/scipy/_lib/_array_api_compat_vendor.py
:
from .array_api_compat import *
_array_namespace_orig = array_namespace
def array_namespace(...):
# scipy-specific override