Node
./code/node/main.cc
1#include "torch/csrc/jit/passes/quantization/helper.h" // for removeTorchMangle
2#include "torch/script.h"
3
4static void TestRemoveTorchMangle() {
5 std::string s = torch::jit::removeTorchMangle("a.___torch_mangle_1.foo");
6 TORCH_CHECK(s == "a.foo");
7
8 s = torch::jit::removeTorchMangle("a.___torch_mangle_123.foo");
9 TORCH_CHECK(s == "a.foo");
10}
11
12static void TestSimple() {
13 torch::jit::Module m("m");
14 m.define(R"(
15 def forward(self, x: torch.Tensor, y: torch.Tensor):
16 a = x + 2
17 b = y * 3
18 return a + b
19 )");
20 std::shared_ptr<torch::jit::Graph> graph = m.get_method("forward").graph();
21 std::cout << "graph string: \n" << graph->toString() << "\n";
22 // Or we can use graph->dump();
23 torch::jit::Block *block = graph->block();
24 for (auto it = block->nodes().begin(), end = block->nodes().end();
25 it != end;) {
26 torch::jit::Node *n = *it++;
27 torch::jit::NodeKind k = n->kind();
28 std::cout << "node kind: " << k << " " << k.toQualString() << "\n";
29 }
30#if 0
31graph string:
32graph(%self : __torch__.m,
33 %x.1 : Tensor,
34 %y.1 : Tensor):
35 %5 : int = prim::Constant[value=1]()
36 %4 : int = prim::Constant[value=2]() # <string>:3:14
37 %8 : int = prim::Constant[value=3]() # <string>:4:14
38 %a.1 : Tensor = aten::add(%x.1, %4, %5) # <string>:3:10
39 %b.1 : Tensor = aten::mul(%y.1, %8) # <string>:4:10
40 %13 : Tensor = aten::add(%a.1, %b.1, %5) # <string>:5:13
41 return (%13)
42
43node kind: 14 prim::Constant
44node kind: 14 prim::Constant
45node kind: 14 prim::Constant
46node kind: 534 aten::add
47node kind: 241 aten::mul
48node kind: 534 aten::add
49#endif
50}
51
52static void TestFunctionCall() {
53 torch::jit::Module m("m");
54 m.define(R"(
55 def add(self, x: torch.Tensor, y: torch.Tensor):
56 '''my add doc'''
57 return x + y + 3
58
59 def forward(self, x: torch.Tensor, y: torch.Tensor):
60 c = self.add(x, y)
61 return c
62 )");
63 std::shared_ptr<torch::jit::Graph> graph = m.get_method("forward").graph();
64 std::cout << "graph string: \n" << graph->toString() << "\n";
65 torch::jit::Block *block = graph->block();
66 for (auto it = block->nodes().begin(), end = block->nodes().end();
67 it != end;) {
68 torch::jit::Node *n = *it++;
69 torch::jit::NodeKind k = n->kind();
70 std::cout << "node kind: " << k << " " << k.toQualString() << "\n";
71 }
72#if 0
73graph string:
74graph(%self.1 : __torch__.m,
75 %x.1 : Tensor,
76 %y.1 : Tensor):
77 %c.1 : Tensor = prim::CallMethod[name="add"](%self.1, %x.1, %y.1) # <string>:6:10
78 return (%c.1)
79
80node kind: 149 prim::CallMethod
81#endif
82 for (auto it = block->nodes().begin(), end = block->nodes().end();
83 it != end;) {
84 torch::jit::Node *n = *it++;
85 torch::jit::NodeKind k = n->kind();
86 if (k == c10::prim::CallMethod) {
87 torch::ArrayRef<torch::jit::Value *> inputs = n->inputs();
88 TORCH_CHECK(inputs.size() == 3);
89
90 torch::jit::TypePtr type = inputs[0]->type();
91
92 auto class_type = type->cast<torch::jit::ClassType>();
93 TORCH_CHECK(class_type->str() == "__torch__.m");
94 if (!class_type) {
95 std::cout << "Not a class type: " << type->str() << "\n";
96 continue;
97 }
98 // defined by the macro "CREATE_ACCESSOR()" in ir/ir.h
99 const std::string &function_name = n->s(c10::attr::name);
100 // const std::string &function_name = n->s(torch::jit::attr::name);
101 TORCH_CHECK(function_name == "add");
102
103 TORCH_CHECK(torch::jit::attr::name == c10::attr::name);
104
105 torch::jit::Function &function = class_type->getMethod(function_name);
106 if (!function.isGraphFunction()) {
107 std::cout << function_name << " is not a graph function"
108 << "\n";
109 continue;
110 }
111 std::string class_type_str =
112 torch::jit::removeTorchMangle(class_type->str());
113 // remove __torch__., which is 10 characters long
114 std::string no_torch_class_type_str = class_type_str.substr(10);
115 }
116 }
117}
118
119int main() {
120 // TestRemoveTorchMangle();
121 // TestSimple();
122 TestFunctionCall();
123 return 0;
124}