torch.quantize_per_channel

See `<https://pytorch.org/docs/stable/generated/torch.quantize_per_channel.html#torch.quantize_per_channel>`_k

def test_quantize_per_channel_2d():
    # (N, C)
    a = torch.tensor(
        [
            [1, 2, 3],
            [4, 5, 6],
        ],
        dtype=torch.float32,
    )
    assert a.shape == (2, 3)
    scales = torch.tensor([0.125, 0.25, 0.5])

    # It will be converted to torch.int64 internally
    zero_points = torch.tensor([10, 20, 30], dtype=torch.int32)
    q = torch.quantize_per_channel(
        input=a,
        scales=scales,
        zero_points=zero_points,
        axis=1,
        dtype=torch.qint8,
    )
    assert q.dtype == torch.qint8

    assert q.q_per_channel_scales().dtype == torch.float64
    assert torch.all(torch.eq(q.q_per_channel_scales(), scales))

    assert q.q_per_channel_zero_points().dtype == torch.int64
    assert torch.all(torch.eq(q.q_per_channel_zero_points(), zero_points))

    assert str(q.qscheme()) == "torch.per_channel_affine"

    assert q.q_per_channel_axis() == 1

    i = q.int_repr()
    expected_i = torch.tensor([[18, 28, 36], [42, 40, 42]], dtype=torch.int8)
    assert i.dtype == torch.int8
    assert torch.all(torch.eq(i, expected_i))

    assert i[0][0].item() == a[0][0].item() / scales[0] + zero_points[0]
    assert i[0][1].item() == a[0][1].item() / scales[1] + zero_points[1]
    assert i[0][2].item() == a[0][2].item() / scales[2] + zero_points[2]

    assert i[1][0].item() == a[1][0].item() / scales[0] + zero_points[0]
    assert i[1][1].item() == a[1][1].item() / scales[1] + zero_points[1]
    assert i[1][2].item() == a[1][2].item() / scales[2] + zero_points[2]

    d = q.dequantize()
    assert torch.all(torch.eq(d, a))

    f = torch.dequantize(q)
    assert torch.all(torch.eq(f, a))

    #  print(q)
    """
    tensor([[1., 2., 3.],
            [4., 5., 6.]], size=(2, 3), dtype=torch.qint8,
           quantization_scheme=torch.per_channel_affine,
           scale=tensor([0.1250, 0.2500, 0.5000], dtype=torch.float64),
           zero_point=tensor([10, 20, 30]), axis=1)
    """