ivalue
./code/ivalue/main.cc
1#include "torch/script.h"
2
3static void TestVectorOfTensor() {
4 torch::jit::Module m("m");
5 m.define(R"(
6 def forward(self, x, y):
7 return [x, y]
8 )");
9 auto x = torch::tensor({1, 2, 3});
10 auto y = torch::tensor({4, 5, 6});
11 auto i = m.run_method("forward", x, y);
12
13 assert(i.tagKind() == "GenericList");
14
15 torch::ArrayRef<torch::IValue> tensor_list = i.toListRef();
16 TORCH_CHECK(torch::allclose(x, tensor_list[0].toTensor()));
17 TORCH_CHECK(torch::allclose(y, tensor_list[1].toTensor()));
18
19 torch::List<torch::IValue> k = i.toList();
20
21 torch::List<torch::Tensor> o =
22 c10::impl::toTypedList<torch::Tensor>(std::move(k));
23
24 TORCH_CHECK(torch::allclose(o[0], x));
25 TORCH_CHECK(torch::allclose(o[1], y));
26
27 std::vector<torch::Tensor> p = o.vec();
28 TORCH_CHECK(torch::allclose(p[0], x));
29 TORCH_CHECK(torch::allclose(p[1], y));
30}
31
32static void TestVectorOfTensor2() {
33 torch::jit::Module m("m");
34 m.define(R"(
35 def forward(self, x):
36 return [[x], [x,x]]
37 )");
38 auto x = torch::tensor({1, 2, 3});
39 auto i = m.run_method("forward", x);
40 TORCH_CHECK(i.tagKind() == "GenericList");
41
42 torch::List<torch::IValue> list = i.toList();
43 torch::Tensor a = list.get(0).toListRef()[0].toTensor();
44 TORCH_CHECK(torch::allclose(a, x));
45
46 std::vector<torch::Tensor> b =
47 c10::impl::toTypedList<torch::Tensor>(list.get(1).toList()).vec();
48 TORCH_CHECK(torch::allclose(b[0], x));
49 TORCH_CHECK(torch::allclose(b[1], x));
50}
51
52static void TestVectorOfTensor3() {
53 torch::jit::Module m("m");
54 m.define(R"(
55 def forward(self, x: List[torch.Tensor]):
56 return x[0] + x[1]
57 )");
58
59 std::vector<torch::Tensor> v;
60 v.push_back(torch::tensor({1, 2}));
61 v.push_back(torch::tensor({3, 4}));
62 c10::List<torch::Tensor> ilist(v);
63
64 c10::impl::GenericList generic_list = c10::impl::toList(ilist);
65
66 c10::List<torch::Tensor> l2 =
67 c10::impl::toTypedList<torch::Tensor>(generic_list);
68
69 TORCH_CHECK(torch::allclose(l2[0], v[0]));
70 TORCH_CHECK(torch::allclose(l2[1], v[1]));
71
72 auto r = m.run_method("forward", generic_list);
73 TORCH_CHECK(torch::allclose(r.toTensor(), v[0] + v[1]));
74
75 // Note: We can pass a vector directly
76 r = m.run_method("forward", v);
77 TORCH_CHECK(torch::allclose(r.toTensor(), v[0] + v[1]));
78
79 r = m.run_method("forward", ilist); // also OK
80 TORCH_CHECK(torch::allclose(r.toTensor(), v[0] + v[1]));
81}
82
83static void TestVectorOfTensor4() {
84 torch::jit::Module m("m");
85 m.define(R"(
86 def forward(self, x: Tuple[List[torch.Tensor]]):
87 return x[0][0] + x[0][1]
88 )");
89
90 std::vector<torch::Tensor> v;
91 v.push_back(torch::tensor({1, 2}));
92 v.push_back(torch::tensor({3, 4}));
93 auto t = torch::ivalue::Tuple::create(v);
94
95 auto r = m.run_method("forward", t);
96 TORCH_CHECK(torch::allclose(r.toTensor(), v[0] + v[1]));
97}
98
99static void TestVectorOfTensor5() {
100 torch::jit::Module m("m");
101 m.define(R"(
102 def forward(self, x: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]):
103 return x[0][0][0] + x[0][0][1] + x[1][0] + x[1][1]
104 )");
105
106 std::vector<torch::Tensor> v;
107 v.push_back(torch::tensor({1, 2}));
108 v.push_back(torch::tensor({3, 4}));
109
110 std::vector<std::vector<torch::Tensor>> vv;
111 vv.push_back(v);
112 vv.push_back(v);
113
114 auto t = torch::ivalue::Tuple::create(vv, v);
115
116 auto r = m.run_method("forward", t);
117 TORCH_CHECK(torch::allclose(r.toTensor(), v[0] + v[1] + v[0] + v[1]));
118}
119
120static void TestVectorOfTensor6() {
121 // List[List[Tensor]]
122 std::vector<torch::Tensor> v;
123 v.push_back(torch::tensor({1, 2}));
124 v.push_back(torch::tensor({3, 4}));
125
126 c10::List<torch::Tensor> ilist(v);
127 torch::IValue ivalue(ilist);
128 TORCH_CHECK(ivalue.tagKind() == "GenericList");
129
130 c10::List<c10::List<torch::Tensor>> ilist2(ilist);
131 ilist2.push_back(ilist);
132 ilist2.push_back(ilist);
133
134 torch::IValue ivalue2(ilist2);
135 TORCH_CHECK(ivalue2.tagKind() == "GenericList");
136
137 c10::List<torch::IValue> a0 = ivalue2.toList();
138 c10::List<c10::List<torch::Tensor>> a1 =
139 c10::impl::toTypedList<c10::List<torch::Tensor>>(a0);
140
141 c10::ArrayRef<torch::IValue> a = ivalue2.toListRef();
142
143 torch::List<torch::Tensor> b =
144 c10::impl::toTypedList<torch::Tensor>(a[0].toList());
145 for (int32_t i = 0; i != b.size(); ++i) {
146 std::cout << b[i] << "\n";
147 }
148 std::vector<std::vector<torch::Tensor>> v2{v};
149 torch::List<torch::List<torch::Tensor>> c;
150 for (auto k : v2) {
151 c10::List<torch::Tensor> dd{torch::ArrayRef<torch::Tensor>(k)};
152 c.push_back(std::move(dd));
153 }
154}
155
156int main() {
157 TestVectorOfTensor();
158 TestVectorOfTensor2();
159 TestVectorOfTensor3();
160 TestVectorOfTensor4();
161 TestVectorOfTensor5();
162 TestVectorOfTensor6();
163 return 0;
164}