torch.gather

./code/gather.py
 1#!/usr/bin/env python3
 2
 3import torch
 4
 5
 6def main():
 7    left_context = 0
 8    N = 1
 9    T = 1
10    H = 5  # time1
11    W = 2 * H - 1 + left_context  # 2time1 - 1 + left_context
12    a = torch.randn(N, T, H, W)
13    a = torch.arange(N * T * H * W).reshape(N, T, H, W).contiguous()
14
15    if True:
16        rows = torch.arange(start=H - 1, end=-1, step=-1).unsqueeze(-1)
17        cols = torch.arange(H + left_context)
18        indexes = rows + cols
19
20        indexes = torch.tile(indexes, (N * T, 1))
21    else:
22        rows = torch.arange(start=H - 1, end=-1, step=-1)
23        cols = torch.arange(H + left_context)
24        rows = torch.cat([rows] * (N * T)).unsqueeze(-1)
25        indexes = rows + cols
26
27    print(indexes.shape)
28
29    ta = a.reshape(-1, W)
30
31    b = torch.gather(ta, dim=1, index=indexes)
32    b = b.reshape(N, T, H, -1)
33
34    c = a.as_strided(
35        (N, T, H, H + left_context),
36        (T * H * W, H * W, W - 1, 1),
37        storage_offset=H - 1,
38    )
39    assert torch.equal(b, c), (b, c)
40
41
42if __name__ == "__main__":
43    torch.manual_seed(20220727)
44    main()