Hello

See https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.

torch.jit.script as a decorator

./code/1-ex.py
 1@torch.jit.script
 2def adder(x: int):
 3    return x + 1
 4
 5
 6def test_adder():
 7    assert isinstance(adder, torch.jit.ScriptFunction)
 8    print(adder.graph)
 9    print("-" * 10)
10    print(adder.code)
11    adder.save("adder.pt")
12
13    my_adder = torch.jit.load("adder.pt")
14
15    assert isinstance(my_adder, torch.jit._script.RecursiveScriptModule)
16    assert isinstance(my_adder, torch.jit.ScriptModule)
17    assert not isinstance(my_adder, torch.jit.ScriptFunction)
18    print(my_adder(torch.tensor([3])))
19
20
21"""
22graph(%x.1 : int):
23  %2 : int = prim::Constant[value=1]() # ./1-ex.py:8:15
24  %3 : int = aten::add(%x.1, %2) # ./1-ex.py:8:11
25  return (%3)
26
27----------
28def adder(x: int) -> int:
29  return torch.add(x, 1)
30
314
32"""

torch.Value has the following attributes (output of dir()):

['__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__',
'__ge__', '__getattribute__', '__gt__', '__hash__', '__init__',
'__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__',
'__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__',
'__str__', '__subclasshook__', 'copyMetadata', 'debugName', 'inferTypeFrom',
'isCompleteTensor', 'node', 'offset', 'replaceAllUsesAfterNodeWith',
'replaceAllUsesWith', 'requiresGrad', 'requires_grad', 'setDebugName',
'setType', 'setTypeAs', 'toIValue', 'type', 'unique', 'uses']
./code/1-ex.py
 1def print_graph():
 2    assert isinstance(adder.graph, torch._C.Graph)
 3    assert isinstance(adder.graph, torch.Graph)
 4
 5    # It should have only 1 input
 6    assert len(list(adder.graph.inputs())) == 1
 7
 8    x = next(adder.graph.inputs())
 9    assert isinstance(x, torch.Value)
10    assert isinstance(x.debugName(), str)
11    assert x.debugName() == "x.1"
12    print(type(x.uses()[0]))
13    print(dir(x.uses()[0]))
14    print(x.uses()[0].user)
15    assert isinstance(x.uses()[0].user, torch.Node)
16
17    x.setDebugName("x.2")
18    assert next(adder.graph.inputs()).debugName() == "x.2"
19    assert isinstance(x.type(), torch.IntType)
20
21    print(x.node())
22    assert isinstance(x.node(), torch.Node)
23    print(dir(x.node()))
24    n = x.node()
25    assert isinstance(n.kind(), str)
26    assert n.kind() == "prim::Param", n.kind()
27    print(n.kind())
28    # a node as input and output
29    assert list(n.inputs()) == []
30
31    # n has only one output, i.e., x
32    assert len(list(n.outputs())) == 1
33    x2 = next(n.outputs())  # its type is torch.Value
34    assert x2 is x
35    assert len(list(n.blocks())) == 0

torch.Node has the following attributes (output from dir()):

['__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__',
'__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__',
'__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__',
'__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__',
'__str__', '__subclasshook__', 'addBlock', 'addInput', 'addOutput',
'attributeNames', 'blocks', 'c', 'c_', 'cconv', 'copyAttributes', 'copyMetadata',
'destroy', 'eraseOutput', 'f', 'f_', 'findAllNodes', 'findNode', 'fs',
'fs_', 'g', 'g_', 'getModuleHierarchy', 'gs', 'gs_', 'hasAttribute',
'hasAttributes', 'hasMultipleOutputs', 'hasUses', 'i', 'i_', 'input',
'inputs', 'inputsAt', 'inputsSize', 'insertAfter', 'insertBefore', 'is',
'isAfter', 'isBefore', 'isNondeterministic', 'is_', 'ival', 'ival_', 'kind',
'kindOf', 'matches', 'moveAfter', 'moveBefore', 'mustBeNone', 'namedInput',
'output', 'outputs', 'outputsAt', 'outputsSize', 'owningBlock', 'prev',
'pyname', 'pyobj', 'removeAllInputs', 'removeAttribute', 'removeInput',
'replaceAllUsesWith', 'replaceInput', 'replaceInputWith', 's', 's_', 'scalar_args',
'schema', 'scopeName', 'sourceRange', 'ss', 'ss_', 't', 't_', 'ts', 'ts_',
'ty_', 'tys_', 'z', 'z_', 'zs', 'zs_']

torch.jit.script as a function

./code/2-ex.py
 1def adder(x: int):
 2    return x + 2
 3
 4
 5def test_adder():
 6    adder_func = torch.jit.script(adder)
 7    assert isinstance(adder_func, torch.jit.ScriptFunction)
 8    print(adder_func.graph)
 9    print(adder_func(3))
10
11
12"""
13graph(%x.1 : int):
14  %2 : int = prim::Constant[value=2]() # ./2-ex.py:6:15
15  %3 : int = aten::add(%x.1, %2) # ./2-ex.py:6:11
16  return (%3)
17
185
19"""

torchscript a module

./code/3-ex.py
 1class MyModel(torch.nn.Module):
 2    def __init__(self):
 3        super().__init__()
 4        self.p = torch.nn.Parameter(torch.tensor([2.0]))
 5
 6    def forward(self, x: torch.Tensor):
 7        return self.p * x
 8
 9
10def test_my_model():
11    model = MyModel()
12    scripted_model = torch.jit.script(model)
13    print(scripted_model.graph)
14    print("-" * 10)
15    print(scripted_model.code)
16    print(scripted_model(torch.tensor([10])))
17
18
19"""
20graph(%self : __torch__.MyModel,
21      %x.1 : Tensor):
22  %p : Tensor = prim::GetAttr[name="p"](%self)
23  %4 : Tensor = aten::mul(%p, %x.1) # ./3-ex.py:12:15
24  return (%4)
25
26----------
27def forward(self,
28    x: Tensor) -> Tensor:
29  p = self.p
30  return torch.mul(p, x)
31"""

trace a module

./code/trace/ex0.py
 1#!/usr/bin/env python3
 2
 3import torch
 4
 5import torch.nn as nn
 6from typing import List
 7
 8
 9class Foo(nn.Module):
10    def __init__(self):
11        super().__init__()
12        self.relu = nn.ReLU()
13
14    def forward(self, x):
15        return self.relu(x)
16
17
18def test_foo():
19    f = Foo()
20    m = torch.jit.trace(f, torch.rand(2, 3))
21
22    print(m(torch.rand(2)))
23    print(m(torch.rand(2, 3, 4)))
24    # Note: The input shape is dynamic, not fixed.
25
26
27def simple(x: List[torch.Tensor], y: torch.Tensor):
28    x = x[0].item()
29    if x > 2:
30        return y + x + 1
31    elif x < 1:
32        return y
33    else:
34        return y + x
35
36
37def test_simple():
38    f0 = torch.jit.trace(simple, ([torch.tensor([0])], torch.rand(2, 3)))
39    #  print(dir(f0))
40    """
41    ['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__',
42    '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__',
43    '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__',
44    '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__',
45    '__sizeof__', '__str__', '__subclasshook__', '_debug_flush_compilation_cache',
46    'code', 'get_debug_state', 'graph', 'graph_for', 'inlined_graph', 'name',
47    'qualified_name', 'save', 'save_to_buffer', 'schema']
48    """
49    #  print(f0.schema)  # simple(Tensor[] x, Tensor y) -> (Tensor)
50    #  print(f0.code)
51    """
52    def simple(x: List[Tensor],
53        y: Tensor) -> Tensor:
54      return y
55    """
56    #  print(f0.graph)
57    """
58    graph(%x : Tensor[],
59          %y : Float(2, 3, strides=[3, 1], requires_grad=0, device=cpu)):
60      return (%y)
61    """
62    #  print(f0.inlined_graph) # same as the above one
63    #  print(f0.name) # simple
64    print(f0.qualified_name)  # __torch__.simple
65
66
67def main():
68    #  test_foo()
69    test_simple()
70
71
72if __name__ == "__main__":
73    main()

Export and ignore methods

  1. Use @torch.jit.export decorator to export a method.

  2. Use torch.jit.export function call to export a method.

  3. Use @torch.jit.ignore decorator to ignore a method.

  4. Use torch.jit.ignore function call to ignore a method.

  5. Use @torch.jit.unused or torch.jit.unused to ignore a method.

See Load in C++ to load the saved file.

./code/4-ex.py
 1class MyModel(torch.nn.Module):
 2    def __init__(self):
 3        super().__init__()
 4        self.p = torch.nn.Parameter(torch.tensor([2.0]))
 5
 6    def foobar(self, x: torch.Tensor):
 7        return x + 3
 8
 9    def foo(self, x: torch.Tensor):
10        return self.foobar(x)
11
12    def bar(self, x: torch.Tensor):
13        return self.p - x
14
15    @torch.jit.export
16    def baz(self, x: torch.Tensor):
17        return self.p + x + 2
18
19    def forward(self, x: torch.Tensor):
20        return self.p * x
21
22
23def test_my_model():
24    MyModel.foo = torch.jit.export(MyModel.foo)  # manually export
25
26    # Note: forward is exported by default. We ignore it here manually
27    MyModel.forward = torch.jit.ignore(MyModel.forward)
28
29    model = MyModel()
30    scripted_model = torch.jit.script(model)
31    assert hasattr(scripted_model, "foo")
32    assert hasattr(scripted_model, "baz")
33    assert hasattr(scripted_model, "foobar")  # because it is called by `foo`
34    assert not hasattr(scripted_model, "bar")
35
36    scripted_model.save("foo.pt")
37
38    m = torch.jit.load("foo.pt")
39    print(m.foo(torch.tensor([1])))
40    print(m.baz(torch.tensor([1])))
41
42
43"""
44graph(%self : __torch__.MyModel,
45      %x.1 : Tensor):
46  %p : Tensor = prim::GetAttr[name="p"](%self)
47  %4 : Tensor = aten::mul(%p, %x.1) # ./3-ex.py:12:15
48  return (%4)
49
50----------
51def forward(self,
52    x: Tensor) -> Tensor:
53  p = self.p
54  return torch.mul(p, x)
55"""