labmlai/annotated_deep_learning_paper_implementations

q,k,v have different shape but torch.stack works?

junsukha opened this issue · 0 comments

In labml_nn/diffusion/stable_diffusion/model/unet_attention.py,
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): has torch.stack((q,k,v), dim=2) where, I believe, q is of different shape from k and v.

How does torch.stack work then?

When I run text_to_image.py,
q, k, v are of shape ([8, 1024, 640]), ([8, 77, 640]), ([8, 77, 640]) respectively.