Observer

./code/observer/ex0.py
 1#!/usr/bin/env python3
 2
 3import torch
 4from torch.ao.quantization.observer import _with_args, MinMaxObserver
 5
 6
 7class Foo:
 8    def __init__(self, a=1, b=2):
 9        self.a = a
10        self.b = b
11
12
13def test_with_args():
14    Foo.with_args = classmethod(_with_args)
15    foo_builder = Foo.with_args(a=3).with_args(b=4).with_args(a=10)
16    f = foo_builder()
17    assert f.a == 10  # the last a=10 replaces the first a=3
18    assert f.b == 4
19
20    f2 = foo_builder()
21    assert id(f) != id(f2)
22
23
24def test_min_max_observer():
25    ob = MinMaxObserver(dtype=torch.qint8)
26    print(ob)  # MinMaxObserver(min_val=inf, max_val=-inf)
27
28    ob(torch.tensor([1, 2, 3]))
29    print(ob)  # MinMaxObserver(min_val=1.0, max_val=3.0)
30
31    ob(torch.tensor([-1, 30]))
32    print(ob)  # MinMaxObserver(min_val=-1.0, max_val=30.0)
33    scale, zero_point = ob.calculate_qparams()
34    print("scale", scale)  # scale tensor([0.1216])
35    print("zero_point", zero_point)  # zero_point tensor([-120], dtype=torch.int32)
36
37
38def main():
39    test_with_args()
40    test_min_max_observer()
41
42
43if __name__ == "__main__":
44    main()