LSTM
LSTM in mlx does not support multi-layers
batch is always at the 0 dim, i.e., batch first is always True
also, does not support bidrectional lstm
there is only a single bias, if not None
https://github.com/ml-explore/mlx-examples/blob/main/encodec/encodec.py#L14 has a meta kernel implementation for LSTM.
./code/test_lstm_1_layer.py
1#!/usr/bin/env python3
2
3import mlx.core as mx
4import mlx.nn as nn
5import numpy as np
6import torch
7from mlx.utils import tree_flatten, tree_map, tree_unflatten
8
9
10class TorchLstm(torch.nn.Module):
11 def __init__(self):
12 super().__init__()
13 self.lstm = torch.nn.LSTM(input_size=2, hidden_size=6, batch_first=True)
14
15 def forward(self, x, h=None, c=None):
16 if h is not None:
17 y, (h, c) = self.lstm(x, (h, c))
18 else:
19 y, (h, c) = self.lstm(x)
20 return y, h, c
21
22
23class MlxLstm(nn.Module):
24 def __init__(self):
25 super().__init__()
26 self.lstm = nn.LSTM(input_size=2, hidden_size=6)
27
28 def forward(self, x, h=None, c=None):
29 """
30 Args:
31 x: (N, L, H)
32 h: None or (N, H)
33 c: None or (N, H)
34 Returns:
35 y: (N, L, H)
36 h: (N, H)
37 c: (N, H)
38 """
39 if h is None:
40 h, c = self.lstm(x)
41 else:
42 h, c = self.lstm(x, h, c)
43 # now both h and c are (N, L, H)
44
45 return h, h[:, -1, :], c[:, -1, :]
46
47
48def test_with_seq_len(torch_lstm, mx_lstm, N, T):
49 x = torch.rand(N, T, 2)
50 print("x", x.shape)
51
52 y, h, c = torch_lstm(x)
53 assert y.shape == (N, T, 6), y.shape
54 assert h.shape == (1, N, 6), h.shape # 1 is number of layers
55 assert c.shape == (1, N, 6), h.shape
56
57 ix = mx.array(x)
58 mx_y, mx_h, mx_c = mx_lstm.forward(ix)
59 assert mx_y.shape == (N, T, 6), mx_y.shape
60 assert mx_h.shape == (N, 6), mx_h.shape
61 assert mx_c.shape == (N, 6), mx_c.shape
62
63 assert torch.allclose(y, torch.from_numpy(np.array(mx_y)))
64 assert torch.allclose(h[0], torch.from_numpy(np.array(mx_h)))
65 assert torch.allclose(c[0], torch.from_numpy(np.array(mx_c)))
66
67 # now with states
68 y, h, c = torch_lstm(x, h, c)
69 assert y.shape == (N, T, 6), y.shape
70 assert h.shape == (1, N, 6), h.shape # 1 is number of layers
71 assert c.shape == (1, N, 6), h.shape
72
73 mx_y, mx_h, mx_c = mx_lstm.forward(ix, mx_h, mx_c)
74 assert mx_y.shape == (N, T, 6), mx_y.shape
75 assert mx_h.shape == (N, 6), mx_h.shape
76 assert mx_c.shape == (N, 6), mx_c.shape
77
78 assert torch.allclose(y, torch.from_numpy(np.array(mx_y)))
79 assert torch.allclose(h[0], torch.from_numpy(np.array(mx_h)))
80 assert torch.allclose(c[0], torch.from_numpy(np.array(mx_c)))
81
82
83def test_single_layer():
84 torch_lstm = TorchLstm()
85 mx_lstm = MlxLstm()
86
87 # ['lstm.weight_ih_l0', 'lstm.weight_hh_l0', 'lstm.bias_ih_l0', 'lstm.bias_hh_l0']
88 print(list(torch_lstm.state_dict().keys()))
89 for k, v in torch_lstm.state_dict().items():
90 print(k, v.shape)
91 """
92 lstm.weight_ih_l0 torch.Size([24, 6])
93 lstm.weight_hh_l0 torch.Size([24, 6])
94 lstm.bias_ih_l0 torch.Size([24])
95 lstm.bias_hh_l0 torch.Size([24])
96 """
97
98 # lstm
99 print(mx_lstm.parameters().keys())
100
101 # dict_keys(['Wx', 'Wh', 'bias'])
102 print(mx_lstm.parameters()["lstm"].keys())
103 for k, v in mx_lstm.parameters()["lstm"].items():
104 print(k, v.shape)
105 """
106 Wx (24, 6)
107 Wh (24, 6)
108 bias (24,)
109 """
110 # convert
111 state_dict = torch_lstm.state_dict()
112 new_state_dict = dict()
113 for k, v in state_dict.items():
114 basename, pname = k.rsplit(".", 1)
115 if "lstm" not in basename:
116 # we convert only lstm in this for loop
117 continue
118 w_or_b, ih_or_hh, ln = pname.split("_")
119 if w_or_b == "weight":
120 new_name = "Wx" if ih_or_hh == "ih" else "Wh"
121 elif w_or_b == "bias" and ih_or_hh == "ih":
122 continue
123 else:
124 v = v + state_dict[k.replace("_hh_", "_ih_")]
125 new_name = "bias"
126 # k = basename + "." + ln[1:] + "." + new_name # for multi-layer lstm
127 k = basename + "." + new_name # for multi-layer lstm
128 new_state_dict[k] = v
129 # print("here", new_state_dict)
130 # print("here2", list(new_state_dict.items()))
131 # print("here3", tree_unflatten(list(new_state_dict.items())))
132 # print("here4", tree_flatten(tree_unflatten(list(new_state_dict.items()))))
133
134 # convert torch.tensor to mx.array
135 new_state_dict = tree_map(mx.array, new_state_dict)
136 mx_lstm.update(tree_unflatten(list(new_state_dict.items())))
137
138 for N in [1, 2, 3]:
139 for T in [1, 2, 3]:
140 test_with_seq_len(torch_lstm=torch_lstm, mx_lstm=mx_lstm, N=N, T=T)
141
142
143@torch.no_grad()
144def main():
145 test_single_layer()
146
147
148if __name__ == "__main__":
149 torch.manual_seed(20250716)
150 main()
./code/test_lstm_multi_layers
1#!/usr/bin/env python3
2
3import mlx.core as mx
4import mlx.nn as nn
5import numpy as np
6import torch
7from mlx.utils import tree_flatten, tree_map, tree_unflatten
8
9
10class TorchLstm(torch.nn.Module):
11 def __init__(self):
12 super().__init__()
13 self.lstm = torch.nn.LSTM(
14 input_size=2, hidden_size=5, num_layers=3, batch_first=True
15 )
16
17 def forward(self, x, h=None, c=None):
18 if h is not None:
19 y, (h, c) = self.lstm(x, (h, c))
20 else:
21 y, (h, c) = self.lstm(x)
22 return y, h, c
23
24
25class MlxLstm(nn.Module):
26 def __init__(self):
27 super().__init__()
28 self.lstm = []
29 input_size = 2
30 hidden_size = 5
31 for i in range(3):
32 in_size = input_size if i == 0 else hidden_size
33 self.lstm.append(nn.LSTM(input_size=in_size, hidden_size=hidden_size))
34
35 def forward(self, x, states=None):
36 """
37 Args:
38 x: (N, L, C)
39 states: None or a list containing 2*num_layers tensors
40 - states[2*i] is the h of the i-th layer, of shape (N, C)
41 - states[2*i+1] is the c of the i-th layer, of shape (N, C)
42 Returns:
43 y: (N, L, C)
44 states: a list containing 2*num_layers tensors
45 - states[2*i] is the h of the i-th layer, of shape (N, C)
46 - states[2*i+1] is the c of the i-th layer, of shape (N, C)
47
48 """
49
50 if states is None:
51 states = [None] * (len(self.lstm) * 2)
52
53 new_states = []
54 for i, layer in enumerate(self.lstm):
55 h = states[2 * i]
56 c = states[2 * i + 1]
57 x, c = layer(x, h, c)
58 new_states.append(x[:, -1, :])
59 new_states.append(c[:, -1, :])
60 return x, new_states
61
62
63def test_with_seq_len(torch_lstm, mx_lstm, N, T):
64 x = torch.rand(N, T, 2)
65 print("x", x.shape)
66
67 y, h, c = torch_lstm(x)
68 assert y.shape == (N, T, 5), y.shape
69 assert h.shape == (3, N, 5), h.shape # 3 is number of layers
70 assert c.shape == (3, N, 5), h.shape
71
72 ix = mx.array(x)
73 mx_y, states = mx_lstm.forward(ix)
74
75 assert mx_y.shape == (N, T, 5), mx_y.shape
76 assert torch.allclose(y, torch.from_numpy(np.array(mx_y)))
77
78 for i in range(3):
79 assert torch.allclose(h[i], torch.from_numpy(np.array(states[2 * i])))
80 assert torch.allclose(c[i], torch.from_numpy(np.array(states[2 * i + 1])))
81 print("mx_y", mx_y.shape, [s.shape for s in states])
82
83 # now with states
84 y, h, c = torch_lstm(x, h, c)
85 assert y.shape == (N, T, 5), y.shape
86 assert h.shape == (3, N, 5), h.shape # 3 is number of layers
87 assert c.shape == (3, N, 5), h.shape
88
89 mx_y, states = mx_lstm.forward(ix, states)
90 assert torch.allclose(y, torch.from_numpy(np.array(mx_y)))
91
92 for i in range(3):
93 assert torch.allclose(h[i], torch.from_numpy(np.array(states[2 * i])))
94 assert torch.allclose(c[i], torch.from_numpy(np.array(states[2 * i + 1])))
95 print("mx_y", mx_y.shape, [s.shape for s in states])
96
97
98def test_multi_layers():
99 torch_lstm = TorchLstm()
100 mx_lstm = MlxLstm()
101 # mx.eval(mx_lstm.parameters())
102
103 # ['lstm.weight_ih_l0', 'lstm.weight_hh_l0', 'lstm.bias_ih_l0',
104 # 'lstm.bias_hh_l0', 'lstm.weight_ih_l1', 'lstm.weight_hh_l1',
105 # 'lstm.bias_ih_l1', 'lstm.bias_hh_l1', 'lstm.weight_ih_l2',
106 # 'lstm.weight_hh_l2', 'lstm.bias_ih_l2', 'lstm.bias_hh_l2']
107 print(list(torch_lstm.state_dict().keys()))
108 for k, v in torch_lstm.state_dict().items():
109 print(k, v.shape)
110 """
111 lstm.weight_ih_l0 torch.Size([12, 2])
112 lstm.weight_hh_l0 torch.Size([12, 3])
113 lstm.bias_ih_l0 torch.Size([12])
114 lstm.bias_hh_l0 torch.Size([12])
115 lstm.weight_ih_l1 torch.Size([12, 3])
116 lstm.weight_hh_l1 torch.Size([12, 3])
117 lstm.bias_ih_l1 torch.Size([12])
118 lstm.bias_hh_l1 torch.Size([12])
119 lstm.weight_ih_l2 torch.Size([12, 3])
120 lstm.weight_hh_l2 torch.Size([12, 3])
121 lstm.bias_ih_l2 torch.Size([12])
122 lstm.bias_hh_l2 torch.Size([12])
123 """
124
125 # lstm
126 print(mx_lstm.parameters().keys())
127 assert isinstance(mx_lstm.parameters()["lstm"], list)
128 assert len(mx_lstm.parameters()["lstm"]) == 3, len(mx_lstm.parameters())
129
130 # dict_keys(['Wx', 'Wh', 'bias'])
131 print(mx_lstm.parameters()["lstm"][0].keys())
132 for k, v in mx_lstm.parameters()["lstm"][0].items():
133 print(k, v.shape)
134
135 """
136 Wx (12, 2)
137 Wh (12, 3)
138 bias (12,)
139 """
140 print(mx_lstm.parameters()["lstm"][1].keys())
141 for k, v in mx_lstm.parameters()["lstm"][1].items():
142 print(k, v.shape)
143 """
144 Wx (12, 3)
145 Wh (12, 3)
146 bias (12,)
147 """
148 # a list of tuple
149 # [('lstm.0.Wx', array(...)),
150 # ('lstm.1.Wh', array(...)) ]
151 # print(tree_flatten(mx_lstm.parameters()))
152
153 # convert
154 state_dict = torch_lstm.state_dict()
155 new_state_dict = dict()
156 for k, v in state_dict.items():
157 basename, pname = k.rsplit(".", 1)
158 if "lstm" not in basename:
159 # we convert only lstm in this for loop
160 continue
161 w_or_b, ih_or_hh, ln = pname.split("_")
162 if w_or_b == "weight":
163 new_name = "Wx" if ih_or_hh == "ih" else "Wh"
164 elif w_or_b == "bias" and ih_or_hh == "ih":
165 continue
166 else:
167 v = v + state_dict[k.replace("_hh_", "_ih_")]
168 new_name = "bias"
169 k = basename + "." + ln[1:] + "." + new_name
170 new_state_dict[k] = v
171
172 # convert torch.tensor to mx.array
173 new_state_dict = tree_map(mx.array, new_state_dict)
174 mx_lstm.update(tree_unflatten(list(new_state_dict.items())))
175 print("updated")
176
177 for N in [1, 2, 3]:
178 for T in [1, 2, 3]:
179 test_with_seq_len(torch_lstm=torch_lstm, mx_lstm=mx_lstm, N=N, T=T)
180
181
182@torch.no_grad()
183def main():
184 test_multi_layers()
185
186
187if __name__ == "__main__":
188 torch.manual_seed(20250716)
189 main()