Hello

./code/hello/ex0.py
 1#!/usr/bin/env python3
 2
 3import torch
 4import torch.nn as nn
 5
 6
 7class Foo(nn.Module):
 8    def __init__(self, i):
 9        super().__init__()
10        self.relu = nn.ReLU()
11        self.i = 1
12
13    def forward(self, x):
14        if x.sum().item() > 0:
15            return self.relu(x + 1)
16        else:
17            return self.relu(x + 2)
18
19
20def main():
21    f = Foo(1)
22    f.eval()  # f.train(False)
23    f = torch.jit.script(f)
24
25    x = torch.rand(2, 3, 4)
26    # [N, T, C]
27    torch.onnx.export(
28        f,
29        x,
30        "f.onnx",
31        verbose=False,
32        input_names=["x"],
33        output_names=["y"],
34        dynamic_axes={"x": {0: "batch_size", 1: "T"}, "y": [0, 1]},
35        #  dynamic_axes={"x": [0, 1], "y": [0, 1]},
36    )
37
38
39if __name__ == "__main__":
40    main()
./code/hello/ex0-1.py
 1#!/usr/bin/env python3
 2
 3import onnx
 4
 5
 6def main():
 7    model = onnx.load("f.onnx")
 8    #  print(model)
 9    # Check that the model is well formed
10    onnx.checker.check_model(model)
11    # Print a human readable representation of the graph
12    print(onnx.helper.printable_graph(model.graph))
13    onnx.save(model, "f2.onnx")
14
15
16if __name__ == "__main__":
17    main()
./code/hello/ex0-2.py
 1#!/usr/bin/env python3
 2
 3import onnxruntime as ort
 4import numpy as np
 5
 6
 7def main():
 8    # https://github.com/microsoft/onnxruntime/issues/10113
 9    options = ort.SessionOptions()
10    options.inter_op_num_threads = 1
11    options.intra_op_num_threads = 1
12
13    ort_session = ort.InferenceSession("f.onnx", sess_options=options)
14
15    x = np.arange(24).reshape(2, 3, 4).astype(np.float32)
16    ortvalue = ort.OrtValue.ortvalue_from_numpy(x)
17    assert ortvalue.device_name() == "cpu"
18    assert list(ortvalue.shape()) == list(x.shape)
19    assert ortvalue.data_type() == "tensor(float)"
20    assert ortvalue.is_tensor() is True
21
22    results = ort_session.run(["y"], {"x": ortvalue})
23    print(results)
24
25    ort_inputs = {ort_session.get_inputs()[0].name: x}
26    results = ort_session.run(["y"], ort_inputs)
27    print(results)
28
29    results = ort_session.run(["y"], {"x": x})
30    print(results)
31
32    # https://onnxruntime.ai/docs/api/python/api_summary.html#onnxruntime.NodeArg
33    inputs = ort_session.get_inputs()
34    assert isinstance(inputs, list)
35    assert len(inputs) == 1
36    assert isinstance(inputs[0], ort.NodeArg)
37    print(inputs[0].name, inputs[0].type, inputs[0].shape)
38    assert inputs[0].name == "x"
39    assert inputs[0].type == "tensor(float)"
40    assert inputs[0].shape == ["batch_size", "T", 4]
41
42    outputs = ort_session.get_outputs()
43    assert isinstance(outputs, list)
44    assert isinstance(outputs[0], ort.NodeArg)
45    assert len(outputs) == 1
46    assert outputs[0].name == "y"
47    assert outputs[0].type == "tensor(float)"
48    assert outputs[0].shape == ["y_dynamic_axes_1", "y_dynamic_axes_2", 4]
49
50
51if __name__ == "__main__":
52    main()