facebookresearch/segment-anything

How to make sam( vit_l onnx ) faster.

Siwakonrome opened this issue · 0 comments

My Configuration

  1. Window OS
  2. vit_l onnx
  3. RTX 2080

My processing times are 10.1 sec. I want it faster under 1.0 or 2.0 seconds to be appropriate with my application.

def __call__(self, image_np, box):
        t0 = time.time()
        box = self.create_box_as_a_prompt(box=box)
        input_tensor, resized, orig = self.preprocess(image_np=image_np)
        contours, hierarchy = self.postprocess(input_tensor=input_tensor,
                                               box=box,
                                               resized=resized,
                                               orig=orig)
        t1 = time.time()
        process_time = t1 - t0
        return contours, hierarchy, process_time

My code.

import cv2
import time
import numpy as np
from PIL import Image
import onnxruntime as ort
from copy import deepcopy


class SegmentSamOnnxOperator:

    def __init__(self, onnx_encoder_path, onnx_decoder_path):
        self.encoder = ort.InferenceSession(onnx_encoder_path)
        self.decoder = ort.InferenceSession(onnx_decoder_path)

    def __call__(self, image_np, box):
        t0 = time.time()
        box = self.create_box_as_a_prompt(box=box)
        input_tensor, resized, orig = self.preprocess(image_np=image_np)
        contours, hierarchy = self.postprocess(input_tensor=input_tensor,
                                               box=box,
                                               resized=resized,
                                               orig=orig)
        t1 = time.time()
        process_time = t1 - t0
        return contours, hierarchy, process_time

    def create_box_as_a_prompt(self, box):
        return np.array([box['x'], box['y'], box['x'] + box['width'], box['y'] + box['height']])


    def postprocess(self, input_tensor, box, resized, orig):
        orig_width , orig_height = orig
        resized_width , resized_height = resized
        outputs = self.encoder.run(None,{"images":input_tensor})
        embeddings = outputs[0]
        # 3. DECODE MASKS FROM IMAGE EMBEDDINGS
        # 3.2 OPTION 2: Use box as a prompt
        # ENCODE PROMPT (box)
        input_box = box.reshape(2,2)
        input_labels = np.array([2,3])
        onnx_coord = input_box[None, :, :]
        onnx_label = input_labels[None, :].astype(np.float32)
        coords = deepcopy(onnx_coord).astype(float)
        coords[..., 0] = coords[..., 0] * (resized_width / orig_width)
        coords[..., 1] = coords[..., 1] * (resized_height / orig_height)
        onnx_coord = coords.astype("float32")
        # RUN DECODER TO GET MASK
        onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
        onnx_has_mask_input = np.zeros(1, dtype=np.float32)
        masks,_,_ = self.decoder.run(None,{
            "image_embeddings": embeddings,
            "point_coords": onnx_coord,
            "point_labels": onnx_label,
            "mask_input": onnx_mask_input,
            "has_mask_input": onnx_has_mask_input,
            "orig_im_size": np.array([orig_height, orig_width], dtype=np.float32)
        })
        # POSTPROCESS MASK
        mask = masks[0][0]
        mask = (mask > 0).astype('uint8')*255
        # MASK to contours
        img_mask = Image.fromarray(mask, "L").convert("RGB")
        imgray_np = cv2.cvtColor(np.array(img_mask), cv2.COLOR_RGB2GRAY)
        ret, thresh = cv2.threshold(imgray_np, 250, 255, 0)
        contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
        return contours, hierarchy

    def preprocess(self, image_np):
        img = Image.fromarray(cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB))
        # 1. PREPROCESS IMAGE FOR ENCODER
        # Resize image preserving aspect ratio using 1024 as a long side
        orig_width, orig_height = img.size
        resized_width, resized_height = img.size
        if orig_width > orig_height:
            resized_width = 1024
            resized_height = int(1024 / orig_width * orig_height)
        else:
            resized_height = 1024
            resized_width = int(1024 / orig_height * orig_width)
        img = img.resize((resized_width, resized_height), Image.Resampling.BILINEAR)
        # Prepare input tensor from image
        input_tensor = np.array(img)
        # Normalize input tensor numbers
        mean = np.array([123.675, 116.28, 103.53])
        std = np.array([[58.395, 57.12, 57.375]])
        input_tensor = (input_tensor - mean) / std
        # Transpose input tensor to shape (Batch,Channels,Height,Width
        input_tensor = input_tensor.transpose(2,0,1)[None,:,:,:].astype(np.float32)
        # Make image square 1024x1024 by padding short side by zeros
        if resized_height < resized_width:
            input_tensor = np.pad(input_tensor,((0,0),(0,0),(0,1024-resized_height),(0,0)))
        else:
            input_tensor = np.pad(input_tensor,((0,0),(0,0),(0,0),(0,1024-resized_width)))
        return input_tensor, (resized_width , resized_height), (orig_width , orig_height)