Basics

./code/Makefile
1CXXFLAGS := -std=c++17
2CXXFLAGS += -I /Users/fangjun/Downloads/onnxruntime-osx-x86_64-1.18.0/include
3LDFLAGS := -L /Users/fangjun/Downloads/onnxruntime-osx-x86_64-1.18.0/lib
4LDFLAGS += -l onnxruntime
5LDFLAGS += -Wl,-rpath,/Users/fangjun/Downloads/onnxruntime-osx-x86_64-1.18.0/lib
6
7main: main.cc c-api-test.cc cpp-api-test.cc ./custom-op.cc ./custom-op-2.cc ./custom-op-3.cc
8	$(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS)
./code/main.cc
 1#include <iostream>
 2
 3#include "onnxruntime_cxx_api.h" // NOLINT
 4
 5void TestCApi();
 6void TestCppApi();
 7void TestCustomModel();
 8void TestCustomModel2();
 9void TestCustomModel3();
10
11int main() {
12  TestCApi();
13  TestCppApi();
14
15  std::cout << "---test custom model---\n";
16  TestCustomModel();
17
18  std::cout << "---test custom model2---\n";
19  TestCustomModel2();
20
21  std::cout << "---test custom model3---\n";
22  TestCustomModel3();
23
24  std::cout << "ORT_API_VERSION: " << ORT_API_VERSION << "\n";
25  return 0;
26}
27/*
28GetVersionString(): 1.18.0
29Available providers: CoreMLExecutionProvider, CPUExecutionProvider
30allocator name: Cpu
31---test custom model---
32110
33 220
34 330
35 ---test custom model2---
3611.5
37 2.5
38 3.5
39 44.5
40 ---test custom model3---
4111
42 22
43 ORT_API_VERSION: 18
44 */
./code/c-api-test.cc
 1#include "onnxruntime_c_api.h" // NOLINT
 2#include <cassert>
 3#include <stdio.h>
 4
 5static void TestOrtStatus() {
 6  const OrtApiBase *api_base = OrtGetApiBase();
 7  const OrtApi *api = api_base->GetApi(ORT_API_VERSION);
 8  OrtErrorCode code = ORT_OK;
 9  const char *msg = "this is a message";
10
11  OrtStatus *status = api->CreateStatus(code, msg);
12  assert(api->GetErrorCode(status) == code);
13
14  const char *msg2 = api->GetErrorMessage(status);
15  assert(strcmp(msg, msg2) == 0);
16
17  // status addr: 0x600001e54040, msg2 addr: 0x600001e54044
18  fprintf(stderr, "status addr: %p, msg2 addr: %p\n", status, msg2);
19
20  // note that sizeof(code) is 4 in my test
21  assert((intptr_t)status + sizeof(code) == (intptr_t)msg2);
22
23  // we have to free the status to avoid memory leak
24  api->ReleaseStatus(status);
25}
26
27static void TestOrtApiBase() {
28
29  // OrtApiBase only has two method
30  const OrtApiBase *api_base = OrtGetApiBase();
31  fprintf(stderr, "GetVersionString(): %s\n", api_base->GetVersionString());
32
33  const OrtApi *api = api_base->GetApi(ORT_API_VERSION);
34  fprintf(stderr, "OrtApi: %p\n", api);
35
36  const char *info = api->GetBuildInfoString();
37  fprintf(stderr, "info: %s\n", info);
38}
39
40void TestCApi() {
41  TestOrtApiBase();
42  TestOrtStatus();
43}
./code/cpp-api-test.cc
  1#include "onnxruntime_cxx_api.h" // NOLINT
  2#include <assert.h>
  3#include <iostream>
  4#include <sstream>
  5
  6static void TestOrtGetApi() {
  7  const OrtApi &api = Ort::GetApi(); // it returns a const reference
  8
  9  std::string version = Ort::GetVersionString();
 10  std::cout << "version: " << version << "\n";
 11}
 12
 13static void PrintAvailableProviders() {
 14  std::vector<std::string> providers = Ort::GetAvailableProviders();
 15  std::ostringstream os;
 16  os << "Available providers: ";
 17  std::string sep = "";
 18  for (const auto &p : providers) {
 19    os << sep << p;
 20    sep = ", ";
 21  }
 22  std::cout << os.str() << "\n";
 23}
 24
 25static void TestCreateTensorFromBuffer() {
 26  std::vector<int32_t> v = {1, 2, 3, 4, 5, 6};
 27  std::array<int64_t, 2> shape = {2, 3};
 28  auto memory_info =
 29      Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
 30
 31  Ort::Value x = Ort::Value::CreateTensor<int32_t>(
 32      memory_info, v.data(), v.size(), shape.data(), shape.size());
 33
 34  // memory is shared between x and v
 35  int32_t *p = x.GetTensorMutableData<int32_t>();
 36  p[0] = 10;
 37  assert(v[0] == 10);
 38
 39  v[1] = 20;
 40  assert(p[1] == 20);
 41}
 42
 43static void TestCreateTensor() {
 44  Ort::AllocatorWithDefaultOptions allocator;
 45
 46  std::array<int64_t, 2> shape = {2, 3};
 47  auto memory_info =
 48      Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
 49
 50  Ort::Value x =
 51      Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size());
 52  assert(x.IsTensor());
 53  assert(x.HasValue());
 54  Ort::TypeInfo type_info = x.GetTypeInfo();
 55  auto tensor_type_and_shape_info = type_info.GetTensorTypeAndShapeInfo();
 56  assert(tensor_type_and_shape_info.GetElementCount() == 2 * 3);
 57  assert(tensor_type_and_shape_info.GetDimensionsCount() == 2);
 58  std::vector<int64_t> x_shape = tensor_type_and_shape_info.GetShape();
 59  assert(x_shape.size() == shape.size());
 60  assert(x_shape[0] == shape[0]);
 61  assert(x_shape[1] == shape[1]);
 62
 63  assert(tensor_type_and_shape_info.GetElementType() ==
 64         ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
 65
 66  Ort::ConstMemoryInfo memory_info2 = x.GetTensorMemoryInfo();
 67  std::cout << "allocator name: " << memory_info2.GetAllocatorName() << "\n";
 68}
 69
 70static void TestDataType() {
 71  static_assert(Ort::TypeToTensorType<float>::type ==
 72                ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
 73
 74  static_assert(Ort::TypeToTensorType<double>::type ==
 75                ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE);
 76
 77  static_assert(Ort::TypeToTensorType<int8_t>::type ==
 78                ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8);
 79  static_assert(Ort::TypeToTensorType<int16_t>::type ==
 80                ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16);
 81  static_assert(Ort::TypeToTensorType<int32_t>::type ==
 82                ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
 83  static_assert(Ort::TypeToTensorType<int64_t>::type ==
 84                ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
 85  static_assert(Ort::TypeToTensorType<uint8_t>::type ==
 86                ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
 87  static_assert(Ort::TypeToTensorType<uint16_t>::type ==
 88                ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16);
 89  static_assert(Ort::TypeToTensorType<uint32_t>::type ==
 90                ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32);
 91  static_assert(Ort::TypeToTensorType<uint64_t>::type ==
 92                ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64);
 93  static_assert(Ort::TypeToTensorType<bool>::type ==
 94                ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
 95}
 96
 97void TestCppApi() {
 98  TestOrtGetApi();
 99  PrintAvailableProviders();
100  TestCreateTensorFromBuffer();
101  TestCreateTensor();
102  TestDataType();
103}
./code/custom-op.cc
 1/*
 2references:
 3https://onnxruntime.ai/docs/reference/operators/add-custom-op.html
 4
 5 */
 6#include "onnxruntime_lite_custom_op.h"
 7#include <iostream>
 8#include <utility>
 9#include <vector>
10
11static void KernelOne(const Ort::Custom::Tensor<float> &X,
12                      const Ort::Custom::Tensor<float> &Y,
13                      Ort::Custom::Tensor<float> &Z) {
14  auto input_shape = X.Shape();
15  auto x_raw = X.Data();
16  auto y_raw = Y.Data();
17  auto z_raw = Z.Allocate(input_shape);
18  for (int64_t i = 0; i < Z.NumberOfElement(); ++i) {
19    z_raw[i] = x_raw[i] + y_raw[i];
20  }
21}
22
23static Ort::CustomOpDomain TestCustomOp() {
24  Ort::CustomOpDomain v1_domain{"com.k2fsa.org"};
25  // please make sure that custom_op_one has the same lifetime as the consuming
26  // session
27  //
28  // Here we use a static variable so it is never released.
29  // in practice, we can move it to a member variable of a class
30  static std::unique_ptr<Ort::Custom::OrtLiteCustomOp> custom_op_one{
31      Ort::Custom::CreateLiteCustomOp("CustomOpOne", "CPUExecutionProvider",
32                                      KernelOne)};
33  v1_domain.Add(custom_op_one.get());
34
35  return v1_domain;
36}
37
38void TestCustomModel() {
39  Ort::Env env;
40  Ort::SessionOptions sess_opts;
41  sess_opts.SetIntraOpNumThreads(1);
42  sess_opts.SetInterOpNumThreads(1);
43
44  Ort::CustomOpDomain v1_domain = TestCustomOp();
45
46  Ort::SessionOptions session_options;
47  sess_opts.Add(v1_domain);
48  // create a session with the session_options ...
49
50  std::unique_ptr<Ort::Session> sess =
51      std::make_unique<Ort::Session>(env, "./e.onnx", sess_opts);
52
53  auto memory_info =
54      Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
55
56  std::vector<float> x = {10, 20, 30};
57  std::vector<float> y = {100, 200, 300};
58
59  std::array<int64_t, 1> shape = {3};
60
61  Ort::Value x_tensor = Ort::Value::CreateTensor(
62      memory_info, x.data(), x.size(), shape.data(), shape.size());
63
64  Ort::Value y_tensor = Ort::Value::CreateTensor(
65      memory_info, y.data(), y.size(), shape.data(), shape.size());
66
67  std::vector<Ort::Value> inputs;
68  inputs.push_back(std::move(x_tensor));
69  inputs.push_back(std::move(y_tensor));
70
71  std::vector<const char *> input_names = {"l_x_", "l_y_"};
72  std::vector<const char *> output_names = {"my_add_op"};
73  auto out = sess->Run({}, input_names.data(), inputs.data(), inputs.size(),
74                       output_names.data(), output_names.size());
75  const float *p = out[0].GetTensorData<float>();
76  for (int i = 0; i < 3; ++i) {
77    std::cout << p[i] << "\n ";
78  }
79}
./code/e.py
 1#!/usr/bin/env python3
 2
 3from torch._custom_op import impl as custom_op
 4import torch
 5import onnx
 6import onnxscript
 7from onnxscript import opset18
 8
 9import warnings
10warnings.filterwarnings("ignore")
11
12@custom_op.custom_op("mylibrary::my_add_op")
13def my_add_op(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
14    # Since we are using mylibrary::my_add_op, so the function
15    # name must be my_add_op; otherwise, it will throw an error
16    # when this script is run
17    pass
18
19@my_add_op.impl_abstract()
20def my_add_op_impl_abstract_any_name_is_ok(x, y):
21    return torch.empty_like(x)
22
23@my_add_op.impl("cpu")
24def my_add_op_impl_any_name_is_ok(tensor_x):
25    return torch.round(tensor_x + tensor_x)  # add x to itself, and round the result
26
27class CustomFoo(torch.nn.Module):
28    def forward(self, x, y):
29        return my_add_op(x, y)
30
31
32custom_opset = onnxscript.values.Opset(domain="com.k2fsa.org", version=1)
33
34@onnxscript.script(custom_opset)
35def custom_my_add(x, y):
36    return custom_opset.CustomOpOne(x, y)
37
38def main():
39    torch._dynamo.allow_in_graph(my_add_op)
40    x = torch.randn(3)
41    y = torch.randn(3)
42    custom_addandround_model = CustomFoo()
43    onnx_registry = torch.onnx.OnnxRegistry()
44    onnx_registry.register_op(
45        namespace="mylibrary", op_name="my_add_op", overload="default", function=custom_my_add
46        )
47
48    export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
49    onnx_program = torch.onnx.dynamo_export(
50        custom_addandround_model, x, y, export_options=export_options
51        )
52    onnx_program.save("./e.onnx")
53    with open('e.txt', 'w') as f:
54        f.write(str(onnx_program.model_proto))
55    onnx_model = onnx.load("e.onnx")
56    onnx.checker.check_model(onnx_model)
57
58
59if __name__ == '__main__':
60    main()
./code/e.txt
 1ir_version: 8
 2opset_import {
 3  domain: "com.k2fsa.org"
 4  version: 1
 5}
 6opset_import {
 7  domain: ""
 8  version: 18
 9}
10opset_import {
11  domain: "pkg.onnxscript.torch_lib.common"
12  version: 1
13}
14producer_name: "pytorch"
15producer_version: "2.4.0"
16graph {
17  node {
18    input: "l_x_"
19    input: "l_y_"
20    output: "my_add_op"
21    name: "custom_my_add_0_n0"
22    op_type: "CustomOpOne"
23    domain: "com.k2fsa.org"
24  }
25  name: "main_graph"
26  input {
27    name: "l_x_"
28    type {
29      tensor_type {
30        elem_type: 1
31        shape {
32          dim {
33            dim_value: 3
34          }
35        }
36      }
37    }
38  }
39  input {
40    name: "l_y_"
41    type {
42      tensor_type {
43        elem_type: 1
44        shape {
45          dim {
46            dim_value: 3
47          }
48        }
49      }
50    }
51  }
52  output {
53    name: "my_add_op"
54    type {
55      tensor_type {
56        elem_type: 1
57        shape {
58          dim {
59            dim_value: 3
60          }
61        }
62      }
63    }
64  }
65}
./code/custom-op-2.cc
 1/*
 2references:
 3https://onnxruntime.ai/docs/reference/operators/add-custom-op.html
 4
 5 */
 6#include "onnxruntime_lite_custom_op.h"
 7#include <iostream>
 8#include <utility>
 9#include <vector>
10
11static void KernelOne(const Ort::Custom::Tensor<float> &X,
12                      const Ort::Custom::Tensor<float> &Y,
13                      Ort::Custom::Tensor<float> &Z) {
14  auto input_shape = X.Shape();
15  auto x_raw = X.Data();
16  auto y_raw = Y.Data();
17  auto z_raw = Z.Allocate(input_shape);
18  for (int64_t i = 0; i < Z.NumberOfElement(); ++i) {
19    z_raw[i] = x_raw[i] + y_raw[i];
20  }
21}
22
23static Ort::CustomOpDomain TestCustomOp2() {
24  Ort::CustomOpDomain v1_domain{"com.k2fsa.org"};
25  // please make sure that custom_op_one has the same lifetime as the consuming
26  // session
27  //
28  // Here we use a static variable so it is never released.
29  // in practice, we can move it to a member variable of a class
30  static std::unique_ptr<Ort::Custom::OrtLiteCustomOp> custom_op_one{
31      Ort::Custom::CreateLiteCustomOp("CustomOpOne2", "CPUExecutionProvider",
32                                      KernelOne)};
33  v1_domain.Add(custom_op_one.get());
34
35  return v1_domain;
36}
37
38void TestCustomModel2() {
39  Ort::Env env;
40  Ort::SessionOptions sess_opts;
41  sess_opts.SetIntraOpNumThreads(1);
42  sess_opts.SetInterOpNumThreads(1);
43
44  Ort::CustomOpDomain v1_domain = TestCustomOp2();
45
46  Ort::SessionOptions session_options;
47  sess_opts.Add(v1_domain);
48  // create a session with the session_options ...
49
50  std::unique_ptr<Ort::Session> sess =
51      std::make_unique<Ort::Session>(env, "./f.onnx", sess_opts);
52
53  auto memory_info =
54      Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
55  // foo is [1.5, 2.5, 3.5, 4.5]
56  //
57
58  std::vector<float> x = {10, -20, -30, 40};
59
60  std::array<int64_t, 1> shape = {4};
61
62  Ort::Value x_tensor = Ort::Value::CreateTensor(
63      memory_info, x.data(), x.size(), shape.data(), shape.size());
64
65  std::vector<Ort::Value> inputs;
66  inputs.push_back(std::move(x_tensor));
67
68  std::vector<const char *> input_names = {"l_x_"};
69  std::vector<const char *> output_names = {"my_add_op2"};
70  auto out = sess->Run({}, input_names.data(), inputs.data(), inputs.size(),
71                       output_names.data(), output_names.size());
72  const float *p = out[0].GetTensorData<float>();
73  for (int i = 0; i < 4; ++i) {
74    std::cout << p[i] << "\n ";
75  }
76}
./code/f.py
 1#!/usr/bin/env python3
 2
 3from torch._custom_op import impl as custom_op
 4import torch
 5import onnx
 6import onnxscript
 7from onnxscript import opset18
 8
 9import warnings
10warnings.filterwarnings("ignore")
11
12@custom_op.custom_op("mylibrary::my_add_op2")
13def my_add_op2(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
14    # Since we are using mylibrary::my_add_op2, so the function
15    # name must be my_add_op2; otherwise, it will throw an error
16    # when this script is run
17    pass
18
19@my_add_op2.impl_abstract()
20def my_add_op_impl_abstract_any_name_is_ok(x, y):
21    return torch.empty_like(x)
22
23@my_add_op2.impl("cpu")
24def my_add_op_impl_any_name_is_ok(x, y):
25    x = torch.nn.functional.relu(x)
26    return x+y   # add x to itself, and round the result
27
28class CustomFoo(torch.nn.Module):
29    def __init__(self):
30        super().__init__()
31        self.register_parameter('foo', torch.nn.Parameter(torch.tensor([1.5, 2.5, 3.5, 4.5])))
32        self.relu = torch.nn.ReLU()
33    def forward(self, x):
34        x = self.relu(x)
35        return my_add_op2(x, self.foo)
36
37
38custom_opset = onnxscript.values.Opset(domain="com.k2fsa.org", version=1)
39
40@onnxscript.script(custom_opset)
41def custom_my_add(x, y):
42    return custom_opset.CustomOpOne2(x, y)
43
44@torch.no_grad()
45def main():
46    torch._dynamo.allow_in_graph(my_add_op2)
47    x = torch.randn(4)
48    custom_addandround_model = CustomFoo()
49    onnx_registry = torch.onnx.OnnxRegistry()
50    onnx_registry.register_op(
51        namespace="mylibrary", op_name="my_add_op2", overload="default", function=custom_my_add
52        )
53
54    export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
55    onnx_program = torch.onnx.dynamo_export(
56        custom_addandround_model, x,  export_options=export_options
57        )
58    onnx_program.save("./f.onnx")
59    with open('f.txt', 'w') as f:
60        f.write(str(onnx_program.model_proto))
61    onnx_model = onnx.load("f.onnx")
62    onnx.checker.check_model(onnx_model)
63
64
65if __name__ == '__main__':
66    main()
./code/f.txt
 1ir_version: 8
 2opset_import {
 3  domain: "pkg.onnxscript.torch_lib"
 4  version: 1
 5}
 6opset_import {
 7  domain: "pkg.torch.2.4.0+cpu"
 8  version: 1
 9}
10opset_import {
11  domain: "com.k2fsa.org"
12  version: 1
13}
14opset_import {
15  domain: ""
16  version: 18
17}
18opset_import {
19  domain: "pkg.onnxscript.torch_lib.common"
20  version: 1
21}
22producer_name: "pytorch"
23producer_version: "2.4.0"
24graph {
25  node {
26    input: "l_x_"
27    output: "relu_1"
28    name: "torch_nn_modules_activation_ReLU_relu_1_0_aten_relu_0_n0"
29    op_type: "Relu"
30  }
31  node {
32    input: "relu_1"
33    input: "foo"
34    output: "my_add_op2"
35    name: "custom_my_add_1_n0"
36    op_type: "CustomOpOne2"
37    domain: "com.k2fsa.org"
38  }
39  name: "main_graph"
40  initializer {
41    dims: 4
42    data_type: 1
43    name: "foo"
44    raw_data: "\000\000\300?\000\000 @\000\000`@\000\000\220@"
45  }
46  input {
47    name: "l_x_"
48    type {
49      tensor_type {
50        elem_type: 1
51        shape {
52          dim {
53            dim_value: 4
54          }
55        }
56      }
57    }
58  }
59  output {
60    name: "my_add_op2"
61    type {
62      tensor_type {
63        elem_type: 1
64        shape {
65          dim {
66            dim_value: 4
67          }
68        }
69      }
70    }
71  }
72  value_info {
73    name: "relu_1"
74    type {
75      tensor_type {
76        elem_type: 1
77        shape {
78          dim {
79            dim_value: 4
80          }
81        }
82      }
83    }
84  }
85}
./code/custom-op-3.cc
 1/*
 2references:
 3https://onnxruntime.ai/docs/reference/operators/add-custom-op.html
 4
 5 */
 6#include "onnxruntime_lite_custom_op.h"
 7#include <iostream>
 8#include <utility>
 9#include <vector>
10
11static void KernelOne(const Ort::Custom::Tensor<uint8_t> &X,
12                      const Ort::Custom::Tensor<float> &scale_tensor,
13                      Ort::Custom::Tensor<float> &Y) {
14  auto input_shape = X.Shape();
15  auto x_raw = X.Data();
16  auto scale = scale_tensor.Data()[0];
17  auto y_raw = Y.Allocate(input_shape);
18  for (int64_t i = 0; i < Y.NumberOfElement(); ++i) {
19
20    // scale each uint8 number
21    y_raw[i] = x_raw[i] * scale;
22  }
23}
24
25static Ort::CustomOpDomain TestCustomOp3() {
26  Ort::CustomOpDomain v1_domain{"com.k2fsa.org"};
27  // please make sure that custom_op_one has the same lifetime as the consuming
28  // session
29  //
30  // Here we use a static variable so it is never released.
31  // in practice, we can move it to a member variable of a class
32  static std::unique_ptr<Ort::Custom::OrtLiteCustomOp> custom_op_one{
33      Ort::Custom::CreateLiteCustomOp("MyCast", "CPUExecutionProvider",
34                                      KernelOne)};
35  v1_domain.Add(custom_op_one.get());
36
37  return v1_domain;
38}
39
40void TestCustomModel3() {
41  Ort::Env env;
42  Ort::SessionOptions sess_opts;
43  sess_opts.SetIntraOpNumThreads(1);
44  sess_opts.SetInterOpNumThreads(1);
45
46  Ort::CustomOpDomain v1_domain = TestCustomOp3();
47
48  Ort::SessionOptions session_options;
49  sess_opts.Add(v1_domain);
50  // create a session with the session_options ...
51
52  std::unique_ptr<Ort::Session> sess =
53      std::make_unique<Ort::Session>(env, "./g.onnx", sess_opts);
54
55  auto memory_info =
56      Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
57  // foo is [10, 20]
58  //
59
60  std::vector<float> x = {1, 2};
61
62  std::array<int64_t, 1> shape = {2};
63
64  Ort::Value x_tensor = Ort::Value::CreateTensor(
65      memory_info, x.data(), x.size(), shape.data(), shape.size());
66
67  std::vector<Ort::Value> inputs;
68  inputs.push_back(std::move(x_tensor));
69
70  std::vector<const char *> input_names = {"l_x_"};
71  std::vector<const char *> output_names = {"add"};
72  auto out = sess->Run({}, input_names.data(), inputs.data(), inputs.size(),
73                       output_names.data(), output_names.size());
74  const float *p = out[0].GetTensorData<float>();
75  for (int i = 0; i < 2; ++i) {
76    std::cout << p[i] << "\n ";
77  }
78}
./code/g.py
 1#!/usr/bin/env python3
 2
 3from torch._custom_op import impl as custom_op
 4import torch
 5import onnx
 6import onnxscript
 7from onnxscript import opset18
 8
 9import warnings
10warnings.filterwarnings("ignore")
11
12@custom_op.custom_op("mylibrary::my_cast")
13def my_cast(x: torch.Tensor, scale: float = 0.25) -> torch.Tensor:
14    # Since we are using mylibrary::my_cast, so the function
15    # name must be my_cast; otherwise, it will throw an error
16    # when this script is run
17    return x.to(torch.float32)
18
19@my_cast.impl_abstract()
20def my_cast_impl_abstract_any_name_is_ok(x, scale: float = 0.25):
21    return x.to(torch.float32)
22
23@my_cast.impl("cpu")
24def my_cast_impl_any_name_is_ok(x, scale: float = 0.25):
25    return x.to(torch.float32)   # add x to itself, and round the result
26
27class CustomFoo(torch.nn.Module):
28    def __init__(self):
29        super().__init__()
30        self.register_buffer('foo', torch.tensor([10, 20], dtype=torch.uint8))
31        self.scale = 0.125
32
33    def forward(self, x):
34        return x + my_cast(self.foo, self.scale)
35
36
37custom_opset = onnxscript.values.Opset(domain="com.k2fsa.org", version=1)
38
39@onnxscript.script(custom_opset, default_opset=opset18)
40def custom_my_cast(x, scale: float = 0.5):
41    return custom_opset.MyCast(x, scale)
42
43@torch.no_grad()
44def main():
45    torch._dynamo.allow_in_graph(my_cast)
46    x = torch.randn(2)
47    custom_addandround_model = CustomFoo()
48    onnx_registry = torch.onnx.OnnxRegistry()
49    onnx_registry.register_op(
50        namespace="mylibrary", op_name="my_cast", overload="default", function=custom_my_cast
51        )
52
53    export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
54    onnx_program = torch.onnx.dynamo_export(
55        custom_addandround_model, x,  export_options=export_options
56        )
57    onnx_program.save("./g.onnx")
58    with open('g.txt', 'w') as f:
59        f.write(str(onnx_program.model_proto))
60    onnx_model = onnx.load("g.onnx")
61    onnx.checker.check_model(onnx_model)
62
63
64if __name__ == '__main__':
65    main()
./code/g.txt
  1ir_version: 8
  2opset_import {
  3  domain: "com.k2fsa.org"
  4  version: 1
  5}
  6opset_import {
  7  domain: "pkg.onnxscript.torch_lib"
  8  version: 1
  9}
 10opset_import {
 11  domain: ""
 12  version: 18
 13}
 14opset_import {
 15  domain: "pkg.onnxscript.torch_lib.common"
 16  version: 1
 17}
 18producer_name: "pytorch"
 19producer_version: "2.4.0"
 20graph {
 21  node {
 22    output: "custom_my_cast_0_scale"
 23    name: "custom_my_cast_0_n0"
 24    op_type: "Constant"
 25    attribute {
 26      name: "value_float"
 27      type: FLOAT
 28      f: 0.125
 29    }
 30  }
 31  node {
 32    input: "foo"
 33    input: "custom_my_cast_0_scale"
 34    output: "my_cast"
 35    name: "custom_my_cast_0_n1"
 36    op_type: "MyCast"
 37    domain: "com.k2fsa.org"
 38  }
 39  node {
 40    input: "l_x_"
 41    input: "my_cast"
 42    output: "add"
 43    name: "aten_add_1"
 44    op_type: "aten_add"
 45    domain: "pkg.onnxscript.torch_lib"
 46    attribute {
 47      name: "alpha"
 48      type: FLOAT
 49      f: 1
 50    }
 51  }
 52  name: "main_graph"
 53  initializer {
 54    dims: 2
 55    data_type: 2
 56    name: "foo"
 57    raw_data: "\n\024"
 58  }
 59  input {
 60    name: "l_x_"
 61    type {
 62      tensor_type {
 63        elem_type: 1
 64        shape {
 65          dim {
 66            dim_value: 2
 67          }
 68        }
 69      }
 70    }
 71  }
 72  output {
 73    name: "add"
 74    type {
 75      tensor_type {
 76        elem_type: 1
 77        shape {
 78          dim {
 79            dim_value: 2
 80          }
 81        }
 82      }
 83    }
 84  }
 85  value_info {
 86    name: "custom_my_cast_0_scale"
 87    type {
 88      tensor_type {
 89        elem_type: 1
 90        shape {
 91        }
 92      }
 93    }
 94  }
 95  value_info {
 96    name: "my_cast"
 97    type {
 98      tensor_type {
 99        elem_type: 1
100        shape {
101          dim {
102            dim_value: 2
103          }
104        }
105      }
106    }
107  }
108  value_info {
109    name: "pkg.onnxscript.torch_lib::aten_add/self"
110    type {
111      tensor_type {
112        elem_type: 1
113        shape {
114          dim {
115            dim_value: 2
116          }
117        }
118      }
119    }
120  }
121  value_info {
122    name: "pkg.onnxscript.torch_lib::aten_add/other"
123    type {
124      tensor_type {
125        elem_type: 1
126        shape {
127          dim {
128            dim_value: 2
129          }
130        }
131      }
132    }
133  }
134  value_info {
135    name: "pkg.onnxscript.torch_lib::aten_add/alpha"
136    type {
137      tensor_type {
138        elem_type: 1
139        shape {
140        }
141      }
142    }
143  }
144  value_info {
145    name: "pkg.onnxscript.torch_lib::aten_add/other_1"
146    type {
147      tensor_type {
148        elem_type: 1
149        shape {
150          dim {
151            dim_value: 2
152          }
153        }
154      }
155    }
156  }
157  value_info {
158    name: "pkg.onnxscript.torch_lib::aten_add/return_val"
159    type {
160      tensor_type {
161        elem_type: 1
162        shape {
163          dim {
164            dim_value: 2
165          }
166        }
167      }
168    }
169  }
170}
171functions {
172  name: "aten_add"
173  input: "self"
174  input: "other"
175  output: "return_val"
176  attribute_proto {
177    name: "alpha"
178    type: FLOAT
179    f: 1
180  }
181  node {
182    output: "alpha"
183    name: "n0"
184    op_type: "Constant"
185    attribute {
186      name: "value_float"
187      ref_attr_name: "alpha"
188      type: FLOAT
189    }
190  }
191  node {
192    input: "other"
193    input: "alpha"
194    output: "other_1"
195    name: "n2"
196    op_type: "Mul"
197  }
198  node {
199    input: "self"
200    input: "other_1"
201    output: "return_val"
202    name: "n3"
203    op_type: "Add"
204  }
205  doc_string: "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
206  opset_import {
207    domain: ""
208    version: 18
209  }
210  domain: "pkg.onnxscript.torch_lib"
211}