The goal of this experiment is to provide a simple and easy to use C API for registering custom ONNX operators in ONNX Runtime with convenience APIs for scaffolding the kernel compute implementation.
static OrtSimpleCustomOpConfig custom_ops[] = {
{
.name = "CustomOpOne",
.inputs = {
.count = 2,
.homogenous_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
},
.outputs = {
.count = 1,
.homogenous_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
},
.kernel_compute = custom_op1_kernel
},
{
.name = "CustomOpTwo",
.inputs = {
.count = 1,
.homogenous_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
},
.outputs = {
.count = 1,
.homogenous_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
},
.kernel_compute = custom_op2_kernel
}
};
void Register(...)
{
...
OrtCustomOpDomain* custom_op_domain;
if ((ort_status = OrtSimpleCustomOpRegister(
ort,
NULL /* allocator */,
"test.customop",
custom_ops,
sizeof(custom_ops) / sizeof(custom_ops[0]),
&custom_op_domain,
NULL /* custom_ops */)) != NULL) {
return ort_error(ort, ort_status);
}
if ((ort_status = ort->AddCustomOpDomain(ort_session_options, custom_op_domain)) != NULL) {
return ort_error(ort, ort_status);
}
...
}
static void custom_op1_kernel(
const OrtSimpleCustomOp* op,
const OrtApi* ort,
const OrtKernelContext* context)
{
OrtSimpleCustomOpIO input_x;
OrtSimpleCustomOpGetInput(op, context, 0, &input_x);
OrtSimpleCustomOpIO input_y;
OrtSimpleCustomOpGetInput(op, context, 1, &input_y);
OrtSimpleCustomOpIO output;
OrtSimpleCustomOpGetOutput(op, context, 0, input_x.dims, input_x.dims_len, &output);
float* input_x_buffer = (float*)input_x.buffer;
float* input_y_buffer = (float*)input_y.buffer;
float* output_buffer = (float*)output.buffer;
for (size_t i = 0; i < input_x.buffer_len; i++) {
output_buffer[i] = input_x_buffer[i] + input_y_buffer[i];
}
OrtSimpleCustomOpIORelease(op, &input_x);
OrtSimpleCustomOpIORelease(op, &input_y);
OrtSimpleCustomOpIORelease(op, &output);
}
static void custom_op2_kernel(
const OrtSimpleCustomOp* op,
const OrtApi* ort,
const OrtKernelContext* context)
{
OrtSimpleCustomOpIO input;
OrtSimpleCustomOpGetInput(op, context, 0, &input);
OrtSimpleCustomOpIO output;
OrtSimpleCustomOpGetOutput(op, context, 0, input.dims, input.dims_len, &output);
float* input_buffer = (float*)input.buffer;
int32_t* output_buffer = (int32_t*)output.buffer;
for (size_t i = 0; i < input.buffer_len; i++) {
output_buffer[i] = round(input_buffer[i]);
}
OrtSimpleCustomOpIORelease(op, &input);
OrtSimpleCustomOpIORelease(op, &output);
}