lucidrains/x-transformers

Confusion about image->caption example

mtran14 opened this issue · 1 comments

Hello,

Thank you for creating a great repository. I'm new to x-transformers and I'm a bit confused about the provided sample usage for image captioning:

import torch
from x_transformers import ViTransformerWrapper, TransformerWrapper, Encoder, Decoder

encoder = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

decoder = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        cross_attend = True
    )
)

img = torch.randn(1, 3, 256, 256)
caption = torch.randint(0, 20000, (1, 1024))

encoded = encoder(img, return_embeddings = True)
decoder(caption, context = encoded) # (1, 1024, 20000)

I suppose the code is for model training, where pairs of [img, caption] is available.

  1. Why do we feed caption (our target predictions) into the decoder? Shouldn't the decoder only take encoded as input, and produce predictions for caption?
  2. How should I use the trained model for inference, when onlyimg is available (andcaption is unknown/hidden)?

Thanks in advance!

I also have the same question, hoping to clarify it. Thank you!