Masking attention with batches
ashrafflh opened this issue · 0 comments
ashrafflh commented
Hi,
thank you for your awesome work. I need to introduce masking in the ViT. The following code works for one image, how can I implement it to work with batches? Thanks!
def forward(self, x, mask = None):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
mask = F.pad(mask, (x.shape[-2] - mask.shape[-1], 0), value = True)
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
The error I get is dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max) RuntimeError: The size of tensor a (batch_size) must match the size of tensor b (4097) at non-singleton dimension 2