q,k,v have different shape but torch.stack works?
junsukha opened this issue · 0 comments
junsukha commented
In labml_nn/diffusion/stable_diffusion/model/unet_attention.py
,
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
has torch.stack((q,k,v), dim=2)
where, I believe, q is of different shape from k and v.
How does torch.stack work then?
When I run text_to_image.py
,
q
, k
, v
are of shape ([8, 1024, 640])
, ([8, 77, 640])
, ([8, 77, 640])
respectively.