Node

./code/node/main.cc
  1#include "torch/csrc/jit/passes/quantization/helper.h" // for removeTorchMangle
  2#include "torch/script.h"
  3
  4static void TestRemoveTorchMangle() {
  5  std::string s = torch::jit::removeTorchMangle("a.___torch_mangle_1.foo");
  6  TORCH_CHECK(s == "a.foo");
  7
  8  s = torch::jit::removeTorchMangle("a.___torch_mangle_123.foo");
  9  TORCH_CHECK(s == "a.foo");
 10}
 11
 12static void TestSimple() {
 13  torch::jit::Module m("m");
 14  m.define(R"(
 15    def forward(self, x: torch.Tensor, y: torch.Tensor):
 16      a = x + 2
 17      b = y * 3
 18      return a + b
 19  )");
 20  std::shared_ptr<torch::jit::Graph> graph = m.get_method("forward").graph();
 21  std::cout << "graph string: \n" << graph->toString() << "\n";
 22  // Or we can use graph->dump();
 23  torch::jit::Block *block = graph->block();
 24  for (auto it = block->nodes().begin(), end = block->nodes().end();
 25       it != end;) {
 26    torch::jit::Node *n = *it++;
 27    torch::jit::NodeKind k = n->kind();
 28    std::cout << "node kind: " << k << " " << k.toQualString() << "\n";
 29  }
 30#if 0
 31graph string:
 32graph(%self : __torch__.m,
 33      %x.1 : Tensor,
 34      %y.1 : Tensor):
 35  %5 : int = prim::Constant[value=1]()
 36  %4 : int = prim::Constant[value=2]() # <string>:3:14
 37  %8 : int = prim::Constant[value=3]() # <string>:4:14
 38  %a.1 : Tensor = aten::add(%x.1, %4, %5) # <string>:3:10
 39  %b.1 : Tensor = aten::mul(%y.1, %8) # <string>:4:10
 40  %13 : Tensor = aten::add(%a.1, %b.1, %5) # <string>:5:13
 41  return (%13)
 42
 43node kind: 14 prim::Constant
 44node kind: 14 prim::Constant
 45node kind: 14 prim::Constant
 46node kind: 534 aten::add
 47node kind: 241 aten::mul
 48node kind: 534 aten::add
 49#endif
 50}
 51
 52static void TestFunctionCall() {
 53  torch::jit::Module m("m");
 54  m.define(R"(
 55    def add(self, x: torch.Tensor, y: torch.Tensor):
 56      '''my add doc'''
 57      return x + y + 3
 58
 59    def forward(self, x: torch.Tensor, y: torch.Tensor):
 60      c = self.add(x, y)
 61      return c
 62  )");
 63  std::shared_ptr<torch::jit::Graph> graph = m.get_method("forward").graph();
 64  std::cout << "graph string: \n" << graph->toString() << "\n";
 65  torch::jit::Block *block = graph->block();
 66  for (auto it = block->nodes().begin(), end = block->nodes().end();
 67       it != end;) {
 68    torch::jit::Node *n = *it++;
 69    torch::jit::NodeKind k = n->kind();
 70    std::cout << "node kind: " << k << " " << k.toQualString() << "\n";
 71  }
 72#if 0
 73graph string:
 74graph(%self.1 : __torch__.m,
 75      %x.1 : Tensor,
 76      %y.1 : Tensor):
 77  %c.1 : Tensor = prim::CallMethod[name="add"](%self.1, %x.1, %y.1) # <string>:6:10
 78  return (%c.1)
 79
 80node kind: 149 prim::CallMethod
 81#endif
 82  for (auto it = block->nodes().begin(), end = block->nodes().end();
 83       it != end;) {
 84    torch::jit::Node *n = *it++;
 85    torch::jit::NodeKind k = n->kind();
 86    if (k == c10::prim::CallMethod) {
 87      torch::ArrayRef<torch::jit::Value *> inputs = n->inputs();
 88      TORCH_CHECK(inputs.size() == 3);
 89
 90      torch::jit::TypePtr type = inputs[0]->type();
 91
 92      auto class_type = type->cast<torch::jit::ClassType>();
 93      TORCH_CHECK(class_type->str() == "__torch__.m");
 94      if (!class_type) {
 95        std::cout << "Not a class type: " << type->str() << "\n";
 96        continue;
 97      }
 98      // defined by the macro "CREATE_ACCESSOR()" in ir/ir.h
 99      const std::string &function_name = n->s(c10::attr::name);
100      // const std::string &function_name = n->s(torch::jit::attr::name);
101      TORCH_CHECK(function_name == "add");
102
103      TORCH_CHECK(torch::jit::attr::name == c10::attr::name);
104
105      torch::jit::Function &function = class_type->getMethod(function_name);
106      if (!function.isGraphFunction()) {
107        std::cout << function_name << " is not a graph function"
108                  << "\n";
109        continue;
110      }
111      std::string class_type_str =
112          torch::jit::removeTorchMangle(class_type->str());
113      // remove __torch__., which is 10 characters long
114      std::string no_torch_class_type_str = class_type_str.substr(10);
115    }
116  }
117}
118
119int main() {
120  // TestRemoveTorchMangle();
121  // TestSimple();
122  TestFunctionCall();
123  return 0;
124}