type
See: - https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/jit_type_base.h - https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/jit_type.h - https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/type_ptr.h
torch::Type
contains a member torch::TypeKind
.
torch::SharedType
is a subclass of torch::Type
and std::enabled_shared_from_this<torch::SharedType>
.
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/jit_type_base.h#L637
using TypePtr = SingletonOrSharedTypePtr<Type>;
./code/type/main.cc
1#include "torch/script.h"
2
3static void TestTypeKind() {
4 // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/jit_type_base.h
5 torch::TypeKind k = torch::TypeKind::AnyType;
6 TORCH_CHECK(torch::typeKindToString(k) == std::string("AnyType"));
7
8 // NamedType is not a member of TypeKind
9}
10
11static void TestNumberType() {
12 // torch::NumberType::get() returns a static object!
13 // so p and q are actually the same
14 torch::NumberTypePtr p = torch::NumberType::get();
15 torch::NumberTypePtr q = torch::NumberType::get();
16
17 TORCH_CHECK(p.get() == q.get());
18
19 TORCH_CHECK(p->str() == "Scalar");
20 TORCH_CHECK(p->kind() == torch::NumberType::Kind);
21 TORCH_CHECK(p->kind() == torch::TypeKind::NumberType);
22}
23
24static void TestIntType() {
25 torch::IntTypePtr p = torch::IntType::get();
26 TORCH_CHECK(p->str() == "int");
27 TORCH_CHECK(p->kind() == torch::TypeKind::IntType);
28 TORCH_CHECK(p->kind() == torch::IntType::Kind);
29 TORCH_CHECK(p->isSubtypeOf(torch::NumberType::get()) == true);
30}
31
32static void TestFloatType() {
33 torch::FloatTypePtr p = torch::FloatType::get();
34 TORCH_CHECK(p->str() == "float");
35 TORCH_CHECK(p->kind() == torch::TypeKind::FloatType);
36 TORCH_CHECK(p->kind() == torch::FloatType::Kind);
37 TORCH_CHECK(p->isSubtypeOf(torch::NumberType::get()) == true);
38 TORCH_CHECK(p->isSubtypeOf(torch::IntType::get()) == false);
39}
40
41static void TestBoolType() {
42 torch::BoolTypePtr p = torch::BoolType::get();
43 TORCH_CHECK(p->str() == "bool");
44 TORCH_CHECK(p->kind() == torch::TypeKind::BoolType);
45 TORCH_CHECK(p->kind() == torch::BoolType::Kind);
46 TORCH_CHECK(p->isSubtypeOf(torch::NumberType::get()) == true);
47 TORCH_CHECK(p->isSubtypeOf(torch::IntType::get()) == false);
48}
49
50static void TestNamedType() {
51 // torch::Type is an abstract class!
52 //
53 // torch::NamedType is an abstract class!
54 //
55 // torch::NamedType t(torch::TypeKind::AnyType, "foo.bar"); // error
56 // TORCH_CHECK(t.name()->qualifiedName() == "foo.bar");
57}
58
59static void TestAnyType() {
60 torch::AnyTypePtr p = torch::AnyType::get();
61 TORCH_CHECK(p->Kind == torch::TypeKind::AnyType);
62 TORCH_CHECK(p->kind() == torch::TypeKind::AnyType);
63 TORCH_CHECK(p->str() == "Any");
64 TORCH_CHECK(p->requires_grad() == false);
65
66 TORCH_CHECK(p == torch::AnyType::get());
67
68 // available in newer versions of PyTorch
69 // TORCH_CHECK(p->equals(torch::AnyType::get()));
70
71 TORCH_CHECK(torch::toString(p) == "Any");
72}
73
74int main() {
75 TestTypeKind();
76 TestNumberType();
77 TestIntType();
78 TestFloatType();
79 TestNamedType();
80 TestAnyType();
81 return 0;
82}