Basics
./code/Makefile
1CXXFLAGS := -std=c++17
2CXXFLAGS += -I /Users/fangjun/Downloads/onnxruntime-osx-x86_64-1.18.0/include
3LDFLAGS := -L /Users/fangjun/Downloads/onnxruntime-osx-x86_64-1.18.0/lib
4LDFLAGS += -l onnxruntime
5LDFLAGS += -Wl,-rpath,/Users/fangjun/Downloads/onnxruntime-osx-x86_64-1.18.0/lib
6
7main: main.cc c-api-test.cc cpp-api-test.cc ./custom-op.cc ./custom-op-2.cc ./custom-op-3.cc
8 $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS)
./code/main.cc
1#include <iostream>
2
3#include "onnxruntime_cxx_api.h" // NOLINT
4
5void TestCApi();
6void TestCppApi();
7void TestCustomModel();
8void TestCustomModel2();
9void TestCustomModel3();
10
11int main() {
12 TestCApi();
13 TestCppApi();
14
15 std::cout << "---test custom model---\n";
16 TestCustomModel();
17
18 std::cout << "---test custom model2---\n";
19 TestCustomModel2();
20
21 std::cout << "---test custom model3---\n";
22 TestCustomModel3();
23
24 std::cout << "ORT_API_VERSION: " << ORT_API_VERSION << "\n";
25 return 0;
26}
27/*
28GetVersionString(): 1.18.0
29Available providers: CoreMLExecutionProvider, CPUExecutionProvider
30allocator name: Cpu
31---test custom model---
32110
33 220
34 330
35 ---test custom model2---
3611.5
37 2.5
38 3.5
39 44.5
40 ---test custom model3---
4111
42 22
43 ORT_API_VERSION: 18
44 */
./code/c-api-test.cc
1#include "onnxruntime_c_api.h" // NOLINT
2#include <cassert>
3#include <stdio.h>
4
5static void TestOrtStatus() {
6 const OrtApiBase *api_base = OrtGetApiBase();
7 const OrtApi *api = api_base->GetApi(ORT_API_VERSION);
8 OrtErrorCode code = ORT_OK;
9 const char *msg = "this is a message";
10
11 OrtStatus *status = api->CreateStatus(code, msg);
12 assert(api->GetErrorCode(status) == code);
13
14 const char *msg2 = api->GetErrorMessage(status);
15 assert(strcmp(msg, msg2) == 0);
16
17 // status addr: 0x600001e54040, msg2 addr: 0x600001e54044
18 fprintf(stderr, "status addr: %p, msg2 addr: %p\n", status, msg2);
19
20 // note that sizeof(code) is 4 in my test
21 assert((intptr_t)status + sizeof(code) == (intptr_t)msg2);
22
23 // we have to free the status to avoid memory leak
24 api->ReleaseStatus(status);
25}
26
27static void TestOrtApiBase() {
28
29 // OrtApiBase only has two method
30 const OrtApiBase *api_base = OrtGetApiBase();
31 fprintf(stderr, "GetVersionString(): %s\n", api_base->GetVersionString());
32
33 const OrtApi *api = api_base->GetApi(ORT_API_VERSION);
34 fprintf(stderr, "OrtApi: %p\n", api);
35
36 const char *info = api->GetBuildInfoString();
37 fprintf(stderr, "info: %s\n", info);
38}
39
40void TestCApi() {
41 TestOrtApiBase();
42 TestOrtStatus();
43}
./code/cpp-api-test.cc
1#include "onnxruntime_cxx_api.h" // NOLINT
2#include <assert.h>
3#include <iostream>
4#include <sstream>
5
6static void TestOrtGetApi() {
7 const OrtApi &api = Ort::GetApi(); // it returns a const reference
8
9 std::string version = Ort::GetVersionString();
10 std::cout << "version: " << version << "\n";
11}
12
13static void PrintAvailableProviders() {
14 std::vector<std::string> providers = Ort::GetAvailableProviders();
15 std::ostringstream os;
16 os << "Available providers: ";
17 std::string sep = "";
18 for (const auto &p : providers) {
19 os << sep << p;
20 sep = ", ";
21 }
22 std::cout << os.str() << "\n";
23}
24
25static void TestCreateTensorFromBuffer() {
26 std::vector<int32_t> v = {1, 2, 3, 4, 5, 6};
27 std::array<int64_t, 2> shape = {2, 3};
28 auto memory_info =
29 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
30
31 Ort::Value x = Ort::Value::CreateTensor<int32_t>(
32 memory_info, v.data(), v.size(), shape.data(), shape.size());
33
34 // memory is shared between x and v
35 int32_t *p = x.GetTensorMutableData<int32_t>();
36 p[0] = 10;
37 assert(v[0] == 10);
38
39 v[1] = 20;
40 assert(p[1] == 20);
41}
42
43static void TestCreateTensor() {
44 Ort::AllocatorWithDefaultOptions allocator;
45
46 std::array<int64_t, 2> shape = {2, 3};
47 auto memory_info =
48 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
49
50 Ort::Value x =
51 Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size());
52 assert(x.IsTensor());
53 assert(x.HasValue());
54 Ort::TypeInfo type_info = x.GetTypeInfo();
55 auto tensor_type_and_shape_info = type_info.GetTensorTypeAndShapeInfo();
56 assert(tensor_type_and_shape_info.GetElementCount() == 2 * 3);
57 assert(tensor_type_and_shape_info.GetDimensionsCount() == 2);
58 std::vector<int64_t> x_shape = tensor_type_and_shape_info.GetShape();
59 assert(x_shape.size() == shape.size());
60 assert(x_shape[0] == shape[0]);
61 assert(x_shape[1] == shape[1]);
62
63 assert(tensor_type_and_shape_info.GetElementType() ==
64 ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
65
66 Ort::ConstMemoryInfo memory_info2 = x.GetTensorMemoryInfo();
67 std::cout << "allocator name: " << memory_info2.GetAllocatorName() << "\n";
68}
69
70static void TestDataType() {
71 static_assert(Ort::TypeToTensorType<float>::type ==
72 ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
73
74 static_assert(Ort::TypeToTensorType<double>::type ==
75 ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE);
76
77 static_assert(Ort::TypeToTensorType<int8_t>::type ==
78 ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8);
79 static_assert(Ort::TypeToTensorType<int16_t>::type ==
80 ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16);
81 static_assert(Ort::TypeToTensorType<int32_t>::type ==
82 ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
83 static_assert(Ort::TypeToTensorType<int64_t>::type ==
84 ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
85 static_assert(Ort::TypeToTensorType<uint8_t>::type ==
86 ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
87 static_assert(Ort::TypeToTensorType<uint16_t>::type ==
88 ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16);
89 static_assert(Ort::TypeToTensorType<uint32_t>::type ==
90 ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32);
91 static_assert(Ort::TypeToTensorType<uint64_t>::type ==
92 ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64);
93 static_assert(Ort::TypeToTensorType<bool>::type ==
94 ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
95}
96
97void TestCppApi() {
98 TestOrtGetApi();
99 PrintAvailableProviders();
100 TestCreateTensorFromBuffer();
101 TestCreateTensor();
102 TestDataType();
103}
./code/custom-op.cc
1/*
2references:
3https://onnxruntime.ai/docs/reference/operators/add-custom-op.html
4
5 */
6#include "onnxruntime_lite_custom_op.h"
7#include <iostream>
8#include <utility>
9#include <vector>
10
11static void KernelOne(const Ort::Custom::Tensor<float> &X,
12 const Ort::Custom::Tensor<float> &Y,
13 Ort::Custom::Tensor<float> &Z) {
14 auto input_shape = X.Shape();
15 auto x_raw = X.Data();
16 auto y_raw = Y.Data();
17 auto z_raw = Z.Allocate(input_shape);
18 for (int64_t i = 0; i < Z.NumberOfElement(); ++i) {
19 z_raw[i] = x_raw[i] + y_raw[i];
20 }
21}
22
23static Ort::CustomOpDomain TestCustomOp() {
24 Ort::CustomOpDomain v1_domain{"com.k2fsa.org"};
25 // please make sure that custom_op_one has the same lifetime as the consuming
26 // session
27 //
28 // Here we use a static variable so it is never released.
29 // in practice, we can move it to a member variable of a class
30 static std::unique_ptr<Ort::Custom::OrtLiteCustomOp> custom_op_one{
31 Ort::Custom::CreateLiteCustomOp("CustomOpOne", "CPUExecutionProvider",
32 KernelOne)};
33 v1_domain.Add(custom_op_one.get());
34
35 return v1_domain;
36}
37
38void TestCustomModel() {
39 Ort::Env env;
40 Ort::SessionOptions sess_opts;
41 sess_opts.SetIntraOpNumThreads(1);
42 sess_opts.SetInterOpNumThreads(1);
43
44 Ort::CustomOpDomain v1_domain = TestCustomOp();
45
46 Ort::SessionOptions session_options;
47 sess_opts.Add(v1_domain);
48 // create a session with the session_options ...
49
50 std::unique_ptr<Ort::Session> sess =
51 std::make_unique<Ort::Session>(env, "./e.onnx", sess_opts);
52
53 auto memory_info =
54 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
55
56 std::vector<float> x = {10, 20, 30};
57 std::vector<float> y = {100, 200, 300};
58
59 std::array<int64_t, 1> shape = {3};
60
61 Ort::Value x_tensor = Ort::Value::CreateTensor(
62 memory_info, x.data(), x.size(), shape.data(), shape.size());
63
64 Ort::Value y_tensor = Ort::Value::CreateTensor(
65 memory_info, y.data(), y.size(), shape.data(), shape.size());
66
67 std::vector<Ort::Value> inputs;
68 inputs.push_back(std::move(x_tensor));
69 inputs.push_back(std::move(y_tensor));
70
71 std::vector<const char *> input_names = {"l_x_", "l_y_"};
72 std::vector<const char *> output_names = {"my_add_op"};
73 auto out = sess->Run({}, input_names.data(), inputs.data(), inputs.size(),
74 output_names.data(), output_names.size());
75 const float *p = out[0].GetTensorData<float>();
76 for (int i = 0; i < 3; ++i) {
77 std::cout << p[i] << "\n ";
78 }
79}
./code/e.py
1#!/usr/bin/env python3
2
3from torch._custom_op import impl as custom_op
4import torch
5import onnx
6import onnxscript
7from onnxscript import opset18
8
9import warnings
10warnings.filterwarnings("ignore")
11
12@custom_op.custom_op("mylibrary::my_add_op")
13def my_add_op(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
14 # Since we are using mylibrary::my_add_op, so the function
15 # name must be my_add_op; otherwise, it will throw an error
16 # when this script is run
17 pass
18
19@my_add_op.impl_abstract()
20def my_add_op_impl_abstract_any_name_is_ok(x, y):
21 return torch.empty_like(x)
22
23@my_add_op.impl("cpu")
24def my_add_op_impl_any_name_is_ok(tensor_x):
25 return torch.round(tensor_x + tensor_x) # add x to itself, and round the result
26
27class CustomFoo(torch.nn.Module):
28 def forward(self, x, y):
29 return my_add_op(x, y)
30
31
32custom_opset = onnxscript.values.Opset(domain="com.k2fsa.org", version=1)
33
34@onnxscript.script(custom_opset)
35def custom_my_add(x, y):
36 return custom_opset.CustomOpOne(x, y)
37
38def main():
39 torch._dynamo.allow_in_graph(my_add_op)
40 x = torch.randn(3)
41 y = torch.randn(3)
42 custom_addandround_model = CustomFoo()
43 onnx_registry = torch.onnx.OnnxRegistry()
44 onnx_registry.register_op(
45 namespace="mylibrary", op_name="my_add_op", overload="default", function=custom_my_add
46 )
47
48 export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
49 onnx_program = torch.onnx.dynamo_export(
50 custom_addandround_model, x, y, export_options=export_options
51 )
52 onnx_program.save("./e.onnx")
53 with open('e.txt', 'w') as f:
54 f.write(str(onnx_program.model_proto))
55 onnx_model = onnx.load("e.onnx")
56 onnx.checker.check_model(onnx_model)
57
58
59if __name__ == '__main__':
60 main()
./code/e.txt
1ir_version: 8
2opset_import {
3 domain: "com.k2fsa.org"
4 version: 1
5}
6opset_import {
7 domain: ""
8 version: 18
9}
10opset_import {
11 domain: "pkg.onnxscript.torch_lib.common"
12 version: 1
13}
14producer_name: "pytorch"
15producer_version: "2.4.0"
16graph {
17 node {
18 input: "l_x_"
19 input: "l_y_"
20 output: "my_add_op"
21 name: "custom_my_add_0_n0"
22 op_type: "CustomOpOne"
23 domain: "com.k2fsa.org"
24 }
25 name: "main_graph"
26 input {
27 name: "l_x_"
28 type {
29 tensor_type {
30 elem_type: 1
31 shape {
32 dim {
33 dim_value: 3
34 }
35 }
36 }
37 }
38 }
39 input {
40 name: "l_y_"
41 type {
42 tensor_type {
43 elem_type: 1
44 shape {
45 dim {
46 dim_value: 3
47 }
48 }
49 }
50 }
51 }
52 output {
53 name: "my_add_op"
54 type {
55 tensor_type {
56 elem_type: 1
57 shape {
58 dim {
59 dim_value: 3
60 }
61 }
62 }
63 }
64 }
65}
./code/custom-op-2.cc
1/*
2references:
3https://onnxruntime.ai/docs/reference/operators/add-custom-op.html
4
5 */
6#include "onnxruntime_lite_custom_op.h"
7#include <iostream>
8#include <utility>
9#include <vector>
10
11static void KernelOne(const Ort::Custom::Tensor<float> &X,
12 const Ort::Custom::Tensor<float> &Y,
13 Ort::Custom::Tensor<float> &Z) {
14 auto input_shape = X.Shape();
15 auto x_raw = X.Data();
16 auto y_raw = Y.Data();
17 auto z_raw = Z.Allocate(input_shape);
18 for (int64_t i = 0; i < Z.NumberOfElement(); ++i) {
19 z_raw[i] = x_raw[i] + y_raw[i];
20 }
21}
22
23static Ort::CustomOpDomain TestCustomOp2() {
24 Ort::CustomOpDomain v1_domain{"com.k2fsa.org"};
25 // please make sure that custom_op_one has the same lifetime as the consuming
26 // session
27 //
28 // Here we use a static variable so it is never released.
29 // in practice, we can move it to a member variable of a class
30 static std::unique_ptr<Ort::Custom::OrtLiteCustomOp> custom_op_one{
31 Ort::Custom::CreateLiteCustomOp("CustomOpOne2", "CPUExecutionProvider",
32 KernelOne)};
33 v1_domain.Add(custom_op_one.get());
34
35 return v1_domain;
36}
37
38void TestCustomModel2() {
39 Ort::Env env;
40 Ort::SessionOptions sess_opts;
41 sess_opts.SetIntraOpNumThreads(1);
42 sess_opts.SetInterOpNumThreads(1);
43
44 Ort::CustomOpDomain v1_domain = TestCustomOp2();
45
46 Ort::SessionOptions session_options;
47 sess_opts.Add(v1_domain);
48 // create a session with the session_options ...
49
50 std::unique_ptr<Ort::Session> sess =
51 std::make_unique<Ort::Session>(env, "./f.onnx", sess_opts);
52
53 auto memory_info =
54 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
55 // foo is [1.5, 2.5, 3.5, 4.5]
56 //
57
58 std::vector<float> x = {10, -20, -30, 40};
59
60 std::array<int64_t, 1> shape = {4};
61
62 Ort::Value x_tensor = Ort::Value::CreateTensor(
63 memory_info, x.data(), x.size(), shape.data(), shape.size());
64
65 std::vector<Ort::Value> inputs;
66 inputs.push_back(std::move(x_tensor));
67
68 std::vector<const char *> input_names = {"l_x_"};
69 std::vector<const char *> output_names = {"my_add_op2"};
70 auto out = sess->Run({}, input_names.data(), inputs.data(), inputs.size(),
71 output_names.data(), output_names.size());
72 const float *p = out[0].GetTensorData<float>();
73 for (int i = 0; i < 4; ++i) {
74 std::cout << p[i] << "\n ";
75 }
76}
./code/f.py
1#!/usr/bin/env python3
2
3from torch._custom_op import impl as custom_op
4import torch
5import onnx
6import onnxscript
7from onnxscript import opset18
8
9import warnings
10warnings.filterwarnings("ignore")
11
12@custom_op.custom_op("mylibrary::my_add_op2")
13def my_add_op2(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
14 # Since we are using mylibrary::my_add_op2, so the function
15 # name must be my_add_op2; otherwise, it will throw an error
16 # when this script is run
17 pass
18
19@my_add_op2.impl_abstract()
20def my_add_op_impl_abstract_any_name_is_ok(x, y):
21 return torch.empty_like(x)
22
23@my_add_op2.impl("cpu")
24def my_add_op_impl_any_name_is_ok(x, y):
25 x = torch.nn.functional.relu(x)
26 return x+y # add x to itself, and round the result
27
28class CustomFoo(torch.nn.Module):
29 def __init__(self):
30 super().__init__()
31 self.register_parameter('foo', torch.nn.Parameter(torch.tensor([1.5, 2.5, 3.5, 4.5])))
32 self.relu = torch.nn.ReLU()
33 def forward(self, x):
34 x = self.relu(x)
35 return my_add_op2(x, self.foo)
36
37
38custom_opset = onnxscript.values.Opset(domain="com.k2fsa.org", version=1)
39
40@onnxscript.script(custom_opset)
41def custom_my_add(x, y):
42 return custom_opset.CustomOpOne2(x, y)
43
44@torch.no_grad()
45def main():
46 torch._dynamo.allow_in_graph(my_add_op2)
47 x = torch.randn(4)
48 custom_addandround_model = CustomFoo()
49 onnx_registry = torch.onnx.OnnxRegistry()
50 onnx_registry.register_op(
51 namespace="mylibrary", op_name="my_add_op2", overload="default", function=custom_my_add
52 )
53
54 export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
55 onnx_program = torch.onnx.dynamo_export(
56 custom_addandround_model, x, export_options=export_options
57 )
58 onnx_program.save("./f.onnx")
59 with open('f.txt', 'w') as f:
60 f.write(str(onnx_program.model_proto))
61 onnx_model = onnx.load("f.onnx")
62 onnx.checker.check_model(onnx_model)
63
64
65if __name__ == '__main__':
66 main()
./code/f.txt
1ir_version: 8
2opset_import {
3 domain: "pkg.onnxscript.torch_lib"
4 version: 1
5}
6opset_import {
7 domain: "pkg.torch.2.4.0+cpu"
8 version: 1
9}
10opset_import {
11 domain: "com.k2fsa.org"
12 version: 1
13}
14opset_import {
15 domain: ""
16 version: 18
17}
18opset_import {
19 domain: "pkg.onnxscript.torch_lib.common"
20 version: 1
21}
22producer_name: "pytorch"
23producer_version: "2.4.0"
24graph {
25 node {
26 input: "l_x_"
27 output: "relu_1"
28 name: "torch_nn_modules_activation_ReLU_relu_1_0_aten_relu_0_n0"
29 op_type: "Relu"
30 }
31 node {
32 input: "relu_1"
33 input: "foo"
34 output: "my_add_op2"
35 name: "custom_my_add_1_n0"
36 op_type: "CustomOpOne2"
37 domain: "com.k2fsa.org"
38 }
39 name: "main_graph"
40 initializer {
41 dims: 4
42 data_type: 1
43 name: "foo"
44 raw_data: "\000\000\300?\000\000 @\000\000`@\000\000\220@"
45 }
46 input {
47 name: "l_x_"
48 type {
49 tensor_type {
50 elem_type: 1
51 shape {
52 dim {
53 dim_value: 4
54 }
55 }
56 }
57 }
58 }
59 output {
60 name: "my_add_op2"
61 type {
62 tensor_type {
63 elem_type: 1
64 shape {
65 dim {
66 dim_value: 4
67 }
68 }
69 }
70 }
71 }
72 value_info {
73 name: "relu_1"
74 type {
75 tensor_type {
76 elem_type: 1
77 shape {
78 dim {
79 dim_value: 4
80 }
81 }
82 }
83 }
84 }
85}
./code/custom-op-3.cc
1/*
2references:
3https://onnxruntime.ai/docs/reference/operators/add-custom-op.html
4
5 */
6#include "onnxruntime_lite_custom_op.h"
7#include <iostream>
8#include <utility>
9#include <vector>
10
11static void KernelOne(const Ort::Custom::Tensor<uint8_t> &X,
12 const Ort::Custom::Tensor<float> &scale_tensor,
13 Ort::Custom::Tensor<float> &Y) {
14 auto input_shape = X.Shape();
15 auto x_raw = X.Data();
16 auto scale = scale_tensor.Data()[0];
17 auto y_raw = Y.Allocate(input_shape);
18 for (int64_t i = 0; i < Y.NumberOfElement(); ++i) {
19
20 // scale each uint8 number
21 y_raw[i] = x_raw[i] * scale;
22 }
23}
24
25static Ort::CustomOpDomain TestCustomOp3() {
26 Ort::CustomOpDomain v1_domain{"com.k2fsa.org"};
27 // please make sure that custom_op_one has the same lifetime as the consuming
28 // session
29 //
30 // Here we use a static variable so it is never released.
31 // in practice, we can move it to a member variable of a class
32 static std::unique_ptr<Ort::Custom::OrtLiteCustomOp> custom_op_one{
33 Ort::Custom::CreateLiteCustomOp("MyCast", "CPUExecutionProvider",
34 KernelOne)};
35 v1_domain.Add(custom_op_one.get());
36
37 return v1_domain;
38}
39
40void TestCustomModel3() {
41 Ort::Env env;
42 Ort::SessionOptions sess_opts;
43 sess_opts.SetIntraOpNumThreads(1);
44 sess_opts.SetInterOpNumThreads(1);
45
46 Ort::CustomOpDomain v1_domain = TestCustomOp3();
47
48 Ort::SessionOptions session_options;
49 sess_opts.Add(v1_domain);
50 // create a session with the session_options ...
51
52 std::unique_ptr<Ort::Session> sess =
53 std::make_unique<Ort::Session>(env, "./g.onnx", sess_opts);
54
55 auto memory_info =
56 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
57 // foo is [10, 20]
58 //
59
60 std::vector<float> x = {1, 2};
61
62 std::array<int64_t, 1> shape = {2};
63
64 Ort::Value x_tensor = Ort::Value::CreateTensor(
65 memory_info, x.data(), x.size(), shape.data(), shape.size());
66
67 std::vector<Ort::Value> inputs;
68 inputs.push_back(std::move(x_tensor));
69
70 std::vector<const char *> input_names = {"l_x_"};
71 std::vector<const char *> output_names = {"add"};
72 auto out = sess->Run({}, input_names.data(), inputs.data(), inputs.size(),
73 output_names.data(), output_names.size());
74 const float *p = out[0].GetTensorData<float>();
75 for (int i = 0; i < 2; ++i) {
76 std::cout << p[i] << "\n ";
77 }
78}
./code/g.py
1#!/usr/bin/env python3
2
3from torch._custom_op import impl as custom_op
4import torch
5import onnx
6import onnxscript
7from onnxscript import opset18
8
9import warnings
10warnings.filterwarnings("ignore")
11
12@custom_op.custom_op("mylibrary::my_cast")
13def my_cast(x: torch.Tensor, scale: float = 0.25) -> torch.Tensor:
14 # Since we are using mylibrary::my_cast, so the function
15 # name must be my_cast; otherwise, it will throw an error
16 # when this script is run
17 return x.to(torch.float32)
18
19@my_cast.impl_abstract()
20def my_cast_impl_abstract_any_name_is_ok(x, scale: float = 0.25):
21 return x.to(torch.float32)
22
23@my_cast.impl("cpu")
24def my_cast_impl_any_name_is_ok(x, scale: float = 0.25):
25 return x.to(torch.float32) # add x to itself, and round the result
26
27class CustomFoo(torch.nn.Module):
28 def __init__(self):
29 super().__init__()
30 self.register_buffer('foo', torch.tensor([10, 20], dtype=torch.uint8))
31 self.scale = 0.125
32
33 def forward(self, x):
34 return x + my_cast(self.foo, self.scale)
35
36
37custom_opset = onnxscript.values.Opset(domain="com.k2fsa.org", version=1)
38
39@onnxscript.script(custom_opset, default_opset=opset18)
40def custom_my_cast(x, scale: float = 0.5):
41 return custom_opset.MyCast(x, scale)
42
43@torch.no_grad()
44def main():
45 torch._dynamo.allow_in_graph(my_cast)
46 x = torch.randn(2)
47 custom_addandround_model = CustomFoo()
48 onnx_registry = torch.onnx.OnnxRegistry()
49 onnx_registry.register_op(
50 namespace="mylibrary", op_name="my_cast", overload="default", function=custom_my_cast
51 )
52
53 export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
54 onnx_program = torch.onnx.dynamo_export(
55 custom_addandround_model, x, export_options=export_options
56 )
57 onnx_program.save("./g.onnx")
58 with open('g.txt', 'w') as f:
59 f.write(str(onnx_program.model_proto))
60 onnx_model = onnx.load("g.onnx")
61 onnx.checker.check_model(onnx_model)
62
63
64if __name__ == '__main__':
65 main()
./code/g.txt
1ir_version: 8
2opset_import {
3 domain: "com.k2fsa.org"
4 version: 1
5}
6opset_import {
7 domain: "pkg.onnxscript.torch_lib"
8 version: 1
9}
10opset_import {
11 domain: ""
12 version: 18
13}
14opset_import {
15 domain: "pkg.onnxscript.torch_lib.common"
16 version: 1
17}
18producer_name: "pytorch"
19producer_version: "2.4.0"
20graph {
21 node {
22 output: "custom_my_cast_0_scale"
23 name: "custom_my_cast_0_n0"
24 op_type: "Constant"
25 attribute {
26 name: "value_float"
27 type: FLOAT
28 f: 0.125
29 }
30 }
31 node {
32 input: "foo"
33 input: "custom_my_cast_0_scale"
34 output: "my_cast"
35 name: "custom_my_cast_0_n1"
36 op_type: "MyCast"
37 domain: "com.k2fsa.org"
38 }
39 node {
40 input: "l_x_"
41 input: "my_cast"
42 output: "add"
43 name: "aten_add_1"
44 op_type: "aten_add"
45 domain: "pkg.onnxscript.torch_lib"
46 attribute {
47 name: "alpha"
48 type: FLOAT
49 f: 1
50 }
51 }
52 name: "main_graph"
53 initializer {
54 dims: 2
55 data_type: 2
56 name: "foo"
57 raw_data: "\n\024"
58 }
59 input {
60 name: "l_x_"
61 type {
62 tensor_type {
63 elem_type: 1
64 shape {
65 dim {
66 dim_value: 2
67 }
68 }
69 }
70 }
71 }
72 output {
73 name: "add"
74 type {
75 tensor_type {
76 elem_type: 1
77 shape {
78 dim {
79 dim_value: 2
80 }
81 }
82 }
83 }
84 }
85 value_info {
86 name: "custom_my_cast_0_scale"
87 type {
88 tensor_type {
89 elem_type: 1
90 shape {
91 }
92 }
93 }
94 }
95 value_info {
96 name: "my_cast"
97 type {
98 tensor_type {
99 elem_type: 1
100 shape {
101 dim {
102 dim_value: 2
103 }
104 }
105 }
106 }
107 }
108 value_info {
109 name: "pkg.onnxscript.torch_lib::aten_add/self"
110 type {
111 tensor_type {
112 elem_type: 1
113 shape {
114 dim {
115 dim_value: 2
116 }
117 }
118 }
119 }
120 }
121 value_info {
122 name: "pkg.onnxscript.torch_lib::aten_add/other"
123 type {
124 tensor_type {
125 elem_type: 1
126 shape {
127 dim {
128 dim_value: 2
129 }
130 }
131 }
132 }
133 }
134 value_info {
135 name: "pkg.onnxscript.torch_lib::aten_add/alpha"
136 type {
137 tensor_type {
138 elem_type: 1
139 shape {
140 }
141 }
142 }
143 }
144 value_info {
145 name: "pkg.onnxscript.torch_lib::aten_add/other_1"
146 type {
147 tensor_type {
148 elem_type: 1
149 shape {
150 dim {
151 dim_value: 2
152 }
153 }
154 }
155 }
156 }
157 value_info {
158 name: "pkg.onnxscript.torch_lib::aten_add/return_val"
159 type {
160 tensor_type {
161 elem_type: 1
162 shape {
163 dim {
164 dim_value: 2
165 }
166 }
167 }
168 }
169 }
170}
171functions {
172 name: "aten_add"
173 input: "self"
174 input: "other"
175 output: "return_val"
176 attribute_proto {
177 name: "alpha"
178 type: FLOAT
179 f: 1
180 }
181 node {
182 output: "alpha"
183 name: "n0"
184 op_type: "Constant"
185 attribute {
186 name: "value_float"
187 ref_attr_name: "alpha"
188 type: FLOAT
189 }
190 }
191 node {
192 input: "other"
193 input: "alpha"
194 output: "other_1"
195 name: "n2"
196 op_type: "Mul"
197 }
198 node {
199 input: "self"
200 input: "other_1"
201 output: "return_val"
202 name: "n3"
203 op_type: "Add"
204 }
205 doc_string: "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
206 opset_import {
207 domain: ""
208 version: 18
209 }
210 domain: "pkg.onnxscript.torch_lib"
211}