passes

_jit_pass_fuse_add_relu

See https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/fuse_relu.cpp

./code/passes/fuse_add_relu.py
 1#!/usr/bin/env python3
 2
 3import torch
 4
 5
 6class Foo(torch.nn.Module):
 7    def forward(self, x: torch.Tensor, y: torch.Tensor):
 8        a = torch.nn.functional.relu(x + y)
 9        return a + 10
10
11
12def main():
13    f = Foo()
14    m = torch.jit.trace(f, (torch.rand(3), torch.rand(3)))
15    g = m.graph
16
17    with open("fuse_add_relu-before.txt", "w") as f:
18        print(g, file=f)
19
20    torch._C._jit_pass_fuse_add_relu(g)
21
22    with open("fuse_add_relu-after.txt", "w") as f:
23        print(g, file=f)
24
25
26if __name__ == "__main__":
27    main()
./code/passes/fuse_add_relu-before.txt
 1graph(%self : __torch__.Foo,
 2      %x : Float(3, strides=[1], requires_grad=0, device=cpu),
 3      %y : Float(3, strides=[1], requires_grad=0, device=cpu)):
 4  %5 : int = prim::Constant[value=1]() # ./fuse_add_relu.py:8:0
 5  %input : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%x, %y, %5) # ./fuse_add_relu.py:8:0
 6  %a : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::relu(%input) # /Users/fangjun/py38/lib/python3.8/site-packages/torch/nn/functional.py:1457:0
 7  %8 : Long(requires_grad=0, device=cpu) = prim::Constant[value={10}]() # ./fuse_add_relu.py:9:0
 8  %9 : int = prim::Constant[value=1]() # ./fuse_add_relu.py:9:0
 9  %10 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%a, %8, %9) # ./fuse_add_relu.py:9:0
10  return (%10)
11
./code/passes/fuse_add_relu-after.txt
 1graph(%self : __torch__.Foo,
 2      %x : Float(3, strides=[1], requires_grad=0, device=cpu),
 3      %y : Float(3, strides=[1], requires_grad=0, device=cpu)):
 4  %5 : int = prim::Constant[value=1]() # ./fuse_add_relu.py:8:0
 5  %11 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::_add_relu(%x, %y, %5)
 6  %8 : Long(requires_grad=0, device=cpu) = prim::Constant[value={10}]() # ./fuse_add_relu.py:9:0
 7  %9 : int = prim::Constant[value=1]() # ./fuse_add_relu.py:9:0
 8  %10 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%11, %8, %9) # ./fuse_add_relu.py:9:0
 9  return (%10)
10

in_place relu_

./code/passes/fuse_add_relu_.py
 1#!/usr/bin/env python3
 2
 3import torch
 4
 5
 6class Foo(torch.nn.Module):
 7    def forward(self, x: torch.Tensor, y: torch.Tensor):
 8        a = torch.nn.functional.relu(x + y, inplace=True)
 9        return a + 10
10
11
12def main():
13    f = Foo()
14    m = torch.jit.trace(f, (torch.rand(3), torch.rand(3)))
15    g = m.graph
16
17    with open("fuse_add_relu_-before.txt", "w") as f:
18        print(g, file=f)
19
20    torch._C._jit_pass_fuse_add_relu(g)
21
22    with open("fuse_add_relu_-after.txt", "w") as f:
23        print(g, file=f)
24
25
26if __name__ == "__main__":
27    main()
./code/passes/fuse_add_relu_-before.txt
 1graph(%self : __torch__.Foo,
 2      %x : Float(3, strides=[1], requires_grad=0, device=cpu),
 3      %y : Float(3, strides=[1], requires_grad=0, device=cpu)):
 4  %5 : int = prim::Constant[value=1]() # ./fuse_add_relu_.py:8:0
 5  %input : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%x, %y, %5) # ./fuse_add_relu_.py:8:0
 6  %a : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::relu_(%input) # /Users/fangjun/py38/lib/python3.8/site-packages/torch/nn/functional.py:1455:0
 7  %8 : Long(requires_grad=0, device=cpu) = prim::Constant[value={10}]() # ./fuse_add_relu_.py:9:0
 8  %9 : int = prim::Constant[value=1]() # ./fuse_add_relu_.py:9:0
 9  %10 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%a, %8, %9) # ./fuse_add_relu_.py:9:0
10  return (%10)
11
./code/passes/fuse_add_relu_-after.txt
 1graph(%self : __torch__.Foo,
 2      %x : Float(3, strides=[1], requires_grad=0, device=cpu),
 3      %y : Float(3, strides=[1], requires_grad=0, device=cpu)):
 4  %5 : int = prim::Constant[value=1]() # ./fuse_add_relu_.py:8:0
 5  %11 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::_add_relu(%x, %y, %5)
 6  %8 : Long(requires_grad=0, device=cpu) = prim::Constant[value={10}]() # ./fuse_add_relu_.py:9:0
 7  %9 : int = prim::Constant[value=1]() # ./fuse_add_relu_.py:9:0
 8  %10 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%11, %8, %9) # ./fuse_add_relu_.py:9:0
 9  return (%10)
10

_jit_pass_fuse_linear

See https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/fuse_linear.cpp

./code/passes/fuse_linear.py
 1#!/usr/bin/env python3
 2
 3import torch
 4
 5
 6class Foo(torch.nn.Module):
 7    def forward(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor):
 8        return torch.matmul(x, w.t()) + b
 9
10
11def main():
12    f = Foo()
13    m = torch.jit.trace(f, (torch.rand(3), torch.rand(3, 3), torch.rand(3)))
14    g = m.graph
15
16    with open("fuse_linear-before.txt", "w") as f:
17        print(g, file=f)
18
19    torch._C._jit_pass_fuse_linear(g)
20
21    with open("fuse_linear-after.txt", "w") as f:
22        print(g, file=f)
23
24
25if __name__ == "__main__":
26    main()
./code/passes/fuse_linear-before.txt
 1graph(%self : __torch__.Foo,
 2      %x : Float(3, strides=[1], requires_grad=0, device=cpu),
 3      %w : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu),
 4      %b : Float(3, strides=[1], requires_grad=0, device=cpu)):
 5  %6 : Float(3, 3, strides=[1, 3], requires_grad=0, device=cpu) = aten::t(%w) # ./fuse_linear.py:8:0
 6  %7 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::matmul(%x, %6) # ./fuse_linear.py:8:0
 7  %8 : int = prim::Constant[value=1]() # ./fuse_linear.py:8:0
 8  %9 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%7, %b, %8) # ./fuse_linear.py:8:0
 9  return (%9)
10
./code/passes/fuse_linear-after.txt
 1graph(%self : __torch__.Foo,
 2      %x : Float(3, strides=[1], requires_grad=0, device=cpu),
 3      %w : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu),
 4      %b : Float(3, strides=[1], requires_grad=0, device=cpu)):
 5  %11 : Tensor? = prim::Constant()
 6  %13 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::linear(%x, %w, %11) # ./fuse_linear.py:8:0
 7  %8 : int = prim::Constant[value=1]() # ./fuse_linear.py:8:0
 8  %9 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::add(%13, %b, %8) # ./fuse_linear.py:8:0
 9  return (%9)
10