q,k,v have different shape but torch.stack works?
junsukha opened this issue · 0 comments
junsukha commented
In labml_nn/diffusion/stable_diffusion/model/unet_attention.py,
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): has torch.stack((q,k,v), dim=2) where, I believe, q is of different shape from k and v.
How does torch.stack work then?
When I run text_to_image.py,
q, k, v are of shape ([8, 1024, 640]), ([8, 77, 640]), ([8, 77, 640]) respectively.