torch.quantize_per_channel
See `<https://pytorch.org/docs/stable/generated/torch.quantize_per_channel.html#torch.quantize_per_channel>`_k
def test_quantize_per_channel_2d():
# (N, C)
a = torch.tensor(
[
[1, 2, 3],
[4, 5, 6],
],
dtype=torch.float32,
)
assert a.shape == (2, 3)
scales = torch.tensor([0.125, 0.25, 0.5])
# It will be converted to torch.int64 internally
zero_points = torch.tensor([10, 20, 30], dtype=torch.int32)
q = torch.quantize_per_channel(
input=a,
scales=scales,
zero_points=zero_points,
axis=1,
dtype=torch.qint8,
)
assert q.dtype == torch.qint8
assert q.q_per_channel_scales().dtype == torch.float64
assert torch.all(torch.eq(q.q_per_channel_scales(), scales))
assert q.q_per_channel_zero_points().dtype == torch.int64
assert torch.all(torch.eq(q.q_per_channel_zero_points(), zero_points))
assert str(q.qscheme()) == "torch.per_channel_affine"
assert q.q_per_channel_axis() == 1
i = q.int_repr()
expected_i = torch.tensor([[18, 28, 36], [42, 40, 42]], dtype=torch.int8)
assert i.dtype == torch.int8
assert torch.all(torch.eq(i, expected_i))
assert i[0][0].item() == a[0][0].item() / scales[0] + zero_points[0]
assert i[0][1].item() == a[0][1].item() / scales[1] + zero_points[1]
assert i[0][2].item() == a[0][2].item() / scales[2] + zero_points[2]
assert i[1][0].item() == a[1][0].item() / scales[0] + zero_points[0]
assert i[1][1].item() == a[1][1].item() / scales[1] + zero_points[1]
assert i[1][2].item() == a[1][2].item() / scales[2] + zero_points[2]
d = q.dequantize()
assert torch.all(torch.eq(d, a))
f = torch.dequantize(q)
assert torch.all(torch.eq(f, a))
# print(q)
"""
tensor([[1., 2., 3.],
[4., 5., 6.]], size=(2, 3), dtype=torch.qint8,
quantization_scheme=torch.per_channel_affine,
scale=tensor([0.1250, 0.2500, 0.5000], dtype=torch.float64),
zero_point=tensor([10, 20, 30]), axis=1)
"""