how to quantize the lightweight SAM model?
ranpin opened this issue · 0 comments
Hi, nice work it is. I'm tring your method to do some application and have some questions about the quantization.
I have carefully looked at the code in demo_quan.py and layer.py, but currently the model in demo_quan.py is loaded directly from the quantized weights. I would like to ask how to quantize the instances created from an existing pre-trained SAM model using your quantization method?
Since I don't know how you quantize your lightweight SAM model using the quantization method in layer.py, can you provide a reference example of how did you do when quantizing the model? Thank you very much!
Here is the demo I wrote, it runs successfully, but the test result after quantization is close to 0. does it need retraining? Or maybe I'm not thinking correctly? hoping your reply!
from quantization_layer.layers import InferQuantConv2d, InferQuantConvTranspose2d
model_type = 'vit_b'
checkpoint = 'checkpoints/sam_vit_b_01ec64.pth'
model = sam_model_registry[model_type](checkpoint=checkpoint)
model.to(device)
model.eval()
predictor = SamPredictor(model)
w_bit = 8
a_bit = 8
input_size = (1, 3, 1024, 1024)
n_V = input_size[2]
n_H = input_size[3]
a_interval = torch.tensor(0.1)
a_bias = torch.tensor(0.0)
w_interval = torch.tensor(0.01)
# 量化模型中的卷积层和卷积转置层
def replace_with_quantized_layers(model):
layers_to_replace = []
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
layers_to_replace.append((name, module))
for name, module in layers_to_replace:
if isinstance(module, nn.Conv2d):
quantized_module = InferQuantConv2d(
in_channels=module.in_channels,
out_channels=module.out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=module.bias is not None,
mode='quant_forward',
w_bit=w_bit,
a_bit=a_bit
)
quantized_module.get_parameter(n_V=n_V,
n_H=n_H,
a_interval=a_interval,
a_bias=a_bias,
w_interval=w_interval)
elif isinstance(module, nn.ConvTranspose2d):
quantized_module = InferQuantConvTranspose2d(
in_channels=module.in_channels,
out_channels=module.out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
output_padding=module.output_padding,
groups=module.groups,
bias=module.bias is not None,
mode='quant_forward',
w_bit=w_bit,
a_bit=a_bit
)
quantized_module.get_parameter(n_V=n_V,
n_H=n_H,
a_interval=a_interval,
a_bias=a_bias,
w_interval=w_interval)
setattr(model, name, quantized_module)
return model
quan_model = replace_with_quantized_layers(model)
print(quan_model)