Hello

./code/ex1.py
 1#!/usr/bin/env python3
 2
 3import torch
 4import torch.nn as nn
 5
 6
 7class Model(torch.nn.Module):
 8    def __init__(self):
 9        super().__init__()
10        self.fc = nn.Linear(1, 1)
11
12    def forward(self, x):
13        x = self.fc(x)
14        return x
15
16
17def main():
18    m = Model()
19    model_int8 = torch.quantization.quantize_dynamic(
20        model=m,
21        qconfig_spec={torch.nn.Linear},
22        dtype=torch.qint8,
23    )
24    print(model_int8)
25    print(model_int8.fc)
26    assert model_int8.fc.weight().is_quantized
27    assert model_int8.fc.weight().dtype == torch.qint8
28
29    assert model_int8.fc.bias().is_quantized is False
30    assert model_int8.fc.bias().dtype == torch.float32
31    assert isinstance(model_int8.fc, torch.nn.quantized.dynamic.Linear)
32    print(type(model_int8.fc))
33
34    x = torch.tensor([[1.0]], dtype=torch.float32)
35    y = m(x)
36    print(x, y)  # tensor([[1.]]) tensor([[-1.2900]], grad_fn=<AddmmBackward0>)
37
38    qy = model_int8(x)
39    print(qy)  # tensor([[-1.2931]])
40
41
42if __name__ == "__main__":
43    torch.manual_seed(20220723)
44    main()