graph
./code/graph/main.cc
1#include "torch/script.h"
2
3static void TestConv2d() {
4 torch::jit::Module m("m");
5 m.define(R"(
6 def __init__(self):
7 self.conv = torch.nn.Conv2d(2, 3)
8 def forward(self, x: torch.Tensor):
9 return self.conv(x)
10 )");
11 torch::jit::Method method = m.get_method("forward");
12 std::shared_ptr<torch::jit::Graph> g = method.graph();
13 torch::ArrayRef<torch::jit::Value *> inputs = g->inputs();
14 torch::ArrayRef<torch::jit::Value *> outputs = g->outputs();
15 TORCH_CHECK(inputs.size() == 1);
16 TORCH_CHECK(outputs.size() == 1);
17
18 torch::jit::Value *in = inputs[0];
19 std::cout << in->type()->str() << "\n";
20 std::cout << in->debugName() << "\n";
21}
22
23int main() {
24 TestConv2d();
25 return 0;
26}
./code/graph/inline_calls.py
1#!/usr/bin/env python3
2
3from pathlib import Path
4
5import torch
6import torch.nn as nn
7
8
9class Foo(nn.Module):
10 def __init__(self):
11 super().__init__()
12 self.linear = nn.Linear(2, 2)
13 self.linear2 = nn.Linear(2, 2)
14 self.relu = nn.ReLU()
15 self.t = torch.rand(2)
16
17 def forward(self, x: torch.Tensor):
18 y = self.linear(x + self.t)
19 y = self.linear2(y)
20 y = self.linear2(y)
21 # z = self.relu(y)
22 return nn.functional.elu(y)
23 return z
24
25
26def generate_foo_pt():
27 f = Foo()
28 x = torch.rand(1, 2)
29 m = torch.jit.trace(f, x)
30 m.save("foo.pt")
31
32
33def test_foo_pt():
34 m = torch.jit.load("foo.pt")
35 assert isinstance(m.forward, torch._C.ScriptMethod)
36 assert isinstance(m.forward.graph, torch._C.Graph)
37 assert isinstance(m.forward.inlined_graph, torch._C.Graph)
38
39 print(m.linear.graph)
40 return
41
42 print(m.forward.graph)
43 # print(m.forward.inlined_graph)
44 g = m.forward.graph
45 nodes = g.nodes()
46
47 n = next(nodes)
48 print(dir(n))
49 assert n.kind() == "prim::GetAttr"
50 for i in n.inputs():
51 assert isinstance(i, torch._C.Value)
52 assert i.debugName() == "self.1"
53 assert isinstance(i.type(), torch._C.ClassType)
54 t = i.type()
55 assert t.str() == "__torch__.Foo"
56
57
58def main():
59 generate_foo_pt()
60 # test_foo_pt()
61
62
63if __name__ == "__main__":
64 main()