the iterative_refinement function should be improved.
Bailey-24 opened this issue · 1 comments
I didn't remember i have solve how many bug, but this bug I didn't know how to solve.
I have change the obj_descs to a dict.
"{'a Golden Retriever':['A friendly and affectionate Golden Retriever with a soft, golden-furred coat and its warm eyes filled with joy.'], 'a white cat':['a graceful white cat gracefully stretching, showing off its fluffy, pristine fur'], 'a sleek television':['a sleek and modern television'],'a vase of flowers':['a vase of vibrant flowers'], 'a wooden table':['a wooden table']}"
change the code to
def get_clip_metric(torch_image, bbox, target_text="A graceful white cat gracefully stretches, showing off its fluffy, pristine fur."):
target_region = crop_image(torch_image, bbox)
if target_region.max()<100:
target_region = (target_region*(255)).to(torch.uint8)
score = clip_metric(target_region, target_text)
return score/100
def crop_image(img, bbox):
if img.shape[0]!=3:
img = img.permute(0, 3, 1, 2)
x,y,w,h = bbox[0], bbox[1], bbox[2], bbox[3]
cropped_region = img[:,y:y+h,x:x+w] #img[:,x:x+w,y:y+h]
return cropped_region #torchvision.transforms.Resize((512,512))(cropped_region)
def iterative_refinement(first_stage_gen, objects_dict_bboxes,objects_dict_desc, add_guidance=True, optim_steps=1,guidance_weight=200,skip_small=False):
shift_flag=False
for obj, bbox in objects_dict_bboxes.items():
if obj in objects_dict_desc.keys():
object_desc = objects_dict_desc[obj]
if not shift_flag:
torch_image = torch.tensor(first_stage_gen).cuda()
first_stage_gen_numpy = first_stage_gen[0]
tensor_image = torch.from_numpy(first_stage_gen_numpy)
# Permute the dimensions of the tensor
tensor_image = tensor_image.permute(2, 0, 1)
pil_image = torchvision.transforms.ToPILImage()(tensor_image)
clip_score = get_clip_metric(torch_image, tuple(bbox), target_text=object_desc)
if clip_score < 0.2:
if skip_small:
if (bbox[-1]*bbox[-2])/(512*512)>0.01:
shift_flag=True
ref_image = gen_hq_image_sd(object_desc)
image_mask_pil = create_square_mask(pil_image,tuple(bbox))
p_by_ex, torch_ex = main_paint_by_example(img_p=pil_image, ref_p=ref_image, mask=image_mask_pil,
bbox=bbox,text_desc=object_desc,
add_guidance=add_guidance, optim_steps=optim_steps, guidance_weight=guidance_weight)
torch_image=torch_ex
pil_image=p_by_ex
return pil_image,torch_image
bug is
Traceback (most recent call last):
File "main.py", line 487, in <module>
pil_img, _ = iterative_refinement(first_stage_image, sorted_bbox_dict, object_descs, add_guidance=conf.second_stage_gen_config.add_guidance,
File "main.py", line 433, in iterative_refinement
clip_score = get_clip_metric(torch_image, tuple(bbox), target_text=object_desc)
File "main.py", line 405, in get_clip_metric
if target_region.max()<100:
RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.
Apologies, the previous code was not comprehensive. We have updated the code recently. Kindly clone the repo again and make a new environment as instructed. Let us know if you again face the issues.