Multiple models

./code/multiple-models/ex.py
 1#!/usr/bin/env python3
 2
 3import torch
 4import torch.nn as nn
 5import onnx
 6import onnxruntime as ort
 7import numpy as np
 8import os
 9
10
11class Foo(nn.Module):
12    def forward(self, x):
13        return x + 1
14
15
16class Bar(nn.Module):
17    def forward(self, x):
18        return x - 1
19
20
21def export_to_onnx():
22    x = torch.rand(2, 3, dtype=torch.float32)
23    f = Foo()
24    torch.onnx.export(
25        f,
26        x,
27        "f.onnx",
28        verbose=False,
29        input_names=["x1"],
30        output_names=["y1"],
31        dynamic_axes={
32            "x1": {0: "N", 1: "T"},
33            "y1": {0: "N", 1: "T"},
34        },
35    )
36
37    x = torch.rand(1, dtype=torch.float32)
38    b = Bar()
39    torch.onnx.export(
40        b,
41        x,
42        "b.onnx",
43        verbose=False,
44        input_names=["x2"],
45        output_names=["y2"],
46        dynamic_axes={
47            "x2": {0: "N"},
48            "y2": {0: "N"},
49        },
50    )
51
52
53def merge_models():
54    f = onnx.load("f.onnx")
55    f = onnx.compose.add_prefix(f, prefix="f/")
56    b = onnx.load("b.onnx")
57    combined_model = onnx.compose.merge_models(f, b, io_map={})
58    onnx.save(combined_model, "all.onnx")
59
60
61def test_merged_model():
62    # https://github.com/microsoft/onnxruntime/issues/10113
63    options = ort.SessionOptions()
64    options.inter_op_num_threads = 1
65    options.intra_op_num_threads = 1
66
67    all_model = onnx.load("all.onnx")
68
69    extractor = onnx.utils.Extractor(all_model)
70
71    f = extractor.extract_model(input_names=["f/x1"], output_names=["f/y1"])
72    f_session = ort.InferenceSession(f.SerializeToString(), sess_options=options)
73    f_inputs = f_session.get_inputs()
74    f_out = f_session.run(["f/y1"], {"f/x1": np.array([[1, 3]], dtype=np.float32)})
75    print(f_out[0])  # [[2. 4.]]
76
77    b = extractor.extract_model(input_names=["x2"], output_names=["y2"])
78    b_session = ort.InferenceSession(b.SerializeToString(), sess_options=options)
79    b_inputs = b_session.get_inputs()
80    b_out = b_session.run(["y2"], {"x2": np.array([1, 3], dtype=np.float32)})
81    print(b_out[0])  # [0. 2.]
82
83
84def main():
85    export_to_onnx()
86    merge_models()
87    test_merged_model()
88    os.remove("f.onnx")
89    os.remove("b.onnx")
90    os.remove("all.onnx")
91
92
93if __name__ == "__main__":
94    main()

We can first merge multiple models into one and the extract them.