What’s the best way to use Color palette and another image to condition outputs?
amaarora opened this issue · 2 comments
amaarora commented
Thank you authors for open sourcing your amazing work.
What would be the best way to use Color palette for image generation and image retrieval please?
amaarora commented
This is what I tried so far as color palette input and used text tokenizer.
caption = 'a nice house with outdoor view [S_1]'
bboxes = '[S_1] v0=0 v1=250 v2=420 v3=100 potted plant ' \
'v0=700 v1=720 v2=740 v3=850 bottle [S_2]'
color_palette = '[S_2] color = 2 R = 79 G = 158 B = 143 R = 29 G = 107 B = 137 [S_3]'
Next, we use batched_sample
batched_sample = {}
# Initialize target modalities
for target_mod, ntoks in zip(target_domains, tokens_per_target):
batched_sample = init_empty_target_modality(batched_sample, MODALITY_INFO, target_mod, 1, ntoks, device)
batched_sample = custom_text(
batched_sample, input_text=caption, eos_token='[EOS]',
key='caption', device=device, text_tokenizer=text_tok
)
batched_sample = custom_text(
batched_sample, input_text=bboxes, eos_token='[EOS]',
key='det', device=device, text_tokenizer=text_tok
)
batched_sample = custom_text(
batched_sample, input_text=color_palette, eos_token='[EOS]',
key='color_palette', device=device, text_tokenizer=text_tok
)
And finally create out_dict and dec_dict. But dec_dict fails and gives me an error.
out_dict = sampler.generate(
batched_sample, schedule, text_tokenizer=text_tok,
verbose=True, seed=0,
top_p=top_p, top_k=top_k,
)
dec_dict = decode_dict(
out_dict, toks, text_tok,
image_size=224, patch_size=16,
decoding_steps=1
)
amaarora commented
Just want to confirm the best way to extract using caption and color palette as well please for retreival.
This is what I have so far based on the input notebook.
# Generation configurations
cond_domains = ["caption", "color_palette"]
target_domains = ["tok_dinov2_global"]
tokens_per_target = [16]
generation_config = {
"autoregression_schemes": ["roar"],
"decoding_steps": [1],
"token_decoding_schedules": ["linear"],
"temps": [2.0],
"temp_schedules": ["onex:0.5:0.5"],
"cfg_scales": [1.0],
"cfg_schedules": ["constant"],
"cfg_grow_conditioning": True,
}
top_p, top_k = 0.8, 0.0
schedule = build_chained_generation_schedules(
cond_domains=cond_domains,
target_domains=target_domains,
tokens_per_target=tokens_per_target,
**generation_config,
)
fm_model = FM.from_pretrained(FM_MODEL_PATH).eval().to(DEVICE)
sampler = GenerationSampler(fm_model)
for target_mod, ntoks in zip(target_domains, tokens_per_target):
batched_sample = init_empty_target_modality(
batched_sample, MODALITY_INFO, target_mod, 1, ntoks, DEVICE
)
batched_sample = custom_text(
batched_sample,
input_text=caption,
eos_token="[EOS]",
key="caption",
device=DEVICE,
text_tokenizer=text_tokenizer,
)
batched_sample = custom_text(
batched_sample,
input_text=color_palette,
eos_token="[EOS]",
key="color_palette",
device=DEVICE,
text_tokenizer=text_tokenizer,
)
out_dict = sampler.generate(
batched_sample,
schedule,
text_tokenizer=text_tokenizer,
verbose=True,
seed=0,
top_p=top_p,
top_k=top_k,
)
with torch.no_grad():
dec_dict = decode_dict(
out_dict,
{"tok_dinov2_global": vqvae.to(DEVICE)},
text_tokenizer,
image_size=IMG_SIZE,
patch_size=16,
decoding_steps=1,
)
combined_features = dec_dict["tok_dinov2_global"]