Confusion about image->caption example
mtran14 opened this issue · 1 comments
mtran14 commented
Hello,
Thank you for creating a great repository. I'm new to x-transformers
and I'm a bit confused about the provided sample usage for image captioning:
import torch
from x_transformers import ViTransformerWrapper, TransformerWrapper, Encoder, Decoder
encoder = ViTransformerWrapper(
image_size = 256,
patch_size = 32,
attn_layers = Encoder(
dim = 512,
depth = 6,
heads = 8
)
)
decoder = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
cross_attend = True
)
)
img = torch.randn(1, 3, 256, 256)
caption = torch.randint(0, 20000, (1, 1024))
encoded = encoder(img, return_embeddings = True)
decoder(caption, context = encoded) # (1, 1024, 20000)
I suppose the code is for model training, where pairs of [img, caption]
is available.
- Why do we feed
caption
(our target predictions) into the decoder? Shouldn't thedecoder
only takeencoded
as input, and produce predictions forcaption
? - How should I use the trained model for inference, when only
img
is available (andcaption
is unknown/hidden)?
Thanks in advance!
mk-runner commented
I also have the same question, hoping to clarify it. Thank you!