passes
_jit_pass_fuse_add_relu
See https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/fuse_relu.cpp
./code/passes/fuse_add_relu.py
1#!/usr/bin/env python3
2
3import torch
4
5
6class Foo(torch.nn.Module):
7 def forward(self, x: torch.Tensor, y: torch.Tensor):
8 a = torch.nn.functional.relu(x + y)
9 return a + 10
10
11
12def main():
13 f = Foo()
14 m = torch.jit.trace(f, (torch.rand(3), torch.rand(3)))
15 g = m.graph
16
17 with open("fuse_add_relu-before.txt", "w") as f:
18 print(g, file=f)
19
20 torch._C._jit_pass_fuse_add_relu(g)
21
22 with open("fuse_add_relu-after.txt", "w") as f:
23 print(g, file=f)
24
25
26if __name__ == "__main__":
27 main()
./code/passes/fuse_add_relu-before.txt
1graph(%self : __torch__.Foo,
2 %x : Float(3, strides=[1], requires_grad=0, device=cpu),
3 %y : Float(3, strides=[1], requires_grad=0, device=cpu)):
4 %5 : int = prim::Constant[value=1]() # ./fuse_add_relu.py:8:0
5 %input : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%x, %y, %5) # ./fuse_add_relu.py:8:0
6 %a : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::relu(%input) # /Users/fangjun/py38/lib/python3.8/site-packages/torch/nn/functional.py:1457:0
7 %8 : Long(requires_grad=0, device=cpu) = prim::Constant[value={10}]() # ./fuse_add_relu.py:9:0
8 %9 : int = prim::Constant[value=1]() # ./fuse_add_relu.py:9:0
9 %10 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%a, %8, %9) # ./fuse_add_relu.py:9:0
10 return (%10)
11
./code/passes/fuse_add_relu-after.txt
1graph(%self : __torch__.Foo,
2 %x : Float(3, strides=[1], requires_grad=0, device=cpu),
3 %y : Float(3, strides=[1], requires_grad=0, device=cpu)):
4 %5 : int = prim::Constant[value=1]() # ./fuse_add_relu.py:8:0
5 %11 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::_add_relu(%x, %y, %5)
6 %8 : Long(requires_grad=0, device=cpu) = prim::Constant[value={10}]() # ./fuse_add_relu.py:9:0
7 %9 : int = prim::Constant[value=1]() # ./fuse_add_relu.py:9:0
8 %10 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%11, %8, %9) # ./fuse_add_relu.py:9:0
9 return (%10)
10
in_place relu_
./code/passes/fuse_add_relu_.py
1#!/usr/bin/env python3
2
3import torch
4
5
6class Foo(torch.nn.Module):
7 def forward(self, x: torch.Tensor, y: torch.Tensor):
8 a = torch.nn.functional.relu(x + y, inplace=True)
9 return a + 10
10
11
12def main():
13 f = Foo()
14 m = torch.jit.trace(f, (torch.rand(3), torch.rand(3)))
15 g = m.graph
16
17 with open("fuse_add_relu_-before.txt", "w") as f:
18 print(g, file=f)
19
20 torch._C._jit_pass_fuse_add_relu(g)
21
22 with open("fuse_add_relu_-after.txt", "w") as f:
23 print(g, file=f)
24
25
26if __name__ == "__main__":
27 main()
./code/passes/fuse_add_relu_-before.txt
1graph(%self : __torch__.Foo,
2 %x : Float(3, strides=[1], requires_grad=0, device=cpu),
3 %y : Float(3, strides=[1], requires_grad=0, device=cpu)):
4 %5 : int = prim::Constant[value=1]() # ./fuse_add_relu_.py:8:0
5 %input : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%x, %y, %5) # ./fuse_add_relu_.py:8:0
6 %a : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::relu_(%input) # /Users/fangjun/py38/lib/python3.8/site-packages/torch/nn/functional.py:1455:0
7 %8 : Long(requires_grad=0, device=cpu) = prim::Constant[value={10}]() # ./fuse_add_relu_.py:9:0
8 %9 : int = prim::Constant[value=1]() # ./fuse_add_relu_.py:9:0
9 %10 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%a, %8, %9) # ./fuse_add_relu_.py:9:0
10 return (%10)
11
./code/passes/fuse_add_relu_-after.txt
1graph(%self : __torch__.Foo,
2 %x : Float(3, strides=[1], requires_grad=0, device=cpu),
3 %y : Float(3, strides=[1], requires_grad=0, device=cpu)):
4 %5 : int = prim::Constant[value=1]() # ./fuse_add_relu_.py:8:0
5 %11 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::_add_relu(%x, %y, %5)
6 %8 : Long(requires_grad=0, device=cpu) = prim::Constant[value={10}]() # ./fuse_add_relu_.py:9:0
7 %9 : int = prim::Constant[value=1]() # ./fuse_add_relu_.py:9:0
8 %10 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%11, %8, %9) # ./fuse_add_relu_.py:9:0
9 return (%10)
10
_jit_pass_fuse_linear
See https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/fuse_linear.cpp
./code/passes/fuse_linear.py
1#!/usr/bin/env python3
2
3import torch
4
5
6class Foo(torch.nn.Module):
7 def forward(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor):
8 return torch.matmul(x, w.t()) + b
9
10
11def main():
12 f = Foo()
13 m = torch.jit.trace(f, (torch.rand(3), torch.rand(3, 3), torch.rand(3)))
14 g = m.graph
15
16 with open("fuse_linear-before.txt", "w") as f:
17 print(g, file=f)
18
19 torch._C._jit_pass_fuse_linear(g)
20
21 with open("fuse_linear-after.txt", "w") as f:
22 print(g, file=f)
23
24
25if __name__ == "__main__":
26 main()
./code/passes/fuse_linear-before.txt
1graph(%self : __torch__.Foo,
2 %x : Float(3, strides=[1], requires_grad=0, device=cpu),
3 %w : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu),
4 %b : Float(3, strides=[1], requires_grad=0, device=cpu)):
5 %6 : Float(3, 3, strides=[1, 3], requires_grad=0, device=cpu) = aten::t(%w) # ./fuse_linear.py:8:0
6 %7 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::matmul(%x, %6) # ./fuse_linear.py:8:0
7 %8 : int = prim::Constant[value=1]() # ./fuse_linear.py:8:0
8 %9 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%7, %b, %8) # ./fuse_linear.py:8:0
9 return (%9)
10
./code/passes/fuse_linear-after.txt
1graph(%self : __torch__.Foo,
2 %x : Float(3, strides=[1], requires_grad=0, device=cpu),
3 %w : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu),
4 %b : Float(3, strides=[1], requires_grad=0, device=cpu)):
5 %11 : Tensor? = prim::Constant()
6 %13 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::linear(%x, %w, %11) # ./fuse_linear.py:8:0
7 %8 : int = prim::Constant[value=1]() # ./fuse_linear.py:8:0
8 %9 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%13, %b, %8) # ./fuse_linear.py:8:0
9 return (%9)
10