inquiry of code/result difference between SAM and GSAM
TikaToka opened this issue · 0 comments
TikaToka commented
Hello, thank you for sharing amazing work!
I am trying to adapt GSAM code as an base model, but I have some inquiry.
from transformers import SamModel, SamProcessor
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda" if torch.cuda.is_available() else "cpu")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
sam_masks = []
for idx in range(preds.shape[0]):
sam_inputs = sam_processor(image, input_points=[sampled_points[idx]], return_tensors="pt").to(device)
with torch.no_grad():
sam_outputs = sam_model(**sam_inputs)
print(sam_outputs)
print(sam_outputs.pred_masks.cpu().shape)
sam_masks.append(sam_processor.image_processor.post_process_masks(
sam_outputs.pred_masks.cpu(), sam_inputs["original_sizes"].cpu(), sam_inputs["reshaped_input_sizes"].cpu()
))
for this code from SAM, each sam_mask has shape(1,1,3,h,w), total (n, 1, 1, 3, h, w)
However, if we use this code from GSAM,
image_pil, im = load_image(rgb_path)
# load model
model = load_model(config_file, grounded_checkpoint, device=device)
caption = generate_caption(image_path, device=device)
# Currently ", " is better for detecting single tags
# while ". " is a little worse in some case
text_prompt = generate_tags(caption, split=split)
boxes_filt, scores, pred_phrases = get_grounding_output(
model, im, text_prompt, box_threshold, text_threshold, device='cuda'
)
print(boxes_filt, scores, pred_phrases)
# initialize SAM
# if use_sam_hq:
# print("Initialize SAM-HQ Predictor")
# predictor = SamPredictor(build_sam_hq(checkpoint=sam_hq_checkpoint).to(device))
# else:
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
im = cv2.imread(rgb_path)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
predictor.set_image(im)
size = image_pil.size
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
boxes_filt = boxes_filt.cpu()
# use NMS to handle overlapped boxes
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
boxes_filt = boxes_filt[nms_idx]
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
print(f"After NMS: {boxes_filt.shape[0]} boxes")
caption = check_caption(caption, pred_phrases)
print(f"Revise caption with number: {caption}")
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, im.shape[:2]).to(device)
masks, _, _ = predictor.predict_torch(
point_coords = None,
point_labels = None,
boxes = transformed_boxes.to(device),
multimask_output = False,
)
each mask's shape in masks is (1, h, w), total (n , 1, h, w)
I just wonder why there is a dimensional gap between SAM and GSAM, and is there a way to get a (1,1,3,w,h)?
I think it looks like 'Multimask_output=True', and if it is right, then the code might be:
new = [torch.tensor([[mask]]) for mask in masks.cpu().tolist()]
but I want to make sure of it.
Thank you in advance, and have a nice day!