nn.LSTM

See https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html

./code/lstm-test.py
 1#!/usr/bin/env python3
 2
 3import torch
 4import torch.nn as nn
 5
 6
 7"""
 8self.lstm = LSTM(
 9    input_size=2,
10    hidden_size=5,
11    num_layers=1,
12    bias=True,
13    proj_size=2,
14)
15
16lstm.weight_ih_l0 [20, 2]
17lstm.weight_hh_l0 [20, 2]
18lstm.bias_ih_l0 [20]
19lstm.bias_hh_l0 [20]
20lstm.weight_hr_l0 [2, 5]
21"""
22
23
24class Foo(nn.Module):
25    def __init__(self):
26        super().__init__()
27        self.lstm = nn.LSTM(
28            input_size=3,
29            hidden_size=5,
30            num_layers=1,
31            bias=True,
32            proj_size=4,
33        )
34
35    def forward(self, x, h0, c0):
36        """
37        Args:
38          x:
39            (T, N, H_in), H_in is input dimension of x
40          h0:
41            (num_layers, N, H_out), H_out is proj_size
42          c0:
43            (num_layers, N, H_cell), H_cell is hidden_dim
44        """
45        y, (hx, cx) = self.lstm(x, (h0, c0))
46        return y, hx, cx
47
48
49@torch.no_grad()
50def main():
51    f = Foo()
52    dim_in = 3
53    dim_proj = 4
54    dim_hidden = 5
55    x = torch.rand(1, 1, dim_in)
56    h0 = torch.rand(1, 1, dim_proj)
57    c0 = torch.rand(1, 1, dim_hidden)
58    y, hx, cx = f(x, h0, c0)
59
60    w_ih = f.state_dict()["lstm.weight_ih_l0"]
61    w_hh = f.state_dict()["lstm.weight_hh_l0"]
62
63    b_ih = f.state_dict()["lstm.bias_ih_l0"]
64    b_hh = f.state_dict()["lstm.bias_hh_l0"]
65
66    w_hr = f.state_dict()["lstm.weight_hr_l0"]
67
68    w_ii, w_if, w_ig, w_io = w_ih.split(5, dim=0)
69    w_hi, w_hf, w_hg, w_ho = w_hh.split(5, dim=0)
70
71    b_ii, b_if, b_ig, b_io = b_ih.split(5, dim=0)
72    b_hi, b_hf, b_hg, b_ho = b_hh.split(5, dim=0)
73
74    print(y, hx, cx)
75    print(y.shape)
76    print(hx.shape)
77    print(cx.shape)
78
79    i_gate = (x @ w_ii.t() + b_ii + h0 @ w_hi.t() + b_hi).sigmoid()
80    f_gate = (x @ w_if.t() + b_if + h0 @ w_hf.t() + b_hf).sigmoid()
81    g_gate = (x @ w_ig.t() + b_ig + h0 @ w_hg.t() + b_hg).tanh()
82    o_gate = (x @ w_io.t() + b_io + h0 @ w_ho.t() + b_ho).sigmoid()
83    c = f_gate * c0 + i_gate * g_gate
84
85    h = o_gate * c.tanh()
86    h = h @ w_hr.t()
87
88    print(h, h, c)
89
90
91if __name__ == "__main__":
92    torch.manual_seed(20220903)
93    main()