LSTM

  • LSTM in mlx does not support multi-layers

  • batch is always at the 0 dim, i.e., batch first is always True

  • also, does not support bidrectional lstm

  • there is only a single bias, if not None

https://github.com/ml-explore/mlx-examples/blob/main/encodec/encodec.py#L14 has a meta kernel implementation for LSTM.

./code/test_lstm_1_layer.py
  1#!/usr/bin/env python3
  2
  3import mlx.core as mx
  4import mlx.nn as nn
  5import numpy as np
  6import torch
  7from mlx.utils import tree_flatten, tree_map, tree_unflatten
  8
  9
 10class TorchLstm(torch.nn.Module):
 11    def __init__(self):
 12        super().__init__()
 13        self.lstm = torch.nn.LSTM(input_size=2, hidden_size=6, batch_first=True)
 14
 15    def forward(self, x, h=None, c=None):
 16        if h is not None:
 17            y, (h, c) = self.lstm(x, (h, c))
 18        else:
 19            y, (h, c) = self.lstm(x)
 20        return y, h, c
 21
 22
 23class MlxLstm(nn.Module):
 24    def __init__(self):
 25        super().__init__()
 26        self.lstm = nn.LSTM(input_size=2, hidden_size=6)
 27
 28    def forward(self, x, h=None, c=None):
 29        """
 30        Args:
 31          x: (N, L, H)
 32          h: None or (N, H)
 33          c: None or (N, H)
 34        Returns:
 35          y: (N, L, H)
 36          h: (N, H)
 37          c: (N, H)
 38        """
 39        if h is None:
 40            h, c = self.lstm(x)
 41        else:
 42            h, c = self.lstm(x, h, c)
 43        # now both h and c are (N, L, H)
 44
 45        return h, h[:, -1, :], c[:, -1, :]
 46
 47
 48def test_with_seq_len(torch_lstm, mx_lstm, N, T):
 49    x = torch.rand(N, T, 2)
 50    print("x", x.shape)
 51
 52    y, h, c = torch_lstm(x)
 53    assert y.shape == (N, T, 6), y.shape
 54    assert h.shape == (1, N, 6), h.shape  # 1 is number of layers
 55    assert c.shape == (1, N, 6), h.shape
 56
 57    ix = mx.array(x)
 58    mx_y, mx_h, mx_c = mx_lstm.forward(ix)
 59    assert mx_y.shape == (N, T, 6), mx_y.shape
 60    assert mx_h.shape == (N, 6), mx_h.shape
 61    assert mx_c.shape == (N, 6), mx_c.shape
 62
 63    assert torch.allclose(y, torch.from_numpy(np.array(mx_y)))
 64    assert torch.allclose(h[0], torch.from_numpy(np.array(mx_h)))
 65    assert torch.allclose(c[0], torch.from_numpy(np.array(mx_c)))
 66
 67    # now with states
 68    y, h, c = torch_lstm(x, h, c)
 69    assert y.shape == (N, T, 6), y.shape
 70    assert h.shape == (1, N, 6), h.shape  # 1 is number of layers
 71    assert c.shape == (1, N, 6), h.shape
 72
 73    mx_y, mx_h, mx_c = mx_lstm.forward(ix, mx_h, mx_c)
 74    assert mx_y.shape == (N, T, 6), mx_y.shape
 75    assert mx_h.shape == (N, 6), mx_h.shape
 76    assert mx_c.shape == (N, 6), mx_c.shape
 77
 78    assert torch.allclose(y, torch.from_numpy(np.array(mx_y)))
 79    assert torch.allclose(h[0], torch.from_numpy(np.array(mx_h)))
 80    assert torch.allclose(c[0], torch.from_numpy(np.array(mx_c)))
 81
 82
 83def test_single_layer():
 84    torch_lstm = TorchLstm()
 85    mx_lstm = MlxLstm()
 86
 87    # ['lstm.weight_ih_l0', 'lstm.weight_hh_l0', 'lstm.bias_ih_l0', 'lstm.bias_hh_l0']
 88    print(list(torch_lstm.state_dict().keys()))
 89    for k, v in torch_lstm.state_dict().items():
 90        print(k, v.shape)
 91    """
 92    lstm.weight_ih_l0 torch.Size([24, 6])
 93    lstm.weight_hh_l0 torch.Size([24, 6])
 94    lstm.bias_ih_l0 torch.Size([24])
 95    lstm.bias_hh_l0 torch.Size([24])
 96    """
 97
 98    # lstm
 99    print(mx_lstm.parameters().keys())
100
101    # dict_keys(['Wx', 'Wh', 'bias'])
102    print(mx_lstm.parameters()["lstm"].keys())
103    for k, v in mx_lstm.parameters()["lstm"].items():
104        print(k, v.shape)
105    """
106    Wx (24, 6)
107    Wh (24, 6)
108    bias (24,)
109    """
110    # convert
111    state_dict = torch_lstm.state_dict()
112    new_state_dict = dict()
113    for k, v in state_dict.items():
114        basename, pname = k.rsplit(".", 1)
115        if "lstm" not in basename:
116            # we convert only lstm in this for loop
117            continue
118        w_or_b, ih_or_hh, ln = pname.split("_")
119        if w_or_b == "weight":
120            new_name = "Wx" if ih_or_hh == "ih" else "Wh"
121        elif w_or_b == "bias" and ih_or_hh == "ih":
122            continue
123        else:
124            v = v + state_dict[k.replace("_hh_", "_ih_")]
125            new_name = "bias"
126        #  k = basename + "." + ln[1:] + "." + new_name # for multi-layer lstm
127        k = basename + "." + new_name  # for multi-layer lstm
128        new_state_dict[k] = v
129    #  print("here", new_state_dict)
130    #  print("here2", list(new_state_dict.items()))
131    #  print("here3", tree_unflatten(list(new_state_dict.items())))
132    #  print("here4", tree_flatten(tree_unflatten(list(new_state_dict.items()))))
133
134    # convert torch.tensor to mx.array
135    new_state_dict = tree_map(mx.array, new_state_dict)
136    mx_lstm.update(tree_unflatten(list(new_state_dict.items())))
137
138    for N in [1, 2, 3]:
139        for T in [1, 2, 3]:
140            test_with_seq_len(torch_lstm=torch_lstm, mx_lstm=mx_lstm, N=N, T=T)
141
142
143@torch.no_grad()
144def main():
145    test_single_layer()
146
147
148if __name__ == "__main__":
149    torch.manual_seed(20250716)
150    main()
./code/test_lstm_multi_layers
  1#!/usr/bin/env python3
  2
  3import mlx.core as mx
  4import mlx.nn as nn
  5import numpy as np
  6import torch
  7from mlx.utils import tree_flatten, tree_map, tree_unflatten
  8
  9
 10class TorchLstm(torch.nn.Module):
 11    def __init__(self):
 12        super().__init__()
 13        self.lstm = torch.nn.LSTM(
 14            input_size=2, hidden_size=5, num_layers=3, batch_first=True
 15        )
 16
 17    def forward(self, x, h=None, c=None):
 18        if h is not None:
 19            y, (h, c) = self.lstm(x, (h, c))
 20        else:
 21            y, (h, c) = self.lstm(x)
 22        return y, h, c
 23
 24
 25class MlxLstm(nn.Module):
 26    def __init__(self):
 27        super().__init__()
 28        self.lstm = []
 29        input_size = 2
 30        hidden_size = 5
 31        for i in range(3):
 32            in_size = input_size if i == 0 else hidden_size
 33            self.lstm.append(nn.LSTM(input_size=in_size, hidden_size=hidden_size))
 34
 35    def forward(self, x, states=None):
 36        """
 37        Args:
 38          x: (N, L, C)
 39          states: None or a list containing 2*num_layers tensors
 40            - states[2*i] is the h of the i-th layer, of shape (N, C)
 41            - states[2*i+1] is the c of the i-th layer, of shape (N, C)
 42        Returns:
 43          y: (N, L, C)
 44          states: a list containing 2*num_layers tensors
 45            - states[2*i] is the h of the i-th layer, of shape (N, C)
 46            - states[2*i+1] is the c of the i-th layer, of shape (N, C)
 47
 48        """
 49
 50        if states is None:
 51            states = [None] * (len(self.lstm) * 2)
 52
 53        new_states = []
 54        for i, layer in enumerate(self.lstm):
 55            h = states[2 * i]
 56            c = states[2 * i + 1]
 57            x, c = layer(x, h, c)
 58            new_states.append(x[:, -1, :])
 59            new_states.append(c[:, -1, :])
 60        return x, new_states
 61
 62
 63def test_with_seq_len(torch_lstm, mx_lstm, N, T):
 64    x = torch.rand(N, T, 2)
 65    print("x", x.shape)
 66
 67    y, h, c = torch_lstm(x)
 68    assert y.shape == (N, T, 5), y.shape
 69    assert h.shape == (3, N, 5), h.shape  # 3 is number of layers
 70    assert c.shape == (3, N, 5), h.shape
 71
 72    ix = mx.array(x)
 73    mx_y, states = mx_lstm.forward(ix)
 74
 75    assert mx_y.shape == (N, T, 5), mx_y.shape
 76    assert torch.allclose(y, torch.from_numpy(np.array(mx_y)))
 77
 78    for i in range(3):
 79        assert torch.allclose(h[i], torch.from_numpy(np.array(states[2 * i])))
 80        assert torch.allclose(c[i], torch.from_numpy(np.array(states[2 * i + 1])))
 81    print("mx_y", mx_y.shape, [s.shape for s in states])
 82
 83    # now with states
 84    y, h, c = torch_lstm(x, h, c)
 85    assert y.shape == (N, T, 5), y.shape
 86    assert h.shape == (3, N, 5), h.shape  # 3 is number of layers
 87    assert c.shape == (3, N, 5), h.shape
 88
 89    mx_y, states = mx_lstm.forward(ix, states)
 90    assert torch.allclose(y, torch.from_numpy(np.array(mx_y)))
 91
 92    for i in range(3):
 93        assert torch.allclose(h[i], torch.from_numpy(np.array(states[2 * i])))
 94        assert torch.allclose(c[i], torch.from_numpy(np.array(states[2 * i + 1])))
 95    print("mx_y", mx_y.shape, [s.shape for s in states])
 96
 97
 98def test_multi_layers():
 99    torch_lstm = TorchLstm()
100    mx_lstm = MlxLstm()
101    #  mx.eval(mx_lstm.parameters())
102
103    # ['lstm.weight_ih_l0', 'lstm.weight_hh_l0', 'lstm.bias_ih_l0',
104    # 'lstm.bias_hh_l0', 'lstm.weight_ih_l1', 'lstm.weight_hh_l1',
105    # 'lstm.bias_ih_l1', 'lstm.bias_hh_l1', 'lstm.weight_ih_l2',
106    # 'lstm.weight_hh_l2', 'lstm.bias_ih_l2', 'lstm.bias_hh_l2']
107    print(list(torch_lstm.state_dict().keys()))
108    for k, v in torch_lstm.state_dict().items():
109        print(k, v.shape)
110    """
111    lstm.weight_ih_l0 torch.Size([12, 2])
112    lstm.weight_hh_l0 torch.Size([12, 3])
113    lstm.bias_ih_l0 torch.Size([12])
114    lstm.bias_hh_l0 torch.Size([12])
115    lstm.weight_ih_l1 torch.Size([12, 3])
116    lstm.weight_hh_l1 torch.Size([12, 3])
117    lstm.bias_ih_l1 torch.Size([12])
118    lstm.bias_hh_l1 torch.Size([12])
119    lstm.weight_ih_l2 torch.Size([12, 3])
120    lstm.weight_hh_l2 torch.Size([12, 3])
121    lstm.bias_ih_l2 torch.Size([12])
122    lstm.bias_hh_l2 torch.Size([12])
123    """
124
125    # lstm
126    print(mx_lstm.parameters().keys())
127    assert isinstance(mx_lstm.parameters()["lstm"], list)
128    assert len(mx_lstm.parameters()["lstm"]) == 3, len(mx_lstm.parameters())
129
130    # dict_keys(['Wx', 'Wh', 'bias'])
131    print(mx_lstm.parameters()["lstm"][0].keys())
132    for k, v in mx_lstm.parameters()["lstm"][0].items():
133        print(k, v.shape)
134
135    """
136    Wx (12, 2)
137    Wh (12, 3)
138    bias (12,)
139    """
140    print(mx_lstm.parameters()["lstm"][1].keys())
141    for k, v in mx_lstm.parameters()["lstm"][1].items():
142        print(k, v.shape)
143    """
144    Wx (12, 3)
145    Wh (12, 3)
146    bias (12,)
147    """
148    # a list of tuple
149    # [('lstm.0.Wx', array(...)),
150    #  ('lstm.1.Wh', array(...)) ]
151    #  print(tree_flatten(mx_lstm.parameters()))
152
153    # convert
154    state_dict = torch_lstm.state_dict()
155    new_state_dict = dict()
156    for k, v in state_dict.items():
157        basename, pname = k.rsplit(".", 1)
158        if "lstm" not in basename:
159            # we convert only lstm in this for loop
160            continue
161        w_or_b, ih_or_hh, ln = pname.split("_")
162        if w_or_b == "weight":
163            new_name = "Wx" if ih_or_hh == "ih" else "Wh"
164        elif w_or_b == "bias" and ih_or_hh == "ih":
165            continue
166        else:
167            v = v + state_dict[k.replace("_hh_", "_ih_")]
168            new_name = "bias"
169        k = basename + "." + ln[1:] + "." + new_name
170        new_state_dict[k] = v
171
172    # convert torch.tensor to mx.array
173    new_state_dict = tree_map(mx.array, new_state_dict)
174    mx_lstm.update(tree_unflatten(list(new_state_dict.items())))
175    print("updated")
176
177    for N in [1, 2, 3]:
178        for T in [1, 2, 3]:
179            test_with_seq_len(torch_lstm=torch_lstm, mx_lstm=mx_lstm, N=N, T=T)
180
181
182@torch.no_grad()
183def main():
184    test_multi_layers()
185
186
187if __name__ == "__main__":
188    torch.manual_seed(20250716)
189    main()