lucidrains/x-transformers

[Question] Embedding Inputs to Transformer [batch, seq_len, embedding_dim]

Closed this issue · 2 comments

Hi,

Thanks for the work on this amazing library. I've read the docs however I'm uncertain how to input sequences of embeddings. The embedding comes from another network. I would like to model the sequence of the embeddings.
Is there a way to achieve the following?

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
).cuda()

embeddings_dim = 512
x = torch.randint(0, 256, (1, 1024, embedding_dim)).cuda()
mask = torch.ones_like(x).bool()

model(x, mask = mask) # Error

Apologies for opening a GitHub issue for a question.

@francotheengineer you would just need an Encoder

import torch
from x_transformers import Encoder

encoder = Encoder(
    dim = 512,
    depth = 6,
    heads = 8,
)

embed = torch.randn(1, 1024, 512)
attended = encoder(embed)
assert embed.shape == attended.shape

@lucidrains Thank you very much for your help! I will give it a try.