trace
./code/trace/ex0.py
1#!/usr/bin/env python3
2
3import torch
4
5import torch.nn as nn
6from typing import List
7
8
9class Foo(nn.Module):
10 def __init__(self):
11 super().__init__()
12 self.relu = nn.ReLU()
13
14 def forward(self, x):
15 return self.relu(x)
16
17
18def test_foo():
19 f = Foo()
20 m = torch.jit.trace(f, torch.rand(2, 3))
21
22 print(m(torch.rand(2)))
23 print(m(torch.rand(2, 3, 4)))
24 # Note: The input shape is dynamic, not fixed.
25
26
27def simple(x: List[torch.Tensor], y: torch.Tensor):
28 x = x[0].item()
29 if x > 2:
30 return y + x + 1
31 elif x < 1:
32 return y
33 else:
34 return y + x
35
36
37def test_simple():
38 f0 = torch.jit.trace(simple, ([torch.tensor([0])], torch.rand(2, 3)))
39 # print(dir(f0))
40 """
41 ['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__',
42 '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__',
43 '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__',
44 '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__',
45 '__sizeof__', '__str__', '__subclasshook__', '_debug_flush_compilation_cache',
46 'code', 'get_debug_state', 'graph', 'graph_for', 'inlined_graph', 'name',
47 'qualified_name', 'save', 'save_to_buffer', 'schema']
48 """
49 # print(f0.schema) # simple(Tensor[] x, Tensor y) -> (Tensor)
50 # print(f0.code)
51 """
52 def simple(x: List[Tensor],
53 y: Tensor) -> Tensor:
54 return y
55 """
56 # print(f0.graph)
57 """
58 graph(%x : Tensor[],
59 %y : Float(2, 3, strides=[3, 1], requires_grad=0, device=cpu)):
60 return (%y)
61 """
62 # print(f0.inlined_graph) # same as the above one
63 # print(f0.name) # simple
64 print(f0.qualified_name) # __torch__.simple
65
66
67def main():
68 # test_foo()
69 test_simple()
70
71
72if __name__ == "__main__":
73 main()
./code/trace/ex1.py
1#!/usr/bin/env python3
2
3import torch
4
5
6def f(a, b):
7 c = a + b
8 d = c * c
9 e = torch.tanh(d * c)
10 return d + (e + e)
11
12
13m = torch.jit.script(f)
14print(m.graph)
15
16"""
17graph(%a.1 : Tensor,
18 %b.1 : Tensor):
19 %4 : int = prim::Constant[value=1]()
20 %c.1 : Tensor = aten::add(%a.1, %b.1, %4) # ./ex1.py:7:8
21 %d.1 : Tensor = aten::mul(%c.1, %c.1) # ./ex1.py:8:8
22 %11 : Tensor = aten::mul(%d.1, %c.1) # ./ex1.py:9:19
23 %e.1 : Tensor = aten::tanh(%11) # ./ex1.py:9:8
24 %17 : Tensor = aten::add(%e.1, %e.1, %4) # ./ex1.py:10:16
25 %19 : Tensor = aten::add(%d.1, %17, %4) # ./ex1.py:10:11
26 return (%19)
27"""
28
29"""
30Note: for aten::add(a0, a1, a2), it does a0 + a2 * a1.
31See torch/csrc/jit/codegen/fuser/codegen.cpp
32
33"""
34assert isinstance(m.graph, torch._C.Graph)
35
36# Every graph has inputs and outputs
37# m.graph.inputs() returns an iterator
38assert len(list(m.graph.inputs())) == 2, "It has two inputs: a, b, in our case"
39it = m.graph.inputs()
40a = next(it)
41b = next(it)
42
43assert isinstance(a, torch._C.Value)
44assert isinstance(a.node(), torch._C.Node)
45
46# every node has inputs and outputs
47# a.node().inputs() is an iterator
48assert list(a.node().inputs()) == []
49assert a.node().kind() == "prim::Param"
50assert a.node().inputsSize() == 0
51assert a.node().outputsSize() == 2
52print(next(a.node().outputs()))
53
54oit = a.node().outputs()
55assert next(oit) == a
56assert next(oit) == b
57
58assert next(a.node().outputs()) == a
59
60assert a.node().outputsAt(0) == a
61assert a.node().outputsAt(1) == b
62assert a.node() == b.node()
63assert a.node().attributeNames() == [], "this node has no attributes"
64assert a.debugName() == "a.1"
65assert isinstance(a.type(), torch._C.TensorType)
66assert a.type().kind() == "TensorType"
67assert a.unique() == 0 # TODO(fangjun): what does it mean?
68assert isinstance(a.uses(), list)
69assert isinstance(a.uses()[0], torch._C.Use)
70assert isinstance(a.uses()[0].user, torch._C.Node)
71
72c_node = a.uses()[0].user
73assert c_node.kind() == "aten::add"
74assert c_node.attributeNames() == []
75assert len(list(c_node.inputs())) == 3
76c_it = c_node.inputs()
77assert a == next(c_it)
78assert b == next(c_it)
79v4 = next(c_it)
80assert v4.debugName() == "4"
81assert c_node.hasAttributes() is False
82assert c_node.hasMultipleOutputs() is False
83assert c_node.hasUses() is True
84assert (
85 c_node.schema()
86 == "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)"
87)
88print(c_node.schema())
89print(type(c_node.schema()))
90v4_node = v4.node()
91assert v4_node.attributeNames() == ["value"]
92assert v4_node.hasAttributes() is True
93assert v4_node.hasAttribute("value") is True
94# print(v4_node.t("value"))
95print(dir(v4_node))