lucidrains/vit-pytorch

Masking attention with batches

Opened this issue · 0 comments

Hi,

thank you for your awesome work. I need to introduce masking in the ViT. The following code works for one image, how can I implement it to work with batches? Thanks!

    def forward(self, x, mask = None):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        if mask is not None:
            mask = rearrange(mask, 'b ... -> b (...)')
            mask = F.pad(mask, (x.shape[-2] - mask.shape[-1], 0), value = True)
            dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)



        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

The error I get is dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max) RuntimeError: The size of tensor a (batch_size) must match the size of tensor b (4097) at non-singleton dimension 2