TencentARC/MasaCtrl

confused about the implement of mutualselfattention

parryppp opened this issue · 6 comments

image SORRY to bother you, I am stuck on the implementation of mutualselfattention, in the masactrl.py class MutualSelfAttentionControl, the forward function seems just split and implement the source attention and the target attention and then concatenate, how does the mutualselfattention QUERY from SOURCE KEY AND SOURCE VALUE?

Hi @parryppp, note that in the new attention operation, qu is the unconditional part of both source and target images, while ku[:num_heads] and vu[:num_heads] are only from the source image. As a result, the target image can query source information from the source image. The same goes for the conditional part. With such implementation, the source image reconstruction and target image editing processes are performed simultaneously.

I appreciate your help, i am still a little confused about the meaning of qu, qc, ku, kc, vu, vc, could you clarify it for me? qu is the unconditional part of both source and target images, which means that qc also contains information from the source image and the target image?So the out_u and the out_c both have the information from the source image and the target image, but according to your reply, the out_c should only contain the information from the source image?

Actually, qu[:num_heads] (with shape [num_heads, N, D]) is the query feature from the source image, and qu[num_heads:] (with shape [num_heads, N, D]) is the query feature from the target image. Thus, the output out_u is the concatenation of the attention outputs of the source and target images, with shape [2*num_heads, N, D].

You can verify the tensor shapes of the intermediate tensor variables for better understanding.

Thanks for your suggestion, i am fresh to diffusion model, and i am still stuck on the meaning of feature qu[:num_heads]. I conduct a small experiment to verify it.
image
the tensor b with shape(2,3,4) (batch_size, N, dim), where 2 indicate 2 prompts(the source, the target),
and I permute the tensor shape into (4, 3, 2) for multi-head attention calculation, where the number of heads is 2,
image
image
and I execute the chunk function, I got the following tensors,
image
in this case, the feature c1 is all from the source feature, c2 is all from the target feature, this result makes me so confused, Could you help clarify? Thank you so much!

Your toy experiment is correct. However, this is just the first step, the feature interaction is performed with the attention mechanism, and the target feature serves as the query feature, and the source feature serves as the key and value during the attention process. You can refer to the attention operation for more details.

Btw, this process is equal to the following implementation:

out_u_src = Attention(qu[:num_heads], ku[:num_heads], vu[:num_heads])
out_u_tgt = Attention(qu[num_heads:], ku[:num_heads], vu[:num_heads])
out_u = torch.cat([out_u_src, out_u_tgt])

Thank you so much for your help!