ivalue

./code/ivalue/main.cc
  1#include "torch/script.h"
  2
  3static void TestVectorOfTensor() {
  4  torch::jit::Module m("m");
  5  m.define(R"(
  6    def forward(self, x, y):
  7      return [x, y]
  8  )");
  9  auto x = torch::tensor({1, 2, 3});
 10  auto y = torch::tensor({4, 5, 6});
 11  auto i = m.run_method("forward", x, y);
 12
 13  assert(i.tagKind() == "GenericList");
 14
 15  torch::ArrayRef<torch::IValue> tensor_list = i.toListRef();
 16  TORCH_CHECK(torch::allclose(x, tensor_list[0].toTensor()));
 17  TORCH_CHECK(torch::allclose(y, tensor_list[1].toTensor()));
 18
 19  torch::List<torch::IValue> k = i.toList();
 20
 21  torch::List<torch::Tensor> o =
 22      c10::impl::toTypedList<torch::Tensor>(std::move(k));
 23
 24  TORCH_CHECK(torch::allclose(o[0], x));
 25  TORCH_CHECK(torch::allclose(o[1], y));
 26
 27  std::vector<torch::Tensor> p = o.vec();
 28  TORCH_CHECK(torch::allclose(p[0], x));
 29  TORCH_CHECK(torch::allclose(p[1], y));
 30}
 31
 32static void TestVectorOfTensor2() {
 33  torch::jit::Module m("m");
 34  m.define(R"(
 35    def forward(self, x):
 36      return [[x], [x,x]]
 37  )");
 38  auto x = torch::tensor({1, 2, 3});
 39  auto i = m.run_method("forward", x);
 40  TORCH_CHECK(i.tagKind() == "GenericList");
 41
 42  torch::List<torch::IValue> list = i.toList();
 43  torch::Tensor a = list.get(0).toListRef()[0].toTensor();
 44  TORCH_CHECK(torch::allclose(a, x));
 45
 46  std::vector<torch::Tensor> b =
 47      c10::impl::toTypedList<torch::Tensor>(list.get(1).toList()).vec();
 48  TORCH_CHECK(torch::allclose(b[0], x));
 49  TORCH_CHECK(torch::allclose(b[1], x));
 50}
 51
 52static void TestVectorOfTensor3() {
 53  torch::jit::Module m("m");
 54  m.define(R"(
 55    def forward(self, x: List[torch.Tensor]):
 56      return x[0] + x[1]
 57  )");
 58
 59  std::vector<torch::Tensor> v;
 60  v.push_back(torch::tensor({1, 2}));
 61  v.push_back(torch::tensor({3, 4}));
 62  c10::List<torch::Tensor> ilist(v);
 63
 64  c10::impl::GenericList generic_list = c10::impl::toList(ilist);
 65
 66  c10::List<torch::Tensor> l2 =
 67      c10::impl::toTypedList<torch::Tensor>(generic_list);
 68
 69  TORCH_CHECK(torch::allclose(l2[0], v[0]));
 70  TORCH_CHECK(torch::allclose(l2[1], v[1]));
 71
 72  auto r = m.run_method("forward", generic_list);
 73  TORCH_CHECK(torch::allclose(r.toTensor(), v[0] + v[1]));
 74
 75  // Note: We can pass a vector directly
 76  r = m.run_method("forward", v);
 77  TORCH_CHECK(torch::allclose(r.toTensor(), v[0] + v[1]));
 78
 79  r = m.run_method("forward", ilist); // also OK
 80  TORCH_CHECK(torch::allclose(r.toTensor(), v[0] + v[1]));
 81}
 82
 83static void TestVectorOfTensor4() {
 84  torch::jit::Module m("m");
 85  m.define(R"(
 86    def forward(self, x: Tuple[List[torch.Tensor]]):
 87      return x[0][0] + x[0][1]
 88  )");
 89
 90  std::vector<torch::Tensor> v;
 91  v.push_back(torch::tensor({1, 2}));
 92  v.push_back(torch::tensor({3, 4}));
 93  auto t = torch::ivalue::Tuple::create(v);
 94
 95  auto r = m.run_method("forward", t);
 96  TORCH_CHECK(torch::allclose(r.toTensor(), v[0] + v[1]));
 97}
 98
 99static void TestVectorOfTensor5() {
100  torch::jit::Module m("m");
101  m.define(R"(
102    def forward(self, x: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]):
103      return x[0][0][0] + x[0][0][1] + x[1][0] + x[1][1]
104  )");
105
106  std::vector<torch::Tensor> v;
107  v.push_back(torch::tensor({1, 2}));
108  v.push_back(torch::tensor({3, 4}));
109
110  std::vector<std::vector<torch::Tensor>> vv;
111  vv.push_back(v);
112  vv.push_back(v);
113
114  auto t = torch::ivalue::Tuple::create(vv, v);
115
116  auto r = m.run_method("forward", t);
117  TORCH_CHECK(torch::allclose(r.toTensor(), v[0] + v[1] + v[0] + v[1]));
118}
119
120static void TestVectorOfTensor6() {
121  // List[List[Tensor]]
122  std::vector<torch::Tensor> v;
123  v.push_back(torch::tensor({1, 2}));
124  v.push_back(torch::tensor({3, 4}));
125
126  c10::List<torch::Tensor> ilist(v);
127  torch::IValue ivalue(ilist);
128  TORCH_CHECK(ivalue.tagKind() == "GenericList");
129
130  c10::List<c10::List<torch::Tensor>> ilist2(ilist);
131  ilist2.push_back(ilist);
132  ilist2.push_back(ilist);
133
134  torch::IValue ivalue2(ilist2);
135  TORCH_CHECK(ivalue2.tagKind() == "GenericList");
136
137  c10::List<torch::IValue> a0 = ivalue2.toList();
138  c10::List<c10::List<torch::Tensor>> a1 =
139      c10::impl::toTypedList<c10::List<torch::Tensor>>(a0);
140
141  c10::ArrayRef<torch::IValue> a = ivalue2.toListRef();
142
143  torch::List<torch::Tensor> b =
144      c10::impl::toTypedList<torch::Tensor>(a[0].toList());
145  for (int32_t i = 0; i != b.size(); ++i) {
146    std::cout << b[i] << "\n";
147  }
148  std::vector<std::vector<torch::Tensor>> v2{v};
149  torch::List<torch::List<torch::Tensor>> c;
150  for (auto k : v2) {
151    c10::List<torch::Tensor> dd{torch::ArrayRef<torch::Tensor>(k)};
152    c.push_back(std::move(dd));
153  }
154}
155
156int main() {
157  TestVectorOfTensor();
158  TestVectorOfTensor2();
159  TestVectorOfTensor3();
160  TestVectorOfTensor4();
161  TestVectorOfTensor5();
162  TestVectorOfTensor6();
163  return 0;
164}