apple/ml-4m

What’s the best way to use Color palette and another image to condition outputs?

amaarora opened this issue · 2 comments

Thank you authors for open sourcing your amazing work.

What would be the best way to use Color palette for image generation and image retrieval please?

This is what I tried so far as color palette input and used text tokenizer.

caption = 'a nice house with outdoor view [S_1]'
bboxes = '[S_1] v0=0 v1=250 v2=420 v3=100 potted plant ' \
         'v0=700 v1=720 v2=740 v3=850 bottle [S_2]'
color_palette = '[S_2] color = 2 R = 79 G = 158 B = 143 R = 29 G = 107 B = 137 [S_3]'

Next, we use batched_sample

batched_sample = {}

# Initialize target modalities
for target_mod, ntoks in zip(target_domains, tokens_per_target):
    batched_sample = init_empty_target_modality(batched_sample, MODALITY_INFO, target_mod, 1, ntoks, device)
    
batched_sample = custom_text(
    batched_sample, input_text=caption, eos_token='[EOS]', 
    key='caption', device=device, text_tokenizer=text_tok
)
batched_sample = custom_text(
    batched_sample, input_text=bboxes, eos_token='[EOS]', 
    key='det', device=device, text_tokenizer=text_tok
)
batched_sample = custom_text(
    batched_sample, input_text=color_palette, eos_token='[EOS]', 
    key='color_palette', device=device, text_tokenizer=text_tok
)

And finally create out_dict and dec_dict. But dec_dict fails and gives me an error.

out_dict = sampler.generate(
    batched_sample, schedule, text_tokenizer=text_tok, 
    verbose=True, seed=0,
    top_p=top_p, top_k=top_k,
)
dec_dict = decode_dict(
    out_dict, toks, text_tok, 
    image_size=224, patch_size=16,
    decoding_steps=1
)

Just want to confirm the best way to extract using caption and color palette as well please for retreival.

This is what I have so far based on the input notebook.

# Generation configurations
cond_domains = ["caption", "color_palette"]
target_domains = ["tok_dinov2_global"]
tokens_per_target = [16]
generation_config = {
    "autoregression_schemes": ["roar"],
    "decoding_steps": [1],
    "token_decoding_schedules": ["linear"],
    "temps": [2.0],
    "temp_schedules": ["onex:0.5:0.5"],
    "cfg_scales": [1.0],
    "cfg_schedules": ["constant"],
    "cfg_grow_conditioning": True,
}
top_p, top_k = 0.8, 0.0

schedule = build_chained_generation_schedules(
    cond_domains=cond_domains,
    target_domains=target_domains,
    tokens_per_target=tokens_per_target,
    **generation_config,
)

fm_model = FM.from_pretrained(FM_MODEL_PATH).eval().to(DEVICE)
sampler = GenerationSampler(fm_model)

for target_mod, ntoks in zip(target_domains, tokens_per_target):
    batched_sample = init_empty_target_modality(
        batched_sample, MODALITY_INFO, target_mod, 1, ntoks, DEVICE
    )

batched_sample = custom_text(
    batched_sample,
    input_text=caption,
    eos_token="[EOS]",
    key="caption",
    device=DEVICE,
    text_tokenizer=text_tokenizer,
)
batched_sample = custom_text(
    batched_sample,
    input_text=color_palette,
    eos_token="[EOS]",
    key="color_palette",
    device=DEVICE,
    text_tokenizer=text_tokenizer,
)

out_dict = sampler.generate(
    batched_sample,
    schedule,
    text_tokenizer=text_tokenizer,
    verbose=True,
    seed=0,
    top_p=top_p,
    top_k=top_k,
)

with torch.no_grad():
    dec_dict = decode_dict(
        out_dict,
        {"tok_dinov2_global": vqvae.to(DEVICE)},
        text_tokenizer,
        image_size=IMG_SIZE,
        patch_size=16,
        decoding_steps=1,
    )

combined_features = dec_dict["tok_dinov2_global"]