Conv2d

Input of pytorch NCHW.

Input of mlx NHCW.

Weight shape in PyTorch (out_channels, in_channels/groups, kernel_size[0], kernel_size[1])

Weight shape in mlx (out_channels, kernel_size[0], kernel_size[1], in_channels/groups)