einsum
./code/einsum-test.py
1#!/usr/bin/env python3
2
3import torch
4
5
6# mat multiply
7def test_case1():
8 a = torch.rand(2, 3)
9 b = torch.rand(3, 5)
10 c = torch.matmul(a, b)
11 d = torch.einsum("ij,jk", a, b)
12 assert torch.allclose(c, d)
13
14 # ->ik is optional
15 e = torch.einsum("ij,jk->ik", a, b)
16 assert torch.allclose(c, e)
17
18 # also transpose the output
19 f = torch.einsum("ij,jk->ki", a, b)
20 assert torch.allclose(c.t(), f)
21 print(c)
22 print(d)
23 print(e)
24 print(f)
25
26
27# extract diagonal
28def test_case2():
29 a = torch.arange(25).reshape(5, 5)
30
31 # extract diagonal of a square matrix
32 b = torch.einsum("ii->i", a)
33 c = torch.diag(a, diagonal=0)
34 assert torch.equal(b, c)
35 print(a)
36 print(b)
37 print(c)
38 print(a.shape, b.shape, c.shape, a.dtype, b.dtype, c.dtype)
39
40
41# element-wise product
42def test_case3():
43 a = torch.rand(3, 5)
44 b = torch.rand(3, 5)
45
46 # element-wise product
47 c = torch.einsum("ij,ij->ij", a, b)
48
49 assert a.shape == b.shape == c.shape, (a.shape, b.shape, c.shape)
50
51 d = a * b
52
53 assert torch.allclose(c, d)
54
55 # a**3
56 e = torch.einsum("ij,ij,ij->ij", a, a, a)
57 f = a**3
58 assert torch.allclose(e, f)
59
60
61def my_einsum(a, b, c):
62 return torch.einsum("ijkl, ijlm, ikml -> ijkm", a, b, c)
63
64
65def test_einsum():
66 i = 2
67 j = 3
68 k = 4
69 l = 5
70 m = 6
71 a = torch.rand(i, j, k, l)
72 b = torch.rand(i, j, l, m)
73 c = torch.rand(i, k, m, l)
74
75 scores = torch.einsum("ijkl, ijlm, ikml -> ijkm", a, b, c)
76 m = torch.compile(my_einsum)
77 scores_2 = m(a, b, c)
78
79 assert torch.allclose(scores, scores_2)
80
81 m3 = torch.compile(torch.einsum)
82 scores_3 = m3("ijkl, ijlm, ikml -> ijkm", a, b, c)
83 assert torch.allclose(scores, scores_3)
84
85
86def main():
87 # test_case1()
88 # test_case2()
89 # test_case3()
90 test_einsum()
91
92
93if __name__ == "__main__":
94 torch.manual_seed(20250623)
95 main()