graph

./code/graph/main.cc
 1#include "torch/script.h"
 2
 3static void TestConv2d() {
 4  torch::jit::Module m("m");
 5  m.define(R"(
 6    def __init__(self):
 7      self.conv = torch.nn.Conv2d(2, 3)
 8    def forward(self, x: torch.Tensor):
 9      return self.conv(x)
10  )");
11  torch::jit::Method method = m.get_method("forward");
12  std::shared_ptr<torch::jit::Graph> g = method.graph();
13  torch::ArrayRef<torch::jit::Value *> inputs = g->inputs();
14  torch::ArrayRef<torch::jit::Value *> outputs = g->outputs();
15  TORCH_CHECK(inputs.size() == 1);
16  TORCH_CHECK(outputs.size() == 1);
17
18  torch::jit::Value *in = inputs[0];
19  std::cout << in->type()->str() << "\n";
20  std::cout << in->debugName() << "\n";
21}
22
23int main() {
24  TestConv2d();
25  return 0;
26}
./code/graph/inline_calls.py
 1#!/usr/bin/env python3
 2
 3from pathlib import Path
 4
 5import torch
 6import torch.nn as nn
 7
 8
 9class Foo(nn.Module):
10    def __init__(self):
11        super().__init__()
12        self.linear = nn.Linear(2, 2)
13        self.linear2 = nn.Linear(2, 2)
14        self.relu = nn.ReLU()
15        self.t = torch.rand(2)
16
17    def forward(self, x: torch.Tensor):
18        y = self.linear(x + self.t)
19        y = self.linear2(y)
20        y = self.linear2(y)
21        #  z = self.relu(y)
22        return nn.functional.elu(y)
23        return z
24
25
26def generate_foo_pt():
27    f = Foo()
28    x = torch.rand(1, 2)
29    m = torch.jit.trace(f, x)
30    m.save("foo.pt")
31
32
33def test_foo_pt():
34    m = torch.jit.load("foo.pt")
35    assert isinstance(m.forward, torch._C.ScriptMethod)
36    assert isinstance(m.forward.graph, torch._C.Graph)
37    assert isinstance(m.forward.inlined_graph, torch._C.Graph)
38
39    print(m.linear.graph)
40    return
41
42    print(m.forward.graph)
43    #  print(m.forward.inlined_graph)
44    g = m.forward.graph
45    nodes = g.nodes()
46
47    n = next(nodes)
48    print(dir(n))
49    assert n.kind() == "prim::GetAttr"
50    for i in n.inputs():
51        assert isinstance(i, torch._C.Value)
52        assert i.debugName() == "self.1"
53        assert isinstance(i.type(), torch._C.ClassType)
54        t = i.type()
55        assert t.str() == "__torch__.Foo"
56
57
58def main():
59    generate_foo_pt()
60    #  test_foo_pt()
61
62
63if __name__ == "__main__":
64    main()