patrick-kidger/jaxtyping

Annotations for tensors with dynamics dimensions

martenlienen opened this issue · 2 comments

Hi,
what would be the idiomatic way to describe the return shape of, for example, torch.zeros in jaxtyping?

def zeros(size: tuple[int, ...]) -> Float[Tensor, "what do I put here?"]:
    pass

How can I encode that the output shape is determined by the number of elements of the size tuple?

Hey there! Unfortunately this isn't supported -- it'd require a fairly tricky rewrite of some internals: #140 (comment)

Thanks for clearing that up!