[Bug] Error when `rotary_pos_emb` set to True in cross attention
Closed this issue · 3 comments
BakerBunker commented
import torch
from x_transformers import Encoder, CrossAttender
enc = Encoder(dim=512, depth=6)
model = CrossAttender(
dim=512,
depth=6,
rotary_pos_emb=True,
attn_flash=True,
)
nodes = torch.randn(1, 1, 512)
node_masks = torch.ones(1, 1).bool()
neighbors = torch.randn(1, 5, 512)
neighbor_masks = torch.ones(1, 5).bool()
encoded_neighbors = enc(neighbors, mask=neighbor_masks)
model(
nodes, context=encoded_neighbors, mask=node_masks, context_mask=neighbor_masks
) # (1, 1, 512)
lucidrains commented
hmm, is the source and target sequence in some shared coordinate space? usually you cannot use rotary embeddings in cross attention
BakerBunker commented
Thank you for explanation, it's my fault to use rotary embedding in cross attention
lucidrains commented
@BakerBunker no problem, i should have added an assert to prevent this in cross attention setting