jatentaki/torch-dimcheck

Breakage with wildcard and integer dimensions

Opened this issue · 0 comments

import torch
from torch_dimchecked import dimchecked

@dimchecked
def box_area(box: [..., 3, 2]) -> [...]:
    low, high = box.chunk(2, dim=-1)
    x, y, z = (high - low).chunk(3, dim=-2)
    return x * y * z

bbox = torch.tensor([[0, 1], [0, 1], [0, 1]])
box_area(bbox)

raises

AssertionError                            Traceback (most recent call last)
<ipython-input-51-c80f2a77db5a> in <module>
----> 1 box_area(bbox)

/mnt/disk-1/michat/miniconda3/lib/python3.8/site-packages/torch_dimcheck-0.0.1-py3.8.egg/torch_dimcheck/dimcheck.py in wrapped(*args, **kwargs)
    103             if i in checked_parameters:
    104                 param = checked_parameters[i]
--> 105                 shapes = get_bindings(
    106                     arg, param.annotation, tensor_name=param.name
    107                 )

/mnt/disk-1/michat/miniconda3/lib/python3.8/site-packages/torch_dimcheck-0.0.1-py3.8.egg/torch_dimcheck/dimcheck.py in get_bindings(tensor, annotation, tensor_name)
     85                 raise SizeMismatchError(len(tensor.shape) - i - 1, anno, dim, tensor_name)
     86 
---> 87     raise AssertionError("Arrived at the end of procedure")
     88 
     89 

AssertionError: Arrived at the end of procedure