confused about the implement of mutualselfattention
parryppp opened this issue · 6 comments

Hi @parryppp, note that in the new attention operation, qu
is the unconditional part of both source and target images, while ku[:num_heads]
and vu[:num_heads]
are only from the source image. As a result, the target image can query source information from the source image. The same goes for the conditional part. With such implementation, the source image reconstruction and target image editing processes are performed simultaneously.
I appreciate your help, i am still a little confused about the meaning of qu, qc, ku, kc, vu, vc, could you clarify it for me? qu is the unconditional part of both source and target images, which means that qc also contains information from the source image and the target image?So the out_u and the out_c both have the information from the source image and the target image, but according to your reply, the out_c should only contain the information from the source image?
Actually, qu[:num_heads]
(with shape [num_heads, N, D]) is the query feature from the source image, and qu[num_heads:]
(with shape [num_heads, N, D]) is the query feature from the target image. Thus, the output out_u
is the concatenation of the attention outputs of the source and target images, with shape [2*num_heads, N, D].
You can verify the tensor shapes of the intermediate tensor variables for better understanding.
Thanks for your suggestion, i am fresh to diffusion model, and i am still stuck on the meaning of feature qu[:num_heads]. I conduct a small experiment to verify it.
the tensor b with shape(2,3,4) (batch_size, N, dim), where 2 indicate 2 prompts(the source, the target),
and I permute the tensor shape into (4, 3, 2) for multi-head attention calculation, where the number of heads is 2,
and I execute the chunk function, I got the following tensors,
in this case, the feature c1 is all from the source feature, c2 is all from the target feature, this result makes me so confused, Could you help clarify? Thank you so much!
Your toy experiment is correct. However, this is just the first step, the feature interaction is performed with the attention mechanism, and the target feature serves as the query feature, and the source feature serves as the key and value during the attention process. You can refer to the attention operation for more details.
Btw, this process is equal to the following implementation:
out_u_src = Attention(qu[:num_heads], ku[:num_heads], vu[:num_heads])
out_u_tgt = Attention(qu[num_heads:], ku[:num_heads], vu[:num_heads])
out_u = torch.cat([out_u_src, out_u_tgt])
Thank you so much for your help!