Is it a typo in FLASH module?
marsggbo opened this issue · 1 comments
marsggbo commented
The original code is below:
FLASH-pytorch/flash_pytorch/flash_pytorch.py
Line 338 in edce0fd
Is that a typo? maybe the correct version is n=x.shape[-2]
or set g=self.group_size
lucidrains commented
@marsggbo ohh yea, the einops equation isn't very clear
it should be b (n g) d -> b n g d
, with g = self.group_size
but otherwise it is correct