trace

./code/trace/ex0.py
 1#!/usr/bin/env python3
 2
 3import torch
 4
 5import torch.nn as nn
 6from typing import List
 7
 8
 9class Foo(nn.Module):
10    def __init__(self):
11        super().__init__()
12        self.relu = nn.ReLU()
13
14    def forward(self, x):
15        return self.relu(x)
16
17
18def test_foo():
19    f = Foo()
20    m = torch.jit.trace(f, torch.rand(2, 3))
21
22    print(m(torch.rand(2)))
23    print(m(torch.rand(2, 3, 4)))
24    # Note: The input shape is dynamic, not fixed.
25
26
27def simple(x: List[torch.Tensor], y: torch.Tensor):
28    x = x[0].item()
29    if x > 2:
30        return y + x + 1
31    elif x < 1:
32        return y
33    else:
34        return y + x
35
36
37def test_simple():
38    f0 = torch.jit.trace(simple, ([torch.tensor([0])], torch.rand(2, 3)))
39    #  print(dir(f0))
40    """
41    ['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__',
42    '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__',
43    '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__',
44    '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__',
45    '__sizeof__', '__str__', '__subclasshook__', '_debug_flush_compilation_cache',
46    'code', 'get_debug_state', 'graph', 'graph_for', 'inlined_graph', 'name',
47    'qualified_name', 'save', 'save_to_buffer', 'schema']
48    """
49    #  print(f0.schema)  # simple(Tensor[] x, Tensor y) -> (Tensor)
50    #  print(f0.code)
51    """
52    def simple(x: List[Tensor],
53        y: Tensor) -> Tensor:
54      return y
55    """
56    #  print(f0.graph)
57    """
58    graph(%x : Tensor[],
59          %y : Float(2, 3, strides=[3, 1], requires_grad=0, device=cpu)):
60      return (%y)
61    """
62    #  print(f0.inlined_graph) # same as the above one
63    #  print(f0.name) # simple
64    print(f0.qualified_name)  # __torch__.simple
65
66
67def main():
68    #  test_foo()
69    test_simple()
70
71
72if __name__ == "__main__":
73    main()
./code/trace/ex1.py
 1#!/usr/bin/env python3
 2
 3import torch
 4
 5
 6def f(a, b):
 7    c = a + b
 8    d = c * c
 9    e = torch.tanh(d * c)
10    return d + (e + e)
11
12
13m = torch.jit.script(f)
14print(m.graph)
15
16"""
17graph(%a.1 : Tensor,
18      %b.1 : Tensor):
19  %4 : int = prim::Constant[value=1]()
20  %c.1 : Tensor = aten::add(%a.1, %b.1, %4) # ./ex1.py:7:8
21  %d.1 : Tensor = aten::mul(%c.1, %c.1) # ./ex1.py:8:8
22  %11 : Tensor = aten::mul(%d.1, %c.1) # ./ex1.py:9:19
23  %e.1 : Tensor = aten::tanh(%11) # ./ex1.py:9:8
24  %17 : Tensor = aten::add(%e.1, %e.1, %4) # ./ex1.py:10:16
25  %19 : Tensor = aten::add(%d.1, %17, %4) # ./ex1.py:10:11
26  return (%19)
27"""
28
29"""
30Note: for aten::add(a0, a1, a2), it does  a0 + a2 * a1.
31See torch/csrc/jit/codegen/fuser/codegen.cpp
32
33"""
34assert isinstance(m.graph, torch._C.Graph)
35
36# Every graph has inputs and outputs
37# m.graph.inputs() returns an iterator
38assert len(list(m.graph.inputs())) == 2, "It has two inputs: a, b, in our case"
39it = m.graph.inputs()
40a = next(it)
41b = next(it)
42
43assert isinstance(a, torch._C.Value)
44assert isinstance(a.node(), torch._C.Node)
45
46# every node has inputs and outputs
47# a.node().inputs() is an iterator
48assert list(a.node().inputs()) == []
49assert a.node().kind() == "prim::Param"
50assert a.node().inputsSize() == 0
51assert a.node().outputsSize() == 2
52print(next(a.node().outputs()))
53
54oit = a.node().outputs()
55assert next(oit) == a
56assert next(oit) == b
57
58assert next(a.node().outputs()) == a
59
60assert a.node().outputsAt(0) == a
61assert a.node().outputsAt(1) == b
62assert a.node() == b.node()
63assert a.node().attributeNames() == [], "this node has no attributes"
64assert a.debugName() == "a.1"
65assert isinstance(a.type(), torch._C.TensorType)
66assert a.type().kind() == "TensorType"
67assert a.unique() == 0  # TODO(fangjun): what does it mean?
68assert isinstance(a.uses(), list)
69assert isinstance(a.uses()[0], torch._C.Use)
70assert isinstance(a.uses()[0].user, torch._C.Node)
71
72c_node = a.uses()[0].user
73assert c_node.kind() == "aten::add"
74assert c_node.attributeNames() == []
75assert len(list(c_node.inputs())) == 3
76c_it = c_node.inputs()
77assert a == next(c_it)
78assert b == next(c_it)
79v4 = next(c_it)
80assert v4.debugName() == "4"
81assert c_node.hasAttributes() is False
82assert c_node.hasMultipleOutputs() is False
83assert c_node.hasUses() is True
84assert (
85    c_node.schema()
86    == "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)"
87)
88print(c_node.schema())
89print(type(c_node.schema()))
90v4_node = v4.node()
91assert v4_node.attributeNames() == ["value"]
92assert v4_node.hasAttributes() is True
93assert v4_node.hasAttribute("value") is True
94#  print(v4_node.t("value"))
95print(dir(v4_node))