patrick-kidger/jaxtyping

Multi-backend tensor support in jaxtyping for frameworks like Keras 3 and einops

anas-rz opened this issue · 3 comments

Currently, jaxtyping only supports jax.numpy.array as a valid tensor type. Although it works individually with TensorFlow/PyTorch tensors, but there is no multi-backend tensor type. This limitation creates compatibility issues with libraries like Keras 3 and einops, which utilize multi-backend tensors. It would be great if jaxtyping supports a multibackend tensor type that automatically fetches backend and shapes.

Sorry, just to be clear -- do you mean that you'd like an annotation of the sort Float[AnyArrayOrTensorFromAnyBackend, "foo bar"]? Or something else?

do you mean that you'd like an annotation of the sort Float[AnyArrayOrTensorFromAnyBackend, "foo bar"]?

Yes, a type that works with all backends.

This can be done using abc:

import abc
import jax
import numpy as np
from jaxtyping import Float

class AnyArray(metaclass=abc.ABCMeta):
    pass

AnyArray.register(jax.Array)
AnyArray.register(np.ndarray)

x = np.arange(3.0)
tt = Float[AnyArray, "foo"]
assert isinstance(x, tt)

I think on balance this is probably best done as an end user -- rather than being part of the jaxtyping library -- so that a hypoethetical jaxtyping.AnyArray does not limit new libraries by only working with those types that it has been registered with.