Hello
./code/hello/ex0.py
1#!/usr/bin/env python3
2
3import torch
4import torch.nn as nn
5
6
7class Foo(nn.Module):
8 def __init__(self, i):
9 super().__init__()
10 self.relu = nn.ReLU()
11 self.i = 1
12
13 def forward(self, x):
14 if x.sum().item() > 0:
15 return self.relu(x + 1)
16 else:
17 return self.relu(x + 2)
18
19
20def main():
21 f = Foo(1)
22 f.eval() # f.train(False)
23 f = torch.jit.script(f)
24
25 x = torch.rand(2, 3, 4)
26 # [N, T, C]
27 torch.onnx.export(
28 f,
29 x,
30 "f.onnx",
31 verbose=False,
32 input_names=["x"],
33 output_names=["y"],
34 dynamic_axes={"x": {0: "batch_size", 1: "T"}, "y": [0, 1]},
35 # dynamic_axes={"x": [0, 1], "y": [0, 1]},
36 )
37
38
39if __name__ == "__main__":
40 main()
./code/hello/ex0-1.py
1#!/usr/bin/env python3
2
3import onnx
4
5
6def main():
7 model = onnx.load("f.onnx")
8 # print(model)
9 # Check that the model is well formed
10 onnx.checker.check_model(model)
11 # Print a human readable representation of the graph
12 print(onnx.helper.printable_graph(model.graph))
13 onnx.save(model, "f2.onnx")
14
15
16if __name__ == "__main__":
17 main()
./code/hello/ex0-2.py
1#!/usr/bin/env python3
2
3import onnxruntime as ort
4import numpy as np
5
6
7def main():
8 # https://github.com/microsoft/onnxruntime/issues/10113
9 options = ort.SessionOptions()
10 options.inter_op_num_threads = 1
11 options.intra_op_num_threads = 1
12
13 ort_session = ort.InferenceSession("f.onnx", sess_options=options)
14
15 x = np.arange(24).reshape(2, 3, 4).astype(np.float32)
16 ortvalue = ort.OrtValue.ortvalue_from_numpy(x)
17 assert ortvalue.device_name() == "cpu"
18 assert list(ortvalue.shape()) == list(x.shape)
19 assert ortvalue.data_type() == "tensor(float)"
20 assert ortvalue.is_tensor() is True
21
22 results = ort_session.run(["y"], {"x": ortvalue})
23 print(results)
24
25 ort_inputs = {ort_session.get_inputs()[0].name: x}
26 results = ort_session.run(["y"], ort_inputs)
27 print(results)
28
29 results = ort_session.run(["y"], {"x": x})
30 print(results)
31
32 # https://onnxruntime.ai/docs/api/python/api_summary.html#onnxruntime.NodeArg
33 inputs = ort_session.get_inputs()
34 assert isinstance(inputs, list)
35 assert len(inputs) == 1
36 assert isinstance(inputs[0], ort.NodeArg)
37 print(inputs[0].name, inputs[0].type, inputs[0].shape)
38 assert inputs[0].name == "x"
39 assert inputs[0].type == "tensor(float)"
40 assert inputs[0].shape == ["batch_size", "T", 4]
41
42 outputs = ort_session.get_outputs()
43 assert isinstance(outputs, list)
44 assert isinstance(outputs[0], ort.NodeArg)
45 assert len(outputs) == 1
46 assert outputs[0].name == "y"
47 assert outputs[0].type == "tensor(float)"
48 assert outputs[0].shape == ["y_dynamic_axes_1", "y_dynamic_axes_2", 4]
49
50
51if __name__ == "__main__":
52 main()