einsum

./code/einsum-test.py
 1#!/usr/bin/env python3
 2
 3import torch
 4
 5
 6# mat multiply
 7def test_case1():
 8    a = torch.rand(2, 3)
 9    b = torch.rand(3, 5)
10    c = torch.matmul(a, b)
11    d = torch.einsum("ij,jk", a, b)
12    assert torch.allclose(c, d)
13
14    # ->ik is optional
15    e = torch.einsum("ij,jk->ik", a, b)
16    assert torch.allclose(c, e)
17
18    # also transpose the output
19    f = torch.einsum("ij,jk->ki", a, b)
20    assert torch.allclose(c.t(), f)
21    print(c)
22    print(d)
23    print(e)
24    print(f)
25
26
27# extract diagonal
28def test_case2():
29    a = torch.arange(25).reshape(5, 5)
30
31    # extract diagonal of a square matrix
32    b = torch.einsum("ii->i", a)
33    c = torch.diag(a, diagonal=0)
34    assert torch.equal(b, c)
35    print(a)
36    print(b)
37    print(c)
38    print(a.shape, b.shape, c.shape, a.dtype, b.dtype, c.dtype)
39
40
41# element-wise product
42def test_case3():
43    a = torch.rand(3, 5)
44    b = torch.rand(3, 5)
45
46    # element-wise product
47    c = torch.einsum("ij,ij->ij", a, b)
48
49    assert a.shape == b.shape == c.shape, (a.shape, b.shape, c.shape)
50
51    d = a * b
52
53    assert torch.allclose(c, d)
54
55    #  a**3
56    e = torch.einsum("ij,ij,ij->ij", a, a, a)
57    f = a**3
58    assert torch.allclose(e, f)
59
60
61def my_einsum(a, b, c):
62    return torch.einsum("ijkl, ijlm, ikml -> ijkm", a, b, c)
63
64
65def test_einsum():
66    i = 2
67    j = 3
68    k = 4
69    l = 5
70    m = 6
71    a = torch.rand(i, j, k, l)
72    b = torch.rand(i, j, l, m)
73    c = torch.rand(i, k, m, l)
74
75    scores = torch.einsum("ijkl, ijlm, ikml -> ijkm", a, b, c)
76    m = torch.compile(my_einsum)
77    scores_2 = m(a, b, c)
78
79    assert torch.allclose(scores, scores_2)
80
81    m3 = torch.compile(torch.einsum)
82    scores_3 = m3("ijkl, ijlm, ikml -> ijkm", a, b, c)
83    assert torch.allclose(scores, scores_3)
84
85
86def main():
87    #  test_case1()
88    #  test_case2()
89    #  test_case3()
90    test_einsum()
91
92
93if __name__ == "__main__":
94    torch.manual_seed(20250623)
95    main()