Hello
./code/ex1.py
1#!/usr/bin/env python3
2
3import torch
4import torch.nn as nn
5
6
7class Model(torch.nn.Module):
8 def __init__(self):
9 super().__init__()
10 self.fc = nn.Linear(1, 1)
11
12 def forward(self, x):
13 x = self.fc(x)
14 return x
15
16
17def main():
18 m = Model()
19 model_int8 = torch.quantization.quantize_dynamic(
20 model=m,
21 qconfig_spec={torch.nn.Linear},
22 dtype=torch.qint8,
23 )
24 print(model_int8)
25 print(model_int8.fc)
26 assert model_int8.fc.weight().is_quantized
27 assert model_int8.fc.weight().dtype == torch.qint8
28
29 assert model_int8.fc.bias().is_quantized is False
30 assert model_int8.fc.bias().dtype == torch.float32
31 assert isinstance(model_int8.fc, torch.nn.quantized.dynamic.Linear)
32 print(type(model_int8.fc))
33
34 x = torch.tensor([[1.0]], dtype=torch.float32)
35 y = m(x)
36 print(x, y) # tensor([[1.]]) tensor([[-1.2900]], grad_fn=<AddmmBackward0>)
37
38 qy = model_int8(x)
39 print(qy) # tensor([[-1.2931]])
40
41
42if __name__ == "__main__":
43 torch.manual_seed(20220723)
44 main()