Observer
./code/observer/ex0.py
1#!/usr/bin/env python3
2
3import torch
4from torch.ao.quantization.observer import _with_args, MinMaxObserver
5
6
7class Foo:
8 def __init__(self, a=1, b=2):
9 self.a = a
10 self.b = b
11
12
13def test_with_args():
14 Foo.with_args = classmethod(_with_args)
15 foo_builder = Foo.with_args(a=3).with_args(b=4).with_args(a=10)
16 f = foo_builder()
17 assert f.a == 10 # the last a=10 replaces the first a=3
18 assert f.b == 4
19
20 f2 = foo_builder()
21 assert id(f) != id(f2)
22
23
24def test_min_max_observer():
25 ob = MinMaxObserver(dtype=torch.qint8)
26 print(ob) # MinMaxObserver(min_val=inf, max_val=-inf)
27
28 ob(torch.tensor([1, 2, 3]))
29 print(ob) # MinMaxObserver(min_val=1.0, max_val=3.0)
30
31 ob(torch.tensor([-1, 30]))
32 print(ob) # MinMaxObserver(min_val=-1.0, max_val=30.0)
33 scale, zero_point = ob.calculate_qparams()
34 print("scale", scale) # scale tensor([0.1216])
35 print("zero_point", zero_point) # zero_point tensor([-120], dtype=torch.int32)
36
37
38def main():
39 test_with_args()
40 test_min_max_observer()
41
42
43if __name__ == "__main__":
44 main()