hananshafi/llmblueprint

the iterative_refinement function should be improved.

Bailey-24 opened this issue · 1 comments

I didn't remember i have solve how many bug, but this bug I didn't know how to solve.
I have change the obj_descs to a dict.
"{'a Golden Retriever':['A friendly and affectionate Golden Retriever with a soft, golden-furred coat and its warm eyes filled with joy.'], 'a white cat':['a graceful white cat gracefully stretching, showing off its fluffy, pristine fur'], 'a sleek television':['a sleek and modern television'],'a vase of flowers':['a vase of vibrant flowers'], 'a wooden table':['a wooden table']}"
change the code to

def get_clip_metric(torch_image, bbox, target_text="A graceful white cat gracefully stretches, showing off its fluffy, pristine fur."):
        
    target_region = crop_image(torch_image, bbox)
    if target_region.max()<100:
        target_region = (target_region*(255)).to(torch.uint8)
    score = clip_metric(target_region, target_text)
    return score/100

def crop_image(img, bbox):
    if img.shape[0]!=3:
        img = img.permute(0, 3, 1, 2)
    x,y,w,h = bbox[0], bbox[1], bbox[2], bbox[3]
    cropped_region = img[:,y:y+h,x:x+w] #img[:,x:x+w,y:y+h]
    return cropped_region #torchvision.transforms.Resize((512,512))(cropped_region)

def iterative_refinement(first_stage_gen, objects_dict_bboxes,objects_dict_desc, add_guidance=True, optim_steps=1,guidance_weight=200,skip_small=False):
    shift_flag=False
    
    for obj, bbox in objects_dict_bboxes.items():
        if obj in objects_dict_desc.keys():
            object_desc = objects_dict_desc[obj]
            if not shift_flag:
                torch_image = torch.tensor(first_stage_gen).cuda()
                first_stage_gen_numpy = first_stage_gen[0]
                tensor_image = torch.from_numpy(first_stage_gen_numpy)
                # Permute the dimensions of the tensor
                tensor_image = tensor_image.permute(2, 0, 1)
                
                pil_image = torchvision.transforms.ToPILImage()(tensor_image)

            
            clip_score = get_clip_metric(torch_image, tuple(bbox), target_text=object_desc)
            if clip_score < 0.2:
                if skip_small:
                    if (bbox[-1]*bbox[-2])/(512*512)>0.01:
                        shift_flag=True
                        ref_image = gen_hq_image_sd(object_desc)
                        image_mask_pil = create_square_mask(pil_image,tuple(bbox))
                        p_by_ex, torch_ex = main_paint_by_example(img_p=pil_image, ref_p=ref_image, mask=image_mask_pil,
                                                                    bbox=bbox,text_desc=object_desc,
                                                                    add_guidance=add_guidance, optim_steps=optim_steps, guidance_weight=guidance_weight)
                        torch_image=torch_ex
                        pil_image=p_by_ex
    return pil_image,torch_image

bug is

    Traceback (most recent call last):
  File "main.py", line 487, in <module>
    pil_img, _ = iterative_refinement(first_stage_image, sorted_bbox_dict, object_descs, add_guidance=conf.second_stage_gen_config.add_guidance,
  File "main.py", line 433, in iterative_refinement
    clip_score = get_clip_metric(torch_image, tuple(bbox), target_text=object_desc)
  File "main.py", line 405, in get_clip_metric
    if target_region.max()<100:
RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.

Apologies, the previous code was not comprehensive. We have updated the code recently. Kindly clone the repo again and make a new environment as instructed. Let us know if you again face the issues.