Simple OPs

#!/usr/bin/env python3

from typing import Any, Dict

import torch
import torch.nn as nn
from onnxruntime.quantization import QuantType, quantize_dynamic

import onnx


def add_meta_data(filename: str, meta_data: Dict[str, Any]):
    """Add meta data to an ONNX model. It is changed in-place.

    Args:
      filename:
        Filename of the ONNX model to be changed.
      meta_data:
        Key-value pairs.
    """
    model = onnx.load(filename)
    print(model)

    while len(model.metadata_props):
        model.metadata_props.pop()

    for key, value in meta_data.items():
        meta = model.metadata_props.add()
        meta.key = key
        meta.value = str(value)

    onnx.save(model, filename)


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.my_linear = nn.Linear(3, 4)
        self.my_relu = torch.nn.ReLU()

    def forward(self, x):
        y = self.my_linear(x)
        y = self.my_relu(y)
        return y


def main():
    model = Model()
    print(list(model.named_parameters()))
    x = torch.rand(2, 5, 3)
    torch.onnx.export(
        model,
        x,
        "model.onnx",
        input_names=["x"],
        output_names=["y"],
        dynamic_axes={
            "x": {0: "N"},
            "y": {0: "N"},
        },
    )
    print(list(model.parameters()))
    meta_data = {
        "date": 20240822,
        "author": "me",
        "version": 10,
    }
    add_meta_data("model.onnx", meta_data)

    quantize_dynamic(
        model_input="model.onnx",
        model_output="model.int8.onnx",
        op_types_to_quantize=["MatMul", "Add"],
        weight_type=QuantType.QInt8,
    )
    w = model.my_linear.weight
    max_w = w.abs().max().item()
    scale = max_w * 2 / (255 - 1)
    print(scale)
    print((w / scale).to(torch.int8))

    print("----")

    m = onnx.load("model.onnx")
    with open("model.onnx.txt", "w") as f:
        f.write(str(m))

    m_int8 = onnx.load("model.int8.onnx")
    with open("model.int8.onnx.txt", "w") as f:
        f.write(str(m_int8))


if __name__ == "__main__":
    torch.manual_seed(20240820)
    main()

linear

nn.Linear has two parameters weight and bias.

In PyTorch, weight.shape is (out_channels, in_channels). But after exporting to onnx, model.onnx uses a shape (in_channels, out_channels), which is a transpose of the PyTorch's weight.

To use int8 symmetric quantization for:

tensor([[ 0.0928, -0.0400,  0.0666],
        [-0.5535, -0.2698,  0.4867],
        [ 0.5245, -0.3856,  0.3486],
        [ 0.4714,  0.2035, -0.4349]], requires_grad=True)

w = model.my_linear.weight
max_w = w.abs().max().item()
scale = max_w * 2 / 255
print(scale)
print((w / scale).to(torch.int8))

It prints:

0.0043414938683603325
tensor([[  21,   -9,   15],
        [-127,  -62,  112],
        [ 120,  -88,   80],
        [ 108,   46, -100]], dtype=torch.int8)

Note that for symmetric quantization, zero point is 0.

In model.onnx, the int8 weights are saved as:

[
    [
        21,
        -127,
        120,
        108
    ],
    [
        -9,
        -62,
        -88,
        47
    ],
    [
        15,
        112,
        80,
        -100
    ]
]

Before quantization, the float32 weights in model.onnx are:

[
    [
        0.09278726577758789,
        -0.5535404682159424,
        0.5244883298873901,
        0.4713994264602661
    ],
    [
        -0.04004639387130737,
        -0.26981082558631897,
        -0.3855680227279663,
        0.2035449743270874
    ],
    [
        0.06659042835235596,
        0.48669636249542236,
        0.34855782985687256,
        -0.43489107489585876
    ]
]

The file model.onnx.txt is given below:

The file model.int8.onnx.txt is given below:

Note that onnxruntime uses:

val_fp32 = scale * (val_quantized - zero_point)

The CPU implementation for DynamicLinearQuantizer can be found at