Conv2d
Input of pytorch NCHW.
Input of mlx NHCW.
Weight shape in PyTorch (out_channels, in_channels/groups, kernel_size[0], kernel_size[1])
Weight shape in mlx (out_channels, kernel_size[0], kernel_size[1], in_channels/groups)
Input of pytorch NCHW.
Input of mlx NHCW.
Weight shape in PyTorch (out_channels, in_channels/groups, kernel_size[0], kernel_size[1])
Weight shape in mlx (out_channels, kernel_size[0], kernel_size[1], in_channels/groups)