how to export tosa/linalg, when conv input and weight are torch.int8, bias is torch.int32
zccyman opened this issue · 0 comments
zccyman commented
import torch
import torch.nn.functional as F
def export_linalg_by_shark(model, dummy_input):
from extension.shark.shark_importer import SharkImporter
mlir_type = "linalg"
mlir_importer = SharkImporter(
model,
(dummy_input,),
frontend="torch",
return_str=True,
)
mlir_str = mlir_importer._torch_mlir(
is_dynamic=False, tracing_required=True, mlir_type=mlir_type
)
with open("test.mlir", "w") as f:
f.write(mlir_str)
class SimpleNet(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
x,
):
weight = 100 * torch.ones(64, 12, 3, 3).type(torch.int8)
y = F.conv2d(
input=x.type(torch.int8),
weight=weight,
bias=None,
stride=tuple([1, 1]),
padding=(0, 0),
dilation=tuple([1, 1]),
groups=1,
)
return y
model = SimpleNet()
model.eval()
input = 100 * torch.ones(1, 12, 224, 224).type(torch.int8)
output = model(input)
export_linalg_by_shark(
model,
input,
)
print("test")
input, weight, bias are torch.float32, export linalg can success.