type

See: - https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/jit_type_base.h - https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/jit_type.h - https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/type_ptr.h

torch::Type contains a member torch::TypeKind. torch::SharedType is a subclass of torch::Type and std::enabled_shared_from_this<torch::SharedType>.

// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/jit_type_base.h#L637
using TypePtr = SingletonOrSharedTypePtr<Type>;
./code/type/main.cc
 1#include "torch/script.h"
 2
 3static void TestTypeKind() {
 4  // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/jit_type_base.h
 5  torch::TypeKind k = torch::TypeKind::AnyType;
 6  TORCH_CHECK(torch::typeKindToString(k) == std::string("AnyType"));
 7
 8  // NamedType is not a member of TypeKind
 9}
10
11static void TestNumberType() {
12  // torch::NumberType::get() returns a static object!
13  // so p and q are actually the same
14  torch::NumberTypePtr p = torch::NumberType::get();
15  torch::NumberTypePtr q = torch::NumberType::get();
16
17  TORCH_CHECK(p.get() == q.get());
18
19  TORCH_CHECK(p->str() == "Scalar");
20  TORCH_CHECK(p->kind() == torch::NumberType::Kind);
21  TORCH_CHECK(p->kind() == torch::TypeKind::NumberType);
22}
23
24static void TestIntType() {
25  torch::IntTypePtr p = torch::IntType::get();
26  TORCH_CHECK(p->str() == "int");
27  TORCH_CHECK(p->kind() == torch::TypeKind::IntType);
28  TORCH_CHECK(p->kind() == torch::IntType::Kind);
29  TORCH_CHECK(p->isSubtypeOf(torch::NumberType::get()) == true);
30}
31
32static void TestFloatType() {
33  torch::FloatTypePtr p = torch::FloatType::get();
34  TORCH_CHECK(p->str() == "float");
35  TORCH_CHECK(p->kind() == torch::TypeKind::FloatType);
36  TORCH_CHECK(p->kind() == torch::FloatType::Kind);
37  TORCH_CHECK(p->isSubtypeOf(torch::NumberType::get()) == true);
38  TORCH_CHECK(p->isSubtypeOf(torch::IntType::get()) == false);
39}
40
41static void TestBoolType() {
42  torch::BoolTypePtr p = torch::BoolType::get();
43  TORCH_CHECK(p->str() == "bool");
44  TORCH_CHECK(p->kind() == torch::TypeKind::BoolType);
45  TORCH_CHECK(p->kind() == torch::BoolType::Kind);
46  TORCH_CHECK(p->isSubtypeOf(torch::NumberType::get()) == true);
47  TORCH_CHECK(p->isSubtypeOf(torch::IntType::get()) == false);
48}
49
50static void TestNamedType() {
51  // torch::Type is an abstract class!
52  //
53  // torch::NamedType is an abstract class!
54  //
55  // torch::NamedType t(torch::TypeKind::AnyType, "foo.bar"); // error
56  // TORCH_CHECK(t.name()->qualifiedName() == "foo.bar");
57}
58
59static void TestAnyType() {
60  torch::AnyTypePtr p = torch::AnyType::get();
61  TORCH_CHECK(p->Kind == torch::TypeKind::AnyType);
62  TORCH_CHECK(p->kind() == torch::TypeKind::AnyType);
63  TORCH_CHECK(p->str() == "Any");
64  TORCH_CHECK(p->requires_grad() == false);
65
66  TORCH_CHECK(p == torch::AnyType::get());
67
68  // available in newer versions of PyTorch
69  // TORCH_CHECK(p->equals(torch::AnyType::get()));
70
71  TORCH_CHECK(torch::toString(p) == "Any");
72}
73
74int main() {
75  TestTypeKind();
76  TestNumberType();
77  TestIntType();
78  TestFloatType();
79  TestNamedType();
80  TestAnyType();
81  return 0;
82}