lucidrains/FLASH-pytorch

Is it a typo in FLASH module?

marsggbo opened this issue · 1 comments

The original code is below:

quad_q, quad_k, lin_q, lin_k, v = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v))

Is that a typo? maybe the correct version is n=x.shape[-2] or set g=self.group_size

@marsggbo ohh yea, the einops equation isn't very clear

it should be b (n g) d -> b n g d, with g = self.group_size but otherwise it is correct