Nota-NetsPresso/BK-SDM

Question of Dreambooth evaluation

ofzlo opened this issue · 6 comments

ofzlo commented

Hi, thank you for sharing your awesome work ☺️
How to reproduce your Dreambooth quantitative performance in Table. 5?
Would you provide the evaluation code?

Hi, thanks for your interest.

We’ve followed the protocol described in “Sect. 4.1 Dataset and Evaluation” in the DreamBooth paper.

Hope the following information can be helpful. In order to share the code, some cleanup and refactoring are necessary; however, due to other projects and workload, it might take time for us to provide the full evaluation script/code. We kindly ask for your understanding.

1. Generate 3000 images (30 subjects × 25 prompts × 4 generated images)

  • (1) Download the dataset by following this link.
  • (2) Run the script bash scripts/finetune_full.sh for per-subject finetuning without LoRA.
    • This contains the exact hyperparameters we used and produces a finetuned network for "dog2".
  • (3a) Get 4 generated images for each prompt using bash scripts/generate_after_full_ft.sh.
  • (3b) Repeat (3a) for 25 prompts of prompt_list in this link.
  • (4) Repeat (2) and (3) for all the 30 subjects in this list, by changing $SUBJECT_NAME and $CLASS_NAME in scripts/finetune_full.sh and scripts/generate_after_full_ft.sh.
    • For Live Subject classes (i.e., cat, dog) with 9 subject_names {cat, cat2, dog, dog2, dog3, dog5, dog6, dog7, dog8}, Live Subject Prompts should be used.
    • For Object classes (i.e., the classes except cat and dog) with 21 subject_names, Object Prompts should be used.
    • [Caution] Ensure that your hard drive has enough space to save 30 networks from 30 per-subject finetuning runs.

2. Compute {DINO, CLIP-I, CLIP-T} scores over 3000 images

  • We implemented the below scores by ourselves, based on the DreamBooth paper:
    • CLIP-I (DINO): average pairwise cosine similarity between CLIP (DINO) embeddings of generated and real images.
    • CLIP-T: average cosine similarity between prompt and image CLIP embeddings
  • A pseudocode is shown below. The embeddings and scores are computed similarly to src/eval_clip_score.py.
score_txt_img_list = [] # CLIP-T scores over 30 subjects and 25 prompts
score_img_img_list = [] # CLIP-I or DINO scores over 30 subjects and 25 prompts

for subject in subject_list:
    # Compute embeddings of real images
    real_img_list = get_real_img_list(subject)        
    real_img_embs = []
    for real_img in real_img_list:
        real_img_emb = get_img_emb(real_img, 'clip') # CLIP or DINO embedding of a single real image
        real_img_embs.append(real_img_emb)

    # Compute embeddings of prompts & generated images
    prompt_list = get_prompt_list(subject)
    for prompt in prompt_list:
        txt_emb = get_txt_emb(prompt, 'clip') # CLIP embedding of a single text prompt        

        score_txt_img = 0.0
        score_img_img = 0.0
        gen_img_list = get_gen_img_list(subject, prompt)
        for gen_img in gen_img_list:
            gen_img_emb = get_img_emb(gen_img, 'clip') # CLIP or DINO embedding of a single generated image
            score_txt_img += (txt_emb @ gen_img_emb.T) # CLIP-T score for a single pair

            for real_img_emb in real_img_embs:
                score_img_img += (real_img_emb @ gen_img_emb.T) # CLIP-I or DINO score for a single pair

        score_txt_img /= len(gen_img_list) 
        score_img_img /= (len(gen_img_list) * len(real_img_list))
        
        # Collect per-subject-per-prompt CLIP-T and {CLIP-I or DINO} scores
        score_txt_img_list.append(score_txt_img)
        score_img_img_list.append(score_img_img)

# Compute final average CLIP-T and {CLIP-I or DINO} scores
final_score_txt_img = sum(score_txt_img_list) / len(score_txt_img_list)
final_score_img_img = sum(score_img_img_list) / len(score_img_img_list)
ofzlo commented

Hi, thank you for kind response. However, I have another question, do you also use clip-g-14 when reproducing clip-i and clip-t just like clip-score reproduction for bk-sdm? Or use vit-s/16 DINO for DreamBooth? Otherwise, could you provide more details about the get_img_emb function?

I found it :) However, It would be appreciated if you could give more information about get_img_emb details.

Hi, apologies for the late response. I was away on a business trip followed by a personal vacation.
I've tried to provide a clearer explanation by refactoring our evaluation code. Hope the example below will be helpful.

  • It seems to work correctly based on the given example: the pair {cat0, cat1} is closer than {cat0, teapot}, with images sourced from the DreamBooth dataset.

Reference: For DINO embeddings, the validation part from DINO's eval_linear.py was used.

# ------------------------------------------------------------------------------------
# Copyright (c) 2023 Nota Inc. All Rights Reserved.
# Code modified from 
# [dino] https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/eval_linear.py#L196-L214
# [clip] https://github.com/mlfoundations/open_clip/tree/37b729bc69068daa7e860fb7dbcf1ef1d03a4185#usage
# ------------------------------------------------------------------------------------

import argparse
import torch
from PIL import Image
from torchvision import transforms as pth_transforms
import open_clip

class ImageEmbedder:
    def __init__(self, dino_model, dino_n_last_blocks, dino_avgpool_patchtokens,
                 clip_model, clip_data, device):     
        self.device = device   
        # dino-related part
        self.dino_model = torch.hub.load('facebookresearch/dino:main', dino_model)
        self.dino_model.to(device)
        self.dino_model.eval()
        self.dino_preprocess = pth_transforms.Compose([
            pth_transforms.Resize(256, interpolation=3),
            pth_transforms.CenterCrop(224),
            pth_transforms.ToTensor(),
            pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        self.dino_n_last_blocks = dino_n_last_blocks
        self.dino_avgpool_patchtokens = dino_avgpool_patchtokens        
        # clip-related part
        self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(clip_model,
                                                                                          pretrained=clip_data,
                                                                                          device=device)        
        
    def get_img_emb(self, pil_img, model_type):
        with torch.no_grad():
            if model_type == 'dino':
                image = self.dino_preprocess(pil_img).unsqueeze(0).to(self.device)
                intermediate_output = self.dino_model.get_intermediate_layers(image, self.dino_n_last_blocks)
                output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
                if self.dino_avgpool_patchtokens:
                    output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
                    output = output.reshape(output.shape[0], -1)                
                image_feature = output            
            elif model_type == 'clip':
                image = self.clip_preprocess(pil_img).unsqueeze(0).to(self.device)
                image_feature = self.clip_model.encode_image(image)
            else:
                raise NotImplementedError
        # print(image_feature.shape)
        image_feature /= image_feature.norm(dim=-1, keepdim=True)
        image_feature = image_feature.cpu().numpy()        
        return image_feature

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dino_model", type=str, default='dino_vits16', help='see https://github.com/facebookresearch/dino#pretrained-models-on-pytorch-hub')
    parser.add_argument('--dino_n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens
        for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""")
    parser.add_argument('--dino_avgpool_patchtokens', default=False, type=bool,
        help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
        We typically set this to False for ViT-Small and to True with ViT-Base.""")
    parser.add_argument("--clip_model", type=str, default='ViT-g-14')
    parser.add_argument("--clip_data", type=str, default='laion2b_s34b_b88k')
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to use, cuda:gpu_number or cpu')
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    embedder = ImageEmbedder(args.dino_model, args.dino_n_last_blocks, args.dino_avgpool_patchtokens,
                             args.clip_model, args.clip_data, args.device)

    ## test DINO and CLIP image embeddings
    for model_type in ['dino', 'clip']:
        emb_cat0 = embedder.get_img_emb(Image.open('dreambooth/dataset/cat/00.jpg'), model_type)
        emb_cat1 = embedder.get_img_emb(Image.open('dreambooth/dataset/cat/01.jpg'), model_type)
        emb_teapot = embedder.get_img_emb(Image.open('dreambooth/dataset/teapot/00.jpg'), model_type)

        probs_cat0_cat1 = (emb_cat0 @ emb_cat1.T)[0][0]    
        probs_cat0_teapot = (emb_cat0 @ emb_teapot.T)[0][0]
        print(f"[{model_type}] cat0 vs cat1: {probs_cat0_cat1}")
        print(f"[{model_type}] cat0 vs teapot: {probs_cat0_teapot}") 
        
    # [dino] cat0 vs cat1: 0.8219528794288635
    # [dino] cat0 vs teapot: 0.24008619785308838
    # [clip] cat0 vs cat1: 0.9108940362930298
    # [clip] cat0 vs teapot: 0.41291046142578125

ofzlo commented

HI, thanks for your response. I have confirmed that the code works well. However, when I executed scripts/generate_after_full_ft.sh as you instructed, I noticed that the images were generated differently from my intention. In other words, I observed some distortion in the images produced through the pipeline. I suspect that the error might be occurring due to the image size argument, so I removed the image size option and rewrote the pipeline code (please refer to the code below). However, when I conducted an evaluation using the images generated through this inference pipeline, I obtained values of DINO 0.696, CLIP-I 0.687, and CLIP-T 0.240. These scores are somewhat different from the DINO 0.723, CLIP-I 0.717, and CLIP-T 0.260 scores presented in your paper. Could you help me identify what the issue might be? I would appreciate your response.

pipeline = StableDiffusionPipeline.from_pretrained(args.model_id, torch_dtype=torch.float16).to(args.device)
generator = torch.Generator(args.device).manual_seed(args.seed)
img = pipeline(val_prompt, num_inference_steps=25, generator=generator).images[0]

Hi,
The below images are obtained by just running scripts/finetune_full.sh -> scripts/generate_after_full_ft.sh. For the original model, UNET_TYPE="nota-ai/bk-sdm-base" changes to "CompVis/stable-diffusion-v1-4" at this line.
image

when I executed scripts/generate_after_full_ft.sh as you instructed, I noticed that the images were generated differently from my intention.

  • It would be better if you could share your results and the exact procedure you've conducted.

I suspect that the error might be occurring due to the image size argument, so I removed the image size option

  • It would be helpful if you could explain why you think the image size option matters.

rewrote the pipeline code (please refer to the code below)

  • (Our paper) We use DPM-Solver [36, 37] for DreamBooth results.
  • For this, by referring to our code [here and here], properly adjust your code.

I obtained values of DINO 0.696, CLIP-I 0.687, and CLIP-T 0.240.

  • It would be better if you also measure the scores of the original SD-v1.4 with your code.
ofzlo commented

Thank you for your response. If I learn more precise details about the issue I mentioned earlier, I will register it as a new issue at that time :)