Generation for PaLI?
BurgerAndreas opened this issue · 0 comments
BurgerAndreas commented
How would one generate an action (output text) using PaLI?
PaLI from readme.md
import torch
from x_transformers import ViTransformerWrapper, XTransformer, Encoder
# PaLI composes of
# 1. vision transformer (ViTransformerWrapper) +
# 2. encoder-decoder transformer (XTransformer)
vit = ViTransformerWrapper(
image_size = 256,
patch_size = 32,
attn_layers = Encoder(
dim = 512,
depth = 6,
heads = 8
)
)
pali = XTransformer(
dim = 512,
enc_num_tokens = 256,
enc_depth = 6,
enc_heads = 8,
enc_max_seq_len = 1024,
dec_num_tokens = 256,
dec_depth = 6,
dec_heads = 8,
dec_max_seq_len = 1024
)
# training data
img = torch.randn(1, 3, 256, 256) # images
prompt = torch.randint(0, 256, (1, 1024)) # prompt
prompt_mask = torch.ones(1, 1024).bool() # prompt text mask
output_text = torch.randint(0, 256, (1, 1024)) # target output text
# train
img_embeds = vit(
img,
return_embeddings = True
)
loss = pali(
prompt,
output_text,
mask = prompt_mask,
src_prepend_embeds = img_embeds # will preprend image embeddings to encoder text embeddings before attention
)
loss.backward()
Desired behaviour
with torch.no_grad()
vit.eval()
pali.eval()
img_embeds = vit(
img,
return_embeddings = True
)
# how to do this?
# XTransformer.generate() does not take src_prepend_embeds that can be fed to encoder
output_text = pali.generate(
img_embeds,
prompt,
mask = prompt_mask,
)
Idea?
img_embeds = self.vit(img=img, return_embeddings = True)
# from XTransformer.forward()
enc = pali.encoder(prompt, mask=prompt_mask, preprend_embeds=img_embeds, return_embeddings=True)
# from XTransformer.generate()
output_text = pali.decoder.generate(seq_out_start, seq_len, context=enc, context_mask=prompt_mask)