patrick-kidger/jaxtyping

Random instances / Hypothesis-like generation

srush opened this issue · 3 comments

Was just curious if anyone had built a way to use jaxtyping to generate random instances of the right shape specified by the constraints? Alternatively if I wanted to build that how might I hook into the constraint system.

https://hypothesis.readthedocs.io/en/latest/numpy.html#array-api

I've definitely heard this idea discussed a couple of times, but I don't know that anyone has both done it and published it open-source.

At least with JAX it should be particularly easy, as one could then do a jax.eval_shape call to abstractly evaluate the function, performing all shape-checks without having to evaluate any actual code or knowing what are legal values to pass in.

okay, I might try something up. any recs on how to best get out the constraints?

As in, given an x defined by x = Float[Array, "foo bar"], you're asking how to obtain the "foo bar" from x?

This is accesible as x.dim_str, although that's only semi-public API. You could also look at x.dims to get the parsed result, although that's again only semi-public as the types it contains are not public.