jatentaki/torch-dimcheck

Explicitly specifying argument names fools the checker

Opened this issue · 0 comments

With a signature like def f(a1, a2) calling it like f(a1=a1, a2=a2) bypasses the dimcheck, as compared to f(a1, a2). Repro:

import torch
from torch_dimcheck import dimchecked

@dimchecked
def attention(
    src: ['B', 'S', 'H', 'W'],
    key: ['B', 'C', 'H', 'W'],
    qry: ['B', 'C', 'H', 'W'],
) -> ['B', 'S', 'H', 'W']:

    return src


src = torch.randn(2, 3, 3, 3)
key = torch.randn(2, 3, 5, 5)
qry = torch.randn(2, 3, 5, 5)

attention(src=src, key=key, qry=qry)
print('succeeded')
attention(src, key, qry)
print('failed')