Multi-backend tensor support in jaxtyping for frameworks like Keras 3 and einops
anas-rz opened this issue · 3 comments
Currently, jaxtyping
only supports jax.numpy.array
as a valid tensor type. Although it works individually with TensorFlow/PyTorch tensors, but there is no multi-backend tensor type. This limitation creates compatibility issues with libraries like Keras 3 and einops, which utilize multi-backend tensors. It would be great if jaxtyping
supports a multibackend tensor type that automatically fetches backend and shapes.
Sorry, just to be clear -- do you mean that you'd like an annotation of the sort Float[AnyArrayOrTensorFromAnyBackend, "foo bar"]
? Or something else?
do you mean that you'd like an annotation of the sort
Float[AnyArrayOrTensorFromAnyBackend, "foo bar"]
?
Yes, a type that works with all backends.
This can be done using abc
:
import abc
import jax
import numpy as np
from jaxtyping import Float
class AnyArray(metaclass=abc.ABCMeta):
pass
AnyArray.register(jax.Array)
AnyArray.register(np.ndarray)
x = np.arange(3.0)
tt = Float[AnyArray, "foo"]
assert isinstance(x, tt)
I think on balance this is probably best done as an end user -- rather than being part of the jaxtyping library -- so that a hypoethetical jaxtyping.AnyArray
does not limit new libraries by only working with those types that it has been registered with.