nn.LSTM
See https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
./code/lstm-test.py
1#!/usr/bin/env python3
2
3import torch
4import torch.nn as nn
5
6
7"""
8self.lstm = LSTM(
9 input_size=2,
10 hidden_size=5,
11 num_layers=1,
12 bias=True,
13 proj_size=2,
14)
15
16lstm.weight_ih_l0 [20, 2]
17lstm.weight_hh_l0 [20, 2]
18lstm.bias_ih_l0 [20]
19lstm.bias_hh_l0 [20]
20lstm.weight_hr_l0 [2, 5]
21"""
22
23
24class Foo(nn.Module):
25 def __init__(self):
26 super().__init__()
27 self.lstm = nn.LSTM(
28 input_size=3,
29 hidden_size=5,
30 num_layers=1,
31 bias=True,
32 proj_size=4,
33 )
34
35 def forward(self, x, h0, c0):
36 """
37 Args:
38 x:
39 (T, N, H_in), H_in is input dimension of x
40 h0:
41 (num_layers, N, H_out), H_out is proj_size
42 c0:
43 (num_layers, N, H_cell), H_cell is hidden_dim
44 """
45 y, (hx, cx) = self.lstm(x, (h0, c0))
46 return y, hx, cx
47
48
49@torch.no_grad()
50def main():
51 f = Foo()
52 dim_in = 3
53 dim_proj = 4
54 dim_hidden = 5
55 x = torch.rand(1, 1, dim_in)
56 h0 = torch.rand(1, 1, dim_proj)
57 c0 = torch.rand(1, 1, dim_hidden)
58 y, hx, cx = f(x, h0, c0)
59
60 w_ih = f.state_dict()["lstm.weight_ih_l0"]
61 w_hh = f.state_dict()["lstm.weight_hh_l0"]
62
63 b_ih = f.state_dict()["lstm.bias_ih_l0"]
64 b_hh = f.state_dict()["lstm.bias_hh_l0"]
65
66 w_hr = f.state_dict()["lstm.weight_hr_l0"]
67
68 w_ii, w_if, w_ig, w_io = w_ih.split(5, dim=0)
69 w_hi, w_hf, w_hg, w_ho = w_hh.split(5, dim=0)
70
71 b_ii, b_if, b_ig, b_io = b_ih.split(5, dim=0)
72 b_hi, b_hf, b_hg, b_ho = b_hh.split(5, dim=0)
73
74 print(y, hx, cx)
75 print(y.shape)
76 print(hx.shape)
77 print(cx.shape)
78
79 i_gate = (x @ w_ii.t() + b_ii + h0 @ w_hi.t() + b_hi).sigmoid()
80 f_gate = (x @ w_if.t() + b_if + h0 @ w_hf.t() + b_hf).sigmoid()
81 g_gate = (x @ w_ig.t() + b_ig + h0 @ w_hg.t() + b_hg).tanh()
82 o_gate = (x @ w_io.t() + b_io + h0 @ w_ho.t() + b_ho).sigmoid()
83 c = f_gate * c0 + i_gate * g_gate
84
85 h = o_gate * c.tanh()
86 h = h @ w_hr.t()
87
88 print(h, h, c)
89
90
91if __name__ == "__main__":
92 torch.manual_seed(20220903)
93 main()