dhansmair/flamingo-mini

Doubt about MaskedCrossAttention

Opened this issue · 3 comments

Hi, I'm unsure about this piece of code in MaskedCrossAttention inside gated_cross_attention.py

media_time = torch.arange(n_media, device=y.device) + 1
# >> David:
# side note: here, text tokens attend to ALL previous visual tokens. If We only want to attend to the
# one image coming before in the text (like in the flamingo paper),
# we need to change >= to == at the line where 'text_to_media_mask' is created.
text_to_media_mask = rearrange(text_time, 'b i -> b 1 i 1') == repeat(media_time, 'j -> 1 1 1 (j m)', m=self.n_visual)
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)

sim = sim - sim.amax(dim=-1, keepdim=True).detach()

It seems you are setting the positions you want to mask out to -torch.finfo(sim.dtype).max (large negative number), but then finding the largest value sim.amax to normalize by?

I would think it should be:

sim = sim.masked_fill(~text_to_media_mask, torch.finfo(sim.dtype).max)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()

Any clarification on this logic is appreciated. Thanks!

Hi @eileenforwhat, it's been a while since I have worked on this, so I needed to think about it myself. This snippet is taken from lucidrains code: https://github.com/lucidrains/flamingo-pytorch/blob/10913abbc8b2ceabb2320560d7d9b85fcb85eee3/flamingo_pytorch/flamingo_pytorch.py#L170 where he does the same.

consider this toy example:

import torch

mask = torch.tensor([0,0,1,1], dtype=bool)
x = torch.tensor([1,2,3,4], dtype=float)
print("mask:", mask)
print("inverted mask:", ~mask)
x = x.masked_fill(~mask, -torch.finfo(x.dtype).max)
print("x:", x)
x = x - x.amax(dim=-1, keepdim=True).detach()
print("x:", x)
alphas = x.softmax(dim=-1)
print("alphas: ", alphas)

which gives this result:

 $ python test.py
mask: tensor([False, False,  True,  True])
inverted mask: tensor([ True,  True, False, False])
x: tensor([-1.7977e+308, -1.7977e+308,   3.0000e+00,   4.0000e+00],
       dtype=torch.float64)
x: tensor([-1.7977e+308, -1.7977e+308,  -1.0000e+00,   0.0000e+00],
       dtype=torch.float64)
alphas:  tensor([0.0000, 0.0000, 0.2689, 0.7311], dtype=torch.float64)

here, subtracting the maximum value of 4 is not "normalizing", it shifts the largest value to zero. In fact, it does not change the result of the softmax operation, so my assumption is that it is done for numerical stability (..?)

note that setting the values we want to mask to -infinity will result in 0 after the softmax operation, which is what we want to achieve.

Hope this helps!

I see. This makes sense -- Thank you!

sure, feel free to ask if you have any more doubts :)