Hello
See https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.
torch.jit.script
as a decorator
./code/1-ex.py
1@torch.jit.script
2def adder(x: int):
3 return x + 1
4
5
6def test_adder():
7 assert isinstance(adder, torch.jit.ScriptFunction)
8 print(adder.graph)
9 print("-" * 10)
10 print(adder.code)
11 adder.save("adder.pt")
12
13 my_adder = torch.jit.load("adder.pt")
14
15 assert isinstance(my_adder, torch.jit._script.RecursiveScriptModule)
16 assert isinstance(my_adder, torch.jit.ScriptModule)
17 assert not isinstance(my_adder, torch.jit.ScriptFunction)
18 print(my_adder(torch.tensor([3])))
19
20
21"""
22graph(%x.1 : int):
23 %2 : int = prim::Constant[value=1]() # ./1-ex.py:8:15
24 %3 : int = aten::add(%x.1, %2) # ./1-ex.py:8:11
25 return (%3)
26
27----------
28def adder(x: int) -> int:
29 return torch.add(x, 1)
30
314
32"""
torch.Value
has the following attributes (output of dir()
):
['__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__',
'__ge__', '__getattribute__', '__gt__', '__hash__', '__init__',
'__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__',
'__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__',
'__str__', '__subclasshook__', 'copyMetadata', 'debugName', 'inferTypeFrom',
'isCompleteTensor', 'node', 'offset', 'replaceAllUsesAfterNodeWith',
'replaceAllUsesWith', 'requiresGrad', 'requires_grad', 'setDebugName',
'setType', 'setTypeAs', 'toIValue', 'type', 'unique', 'uses']
./code/1-ex.py
1def print_graph():
2 assert isinstance(adder.graph, torch._C.Graph)
3 assert isinstance(adder.graph, torch.Graph)
4
5 # It should have only 1 input
6 assert len(list(adder.graph.inputs())) == 1
7
8 x = next(adder.graph.inputs())
9 assert isinstance(x, torch.Value)
10 assert isinstance(x.debugName(), str)
11 assert x.debugName() == "x.1"
12 print(type(x.uses()[0]))
13 print(dir(x.uses()[0]))
14 print(x.uses()[0].user)
15 assert isinstance(x.uses()[0].user, torch.Node)
16
17 x.setDebugName("x.2")
18 assert next(adder.graph.inputs()).debugName() == "x.2"
19 assert isinstance(x.type(), torch.IntType)
20
21 print(x.node())
22 assert isinstance(x.node(), torch.Node)
23 print(dir(x.node()))
24 n = x.node()
25 assert isinstance(n.kind(), str)
26 assert n.kind() == "prim::Param", n.kind()
27 print(n.kind())
28 # a node as input and output
29 assert list(n.inputs()) == []
30
31 # n has only one output, i.e., x
32 assert len(list(n.outputs())) == 1
33 x2 = next(n.outputs()) # its type is torch.Value
34 assert x2 is x
35 assert len(list(n.blocks())) == 0
torch.Node
has the following attributes (output from dir()
):
['__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__',
'__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__',
'__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__',
'__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__',
'__str__', '__subclasshook__', 'addBlock', 'addInput', 'addOutput',
'attributeNames', 'blocks', 'c', 'c_', 'cconv', 'copyAttributes', 'copyMetadata',
'destroy', 'eraseOutput', 'f', 'f_', 'findAllNodes', 'findNode', 'fs',
'fs_', 'g', 'g_', 'getModuleHierarchy', 'gs', 'gs_', 'hasAttribute',
'hasAttributes', 'hasMultipleOutputs', 'hasUses', 'i', 'i_', 'input',
'inputs', 'inputsAt', 'inputsSize', 'insertAfter', 'insertBefore', 'is',
'isAfter', 'isBefore', 'isNondeterministic', 'is_', 'ival', 'ival_', 'kind',
'kindOf', 'matches', 'moveAfter', 'moveBefore', 'mustBeNone', 'namedInput',
'output', 'outputs', 'outputsAt', 'outputsSize', 'owningBlock', 'prev',
'pyname', 'pyobj', 'removeAllInputs', 'removeAttribute', 'removeInput',
'replaceAllUsesWith', 'replaceInput', 'replaceInputWith', 's', 's_', 'scalar_args',
'schema', 'scopeName', 'sourceRange', 'ss', 'ss_', 't', 't_', 'ts', 'ts_',
'ty_', 'tys_', 'z', 'z_', 'zs', 'zs_']
torch.jit.script
as a function
./code/2-ex.py
1def adder(x: int):
2 return x + 2
3
4
5def test_adder():
6 adder_func = torch.jit.script(adder)
7 assert isinstance(adder_func, torch.jit.ScriptFunction)
8 print(adder_func.graph)
9 print(adder_func(3))
10
11
12"""
13graph(%x.1 : int):
14 %2 : int = prim::Constant[value=2]() # ./2-ex.py:6:15
15 %3 : int = aten::add(%x.1, %2) # ./2-ex.py:6:11
16 return (%3)
17
185
19"""
torchscript a module
./code/3-ex.py
1class MyModel(torch.nn.Module):
2 def __init__(self):
3 super().__init__()
4 self.p = torch.nn.Parameter(torch.tensor([2.0]))
5
6 def forward(self, x: torch.Tensor):
7 return self.p * x
8
9
10def test_my_model():
11 model = MyModel()
12 scripted_model = torch.jit.script(model)
13 print(scripted_model.graph)
14 print("-" * 10)
15 print(scripted_model.code)
16 print(scripted_model(torch.tensor([10])))
17
18
19"""
20graph(%self : __torch__.MyModel,
21 %x.1 : Tensor):
22 %p : Tensor = prim::GetAttr[name="p"](%self)
23 %4 : Tensor = aten::mul(%p, %x.1) # ./3-ex.py:12:15
24 return (%4)
25
26----------
27def forward(self,
28 x: Tensor) -> Tensor:
29 p = self.p
30 return torch.mul(p, x)
31"""
trace a module
./code/trace/ex0.py
1#!/usr/bin/env python3
2
3import torch
4
5import torch.nn as nn
6from typing import List
7
8
9class Foo(nn.Module):
10 def __init__(self):
11 super().__init__()
12 self.relu = nn.ReLU()
13
14 def forward(self, x):
15 return self.relu(x)
16
17
18def test_foo():
19 f = Foo()
20 m = torch.jit.trace(f, torch.rand(2, 3))
21
22 print(m(torch.rand(2)))
23 print(m(torch.rand(2, 3, 4)))
24 # Note: The input shape is dynamic, not fixed.
25
26
27def simple(x: List[torch.Tensor], y: torch.Tensor):
28 x = x[0].item()
29 if x > 2:
30 return y + x + 1
31 elif x < 1:
32 return y
33 else:
34 return y + x
35
36
37def test_simple():
38 f0 = torch.jit.trace(simple, ([torch.tensor([0])], torch.rand(2, 3)))
39 # print(dir(f0))
40 """
41 ['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__',
42 '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__',
43 '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__',
44 '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__',
45 '__sizeof__', '__str__', '__subclasshook__', '_debug_flush_compilation_cache',
46 'code', 'get_debug_state', 'graph', 'graph_for', 'inlined_graph', 'name',
47 'qualified_name', 'save', 'save_to_buffer', 'schema']
48 """
49 # print(f0.schema) # simple(Tensor[] x, Tensor y) -> (Tensor)
50 # print(f0.code)
51 """
52 def simple(x: List[Tensor],
53 y: Tensor) -> Tensor:
54 return y
55 """
56 # print(f0.graph)
57 """
58 graph(%x : Tensor[],
59 %y : Float(2, 3, strides=[3, 1], requires_grad=0, device=cpu)):
60 return (%y)
61 """
62 # print(f0.inlined_graph) # same as the above one
63 # print(f0.name) # simple
64 print(f0.qualified_name) # __torch__.simple
65
66
67def main():
68 # test_foo()
69 test_simple()
70
71
72if __name__ == "__main__":
73 main()
Export and ignore methods
Use
@torch.jit.export
decorator to export a method.Use
torch.jit.export
function call to export a method.Use
@torch.jit.ignore
decorator to ignore a method.Use
torch.jit.ignore
function call to ignore a method.Use
@torch.jit.unused
ortorch.jit.unused
to ignore a method.
See Load in C++ to load the saved file.
./code/4-ex.py
1class MyModel(torch.nn.Module):
2 def __init__(self):
3 super().__init__()
4 self.p = torch.nn.Parameter(torch.tensor([2.0]))
5
6 def foobar(self, x: torch.Tensor):
7 return x + 3
8
9 def foo(self, x: torch.Tensor):
10 return self.foobar(x)
11
12 def bar(self, x: torch.Tensor):
13 return self.p - x
14
15 @torch.jit.export
16 def baz(self, x: torch.Tensor):
17 return self.p + x + 2
18
19 def forward(self, x: torch.Tensor):
20 return self.p * x
21
22
23def test_my_model():
24 MyModel.foo = torch.jit.export(MyModel.foo) # manually export
25
26 # Note: forward is exported by default. We ignore it here manually
27 MyModel.forward = torch.jit.ignore(MyModel.forward)
28
29 model = MyModel()
30 scripted_model = torch.jit.script(model)
31 assert hasattr(scripted_model, "foo")
32 assert hasattr(scripted_model, "baz")
33 assert hasattr(scripted_model, "foobar") # because it is called by `foo`
34 assert not hasattr(scripted_model, "bar")
35
36 scripted_model.save("foo.pt")
37
38 m = torch.jit.load("foo.pt")
39 print(m.foo(torch.tensor([1])))
40 print(m.baz(torch.tensor([1])))
41
42
43"""
44graph(%self : __torch__.MyModel,
45 %x.1 : Tensor):
46 %p : Tensor = prim::GetAttr[name="p"](%self)
47 %4 : Tensor = aten::mul(%p, %x.1) # ./3-ex.py:12:15
48 return (%4)
49
50----------
51def forward(self,
52 x: Tensor) -> Tensor:
53 p = self.p
54 return torch.mul(p, x)
55"""