How to make sam( vit_l onnx ) faster.
Siwakonrome opened this issue · 0 comments
Siwakonrome commented
My Configuration
- Window OS
- vit_l onnx
- RTX 2080
My processing times are 10.1 sec. I want it faster under 1.0 or 2.0 seconds to be appropriate with my application.
def __call__(self, image_np, box):
t0 = time.time()
box = self.create_box_as_a_prompt(box=box)
input_tensor, resized, orig = self.preprocess(image_np=image_np)
contours, hierarchy = self.postprocess(input_tensor=input_tensor,
box=box,
resized=resized,
orig=orig)
t1 = time.time()
process_time = t1 - t0
return contours, hierarchy, process_time
My code.
import cv2
import time
import numpy as np
from PIL import Image
import onnxruntime as ort
from copy import deepcopy
class SegmentSamOnnxOperator:
def __init__(self, onnx_encoder_path, onnx_decoder_path):
self.encoder = ort.InferenceSession(onnx_encoder_path)
self.decoder = ort.InferenceSession(onnx_decoder_path)
def __call__(self, image_np, box):
t0 = time.time()
box = self.create_box_as_a_prompt(box=box)
input_tensor, resized, orig = self.preprocess(image_np=image_np)
contours, hierarchy = self.postprocess(input_tensor=input_tensor,
box=box,
resized=resized,
orig=orig)
t1 = time.time()
process_time = t1 - t0
return contours, hierarchy, process_time
def create_box_as_a_prompt(self, box):
return np.array([box['x'], box['y'], box['x'] + box['width'], box['y'] + box['height']])
def postprocess(self, input_tensor, box, resized, orig):
orig_width , orig_height = orig
resized_width , resized_height = resized
outputs = self.encoder.run(None,{"images":input_tensor})
embeddings = outputs[0]
# 3. DECODE MASKS FROM IMAGE EMBEDDINGS
# 3.2 OPTION 2: Use box as a prompt
# ENCODE PROMPT (box)
input_box = box.reshape(2,2)
input_labels = np.array([2,3])
onnx_coord = input_box[None, :, :]
onnx_label = input_labels[None, :].astype(np.float32)
coords = deepcopy(onnx_coord).astype(float)
coords[..., 0] = coords[..., 0] * (resized_width / orig_width)
coords[..., 1] = coords[..., 1] * (resized_height / orig_height)
onnx_coord = coords.astype("float32")
# RUN DECODER TO GET MASK
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
masks,_,_ = self.decoder.run(None,{
"image_embeddings": embeddings,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array([orig_height, orig_width], dtype=np.float32)
})
# POSTPROCESS MASK
mask = masks[0][0]
mask = (mask > 0).astype('uint8')*255
# MASK to contours
img_mask = Image.fromarray(mask, "L").convert("RGB")
imgray_np = cv2.cvtColor(np.array(img_mask), cv2.COLOR_RGB2GRAY)
ret, thresh = cv2.threshold(imgray_np, 250, 255, 0)
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
return contours, hierarchy
def preprocess(self, image_np):
img = Image.fromarray(cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB))
# 1. PREPROCESS IMAGE FOR ENCODER
# Resize image preserving aspect ratio using 1024 as a long side
orig_width, orig_height = img.size
resized_width, resized_height = img.size
if orig_width > orig_height:
resized_width = 1024
resized_height = int(1024 / orig_width * orig_height)
else:
resized_height = 1024
resized_width = int(1024 / orig_height * orig_width)
img = img.resize((resized_width, resized_height), Image.Resampling.BILINEAR)
# Prepare input tensor from image
input_tensor = np.array(img)
# Normalize input tensor numbers
mean = np.array([123.675, 116.28, 103.53])
std = np.array([[58.395, 57.12, 57.375]])
input_tensor = (input_tensor - mean) / std
# Transpose input tensor to shape (Batch,Channels,Height,Width
input_tensor = input_tensor.transpose(2,0,1)[None,:,:,:].astype(np.float32)
# Make image square 1024x1024 by padding short side by zeros
if resized_height < resized_width:
input_tensor = np.pad(input_tensor,((0,0),(0,0),(0,1024-resized_height),(0,0)))
else:
input_tensor = np.pad(input_tensor,((0,0),(0,0),(0,0),(0,1024-resized_width)))
return input_tensor, (resized_width , resized_height), (orig_width , orig_height)