Multiple models
./code/multiple-models/ex.py
1#!/usr/bin/env python3
2
3import torch
4import torch.nn as nn
5import onnx
6import onnxruntime as ort
7import numpy as np
8import os
9
10
11class Foo(nn.Module):
12 def forward(self, x):
13 return x + 1
14
15
16class Bar(nn.Module):
17 def forward(self, x):
18 return x - 1
19
20
21def export_to_onnx():
22 x = torch.rand(2, 3, dtype=torch.float32)
23 f = Foo()
24 torch.onnx.export(
25 f,
26 x,
27 "f.onnx",
28 verbose=False,
29 input_names=["x1"],
30 output_names=["y1"],
31 dynamic_axes={
32 "x1": {0: "N", 1: "T"},
33 "y1": {0: "N", 1: "T"},
34 },
35 )
36
37 x = torch.rand(1, dtype=torch.float32)
38 b = Bar()
39 torch.onnx.export(
40 b,
41 x,
42 "b.onnx",
43 verbose=False,
44 input_names=["x2"],
45 output_names=["y2"],
46 dynamic_axes={
47 "x2": {0: "N"},
48 "y2": {0: "N"},
49 },
50 )
51
52
53def merge_models():
54 f = onnx.load("f.onnx")
55 f = onnx.compose.add_prefix(f, prefix="f/")
56 b = onnx.load("b.onnx")
57 combined_model = onnx.compose.merge_models(f, b, io_map={})
58 onnx.save(combined_model, "all.onnx")
59
60
61def test_merged_model():
62 # https://github.com/microsoft/onnxruntime/issues/10113
63 options = ort.SessionOptions()
64 options.inter_op_num_threads = 1
65 options.intra_op_num_threads = 1
66
67 all_model = onnx.load("all.onnx")
68
69 extractor = onnx.utils.Extractor(all_model)
70
71 f = extractor.extract_model(input_names=["f/x1"], output_names=["f/y1"])
72 f_session = ort.InferenceSession(f.SerializeToString(), sess_options=options)
73 f_inputs = f_session.get_inputs()
74 f_out = f_session.run(["f/y1"], {"f/x1": np.array([[1, 3]], dtype=np.float32)})
75 print(f_out[0]) # [[2. 4.]]
76
77 b = extractor.extract_model(input_names=["x2"], output_names=["y2"])
78 b_session = ort.InferenceSession(b.SerializeToString(), sess_options=options)
79 b_inputs = b_session.get_inputs()
80 b_out = b_session.run(["y2"], {"x2": np.array([1, 3], dtype=np.float32)})
81 print(b_out[0]) # [0. 2.]
82
83
84def main():
85 export_to_onnx()
86 merge_models()
87 test_merged_model()
88 os.remove("f.onnx")
89 os.remove("b.onnx")
90 os.remove("all.onnx")
91
92
93if __name__ == "__main__":
94 main()
We can first merge multiple models into one and the extract them.