torch.gather
./code/gather.py
1#!/usr/bin/env python3
2
3import torch
4
5
6def main():
7 left_context = 0
8 N = 1
9 T = 1
10 H = 5 # time1
11 W = 2 * H - 1 + left_context # 2time1 - 1 + left_context
12 a = torch.randn(N, T, H, W)
13 a = torch.arange(N * T * H * W).reshape(N, T, H, W).contiguous()
14
15 if True:
16 rows = torch.arange(start=H - 1, end=-1, step=-1).unsqueeze(-1)
17 cols = torch.arange(H + left_context)
18 indexes = rows + cols
19
20 indexes = torch.tile(indexes, (N * T, 1))
21 else:
22 rows = torch.arange(start=H - 1, end=-1, step=-1)
23 cols = torch.arange(H + left_context)
24 rows = torch.cat([rows] * (N * T)).unsqueeze(-1)
25 indexes = rows + cols
26
27 print(indexes.shape)
28
29 ta = a.reshape(-1, W)
30
31 b = torch.gather(ta, dim=1, index=indexes)
32 b = b.reshape(N, T, H, -1)
33
34 c = a.as_strided(
35 (N, T, H, H + left_context),
36 (T * H * W, H * W, W - 1, 1),
37 storage_offset=H - 1,
38 )
39 assert torch.equal(b, c), (b, c)
40
41
42if __name__ == "__main__":
43 torch.manual_seed(20220727)
44 main()