lucidrains/x-transformers

[Bug] Error when `rotary_pos_emb` set to True in cross attention

Closed this issue · 3 comments

import torch
from x_transformers import Encoder, CrossAttender

enc = Encoder(dim=512, depth=6)
model = CrossAttender(
    dim=512,
    depth=6,
    rotary_pos_emb=True,
    attn_flash=True,
)

nodes = torch.randn(1, 1, 512)
node_masks = torch.ones(1, 1).bool()

neighbors = torch.randn(1, 5, 512)
neighbor_masks = torch.ones(1, 5).bool()

encoded_neighbors = enc(neighbors, mask=neighbor_masks)
model(
    nodes, context=encoded_neighbors, mask=node_masks, context_mask=neighbor_masks
)  # (1, 1, 512)

hmm, is the source and target sequence in some shared coordinate space? usually you cannot use rotary embeddings in cross attention

Thank you for explanation, it's my fault to use rotary embedding in cross attention

@BakerBunker no problem, i should have added an assert to prevent this in cross attention setting