Breakage with wildcard and integer dimensions
Opened this issue · 0 comments
jatentaki commented
import torch
from torch_dimchecked import dimchecked
@dimchecked
def box_area(box: [..., 3, 2]) -> [...]:
low, high = box.chunk(2, dim=-1)
x, y, z = (high - low).chunk(3, dim=-2)
return x * y * z
bbox = torch.tensor([[0, 1], [0, 1], [0, 1]])
box_area(bbox)
raises
AssertionError Traceback (most recent call last)
<ipython-input-51-c80f2a77db5a> in <module>
----> 1 box_area(bbox)
/mnt/disk-1/michat/miniconda3/lib/python3.8/site-packages/torch_dimcheck-0.0.1-py3.8.egg/torch_dimcheck/dimcheck.py in wrapped(*args, **kwargs)
103 if i in checked_parameters:
104 param = checked_parameters[i]
--> 105 shapes = get_bindings(
106 arg, param.annotation, tensor_name=param.name
107 )
/mnt/disk-1/michat/miniconda3/lib/python3.8/site-packages/torch_dimcheck-0.0.1-py3.8.egg/torch_dimcheck/dimcheck.py in get_bindings(tensor, annotation, tensor_name)
85 raise SizeMismatchError(len(tensor.shape) - i - 1, anno, dim, tensor_name)
86
---> 87 raise AssertionError("Arrived at the end of procedure")
88
89
AssertionError: Arrived at the end of procedure