LituRout/RB-Modulation

Question about AFA module

Vincent-luo opened this issue · 6 comments

I'm a bit confused about the K and V in the AFA module. According to the paper, K and V are projected from the latent embedding z_t using a linear projection. However, typically in cross-attention layer, only the query comes from z_t​, while the key and value come from conditioning (like text embeddings). So, do we need to concatenate K and V with K_p, V_p, and other K, V values here? I thought Q, K, and V already interacted with each other in the previous self-attention layer. Could you explain why you chose this design? Please correct me if I'm misunderstanding something.

You are right about the cross-attention layer, where the query (Q) comes from z_t, while the key (K) and the value (V) come from conditioning (like text embedding).

K and V in the AFA module are projections of z_t in a sense that these are the KVs obtained from the self-attention layer. The intuition behind this design choice is that we do not want the latent z_t to deviate too much from its previous state obtained via self-attention. Since cross-attention is applied after self-attention layers in a typical scenario for text conditioning, we followed this recipe for style image conditioning as well. Finally, we concatenate [K,K_p,K_s] so that text-prompt can attend to the reference style image features.

Thanks for your quick reply! I noticed in the paper that you use CLIP image encoder to extract the embedding of the reference style image I_S and project it to K_S and V_S . I'm wondering why you don't just use the Style Descriptor to extract the embedding, as this could prevent content leakage from the style image.

We experimented with both CLIP image encoder and CSD descriptor. CSD descriptor indeed helps prevent content leakage. There is an argument to enable this feature. CLIP image encoder is the default setting used in Stable Cascade. So we wanted to keep that as well in order to highlight our improvements in the ablation study.

I suddenly realize this method is training-free, so does it mean the linear projection layers to output K_p,V_p,K_s,V_s are not learnable? If so, how do you initialize them?

The projection layers are not learned. These are initialized with the pre-trained weights provided by the base generative models. For prompts, we use CLIP text-encoder and for images, CLIP image-encoder or CSD style descriptors.

Thank you so much for your detailed explanation. Looking forward to the code release!