Tensor

See

Common methods

./code/tensor/main.cc (Not recommended constructors)
 1static void TestCommonMethods() {
 2  torch::Tensor t = torch::rand({2, 3, 4});
 3
 4  TORCH_CHECK(t.dim() == 3);              // 3-d tensor
 5  TORCH_CHECK(t.ndimension() == t.dim()); // same
 6  TORCH_CHECK(t.numel() == 2 * 3 * 4);
 7  TORCH_CHECK(t.is_contiguous() == true);
 8  TORCH_CHECK(t.contiguous().is_contiguous() == true);
 9
10  t.fill_(10); // fill all entries to 0
11  t.zero_();   // zero out all entries
12
13  t = t.to(torch::kInt);
14  TORCH_CHECK(t.is_floating_point() == false);
15  TORCH_CHECK(t.is_signed() == true);
16
17  TORCH_CHECK(t.size(0) == 2);
18  TORCH_CHECK(t.size(1) == 3);
19  TORCH_CHECK(t.size(2) == 4);
20  TORCH_CHECK(t.sizes() == torch::ArrayRef<int64_t>({2, 3, 4}));
21
22  t = t.contiguous();
23  TORCH_CHECK(t.stride(0) == 3 * 4);
24  TORCH_CHECK(t.stride(1) == 4);
25  TORCH_CHECK(t.stride(2) == 1);
26  TORCH_CHECK(t.strides() == torch::ArrayRef<int64_t>({12, 4, 1}));
27
28  TORCH_CHECK(t.defined() == true);
29  {
30    torch::Tensor a;
31    TORCH_CHECK(a.defined() == false);
32    a = t;
33    TORCH_CHECK(a.defined() == true);
34    a.reset();
35    TORCH_CHECK(a.defined() == false);
36  }
37
38  t = t.to(torch::kShort);
39  TORCH_CHECK(t.itemsize() == sizeof(int16_t));
40  TORCH_CHECK(t.nbytes() == t.numel() * t.itemsize());
41  TORCH_CHECK(t.itemsize() == t.element_size()); // same
42
43  TORCH_CHECK(t.scalar_type() == torch::kShort);
44  TORCH_CHECK(t.dtype() == caffe2::TypeMeta::Make<int16_t>());
45  TORCH_CHECK(t.dtype().toScalarType() == torch::kShort);
46
47  TORCH_CHECK(t.device() == torch::Device("cpu"));
48  TORCH_CHECK(t.device() == torch::Device(torch::kCPU));
49
50  // Note: t.device() return an instance of torch::Device
51  // t.get_device() returns the device index.
52  TORCH_CHECK(t.get_device() == t.device().index());
53
54  TORCH_CHECK(t.is_cpu() == true);
55  TORCH_CHECK(t.is_cuda() == false);
56
57  t = t.to(torch::kInt);
58  int32_t *p = t.data_ptr<int32_t>();
59  p[0] = 100;
60
61  torch::TensorAccessor<int32_t, 3> acc = t.accessor<int32_t, 3>();
62  TORCH_CHECK(acc[0][0][0] == p[0]);
63  p[12] = -2;
64  TORCH_CHECK(acc[1][0][0] == -2);
65
66  acc[1][1][2] = 3;
67  TORCH_CHECK(*(p + 12 + 4 + 2) == 3);
68
69  t = t.to(torch::kFloat);
70  t.set_requires_grad(true);
71  TORCH_CHECK(t.requires_grad() == true);
72
73  t.set_requires_grad(false);
74  TORCH_CHECK(t.requires_grad() == false);
75
76  t = t.cuda();
77  TORCH_CHECK(t.device().type() == torch::kCUDA);
78  t = t.cpu();
79
80  torch::TensorOptions opts = t.options();
81  TORCH_CHECK(opts.device() == t.device());

slice

torch::slice
 1static void TestSlice() {
 2  auto t = torch::tensor({1, 2, 3, 4, 5}, torch::kInt);
 3  torch::TensorAccessor<int32_t, 1> acc = t.accessor<int32_t, 1>();
 4
 5  // t2 = t[1:3]
 6  torch::Tensor t2 = t.slice(/*dim*/ 0, /*start*/ 1,
 7                             /*end, exclusive*/ 3); // memory is shared
 8  torch::TensorAccessor<int32_t, 1> acc2 = t2.accessor<int32_t, 1>();
 9  TORCH_CHECK(acc2[0] == 2);
10  TORCH_CHECK(acc2[1] == 3);
11
12  acc2[0] = 10; // also changes t since the memory is shared
13  TORCH_CHECK(acc[1] == 10);
torch::slice
1void TestSlice2() {
2  auto t = torch::full({2, 3}, -1);
3  std::cout << t << "\n";
4
5  // set the last column to 0
6  t.index({torch::indexing::Slice(), -1}) = 0;
7  std::cout << t << "\n";
8}
torch::slice
 1void TestSlice3() {
 2  auto a = torch::tensor({1, 3, 10}, torch::kFloat);
 3  auto count = torch::zeros({10}, torch::kFloat);
 4
 5  count.slice(0, 1, 4).add_(a);
 6  // count: 0 1 3 10 0 0 0 0 0 0
 7  std::cout << "count: " << count << "\n";
 8  count.slice(0, 2, 5).add_(a);
 9  std::cout << "count: " << count << "\n";
10  // count: 0 1 4 13 10 0 0 0 0 0
11}

topk

torch::topk
 1// https://pytorch.org/docs/stable/generated/torch.topk.html
 2static void TestTopK() {
 3  auto t = torch::tensor({1, 0, 3, -1}, torch::kInt).to(torch::kFloat);
 4  torch::Tensor values, indexes;
 5  std::tie(values, indexes) =
 6      t.topk(/*k*/ 2, /*dim*/ 0, /*largest*/ true, /*sorted*/ true);
 7  auto values_acc = values.accessor<float, 1>();
 8  auto indexes_acc = indexes.accessor<int64_t, 1>(); // Note: it is int64_t
 9
10  TORCH_CHECK(values.numel() == 2); // k in topk is 2
11  TORCH_CHECK(values_acc[0] == 3);  // the largest value is 3, at t[2]
12  TORCH_CHECK(values_acc[1] == 1);  // the second largest value is 1, at t[0]
13                                    //
14  TORCH_CHECK(indexes_acc[0] == 2); // the largest value is t[2]
15  TORCH_CHECK(indexes_acc[1] == 0); // the second largest value is t[0]

floor_divide

torch::floor_divide
1static void TestFloorDivide() {
2  auto t = torch::tensor({1, 0, 3, 5, 9}, torch::kInt);
3  auto p = torch::floor_divide(t, 2);
4  auto acc = p.accessor<int32_t, 1>();
5  TORCH_CHECK(acc[0] == 1 / 2);
6  TORCH_CHECK(acc[1] == 0 / 2);
7  TORCH_CHECK(acc[2] == 3 / 2);
8  TORCH_CHECK(acc[3] == 5 / 2);
9  TORCH_CHECK(acc[4] == 9 / 2);

div

torch::div
 1// https://pytorch.org/docs/stable/generated/torch.div.html
 2static void TestDiv() {
 3  auto t = torch::tensor({1, 0, 3, 5, 9}, torch::kInt);
 4  // the rounding mode is supported in torch >= 1.8.0
 5  auto p = torch::div(t, 2, /*rounding_mode*/ "trunc");
 6  auto acc = p.accessor<int32_t, 1>();
 7  TORCH_CHECK(acc[0] == 1 / 2);
 8  TORCH_CHECK(acc[1] == 0 / 2);
 9  TORCH_CHECK(acc[2] == 3 / 2);
10  TORCH_CHECK(acc[3] == 5 / 2);
11  TORCH_CHECK(acc[4] == 9 / 2);

remainder

torch::remainder
1// https://pytorch.org/docs/1.6.0/generated/torch.remainder.html
2static void TestRemainder() {
3  auto t = torch::tensor({1, 3, 8}, torch::kInt);
4  auto p = torch::remainder(t, 3);
5  auto acc = p.accessor<int32_t, 1>();
6  TORCH_CHECK(acc[0] == 1);
7  TORCH_CHECK(acc[1] == 0);
8  TORCH_CHECK(acc[2] == 2);

empty

torch::empty
1static void TestEmpty() {
2  auto t = torch::empty({3}, torch::kInt);
3  TORCH_CHECK(t.scalar_type() == torch::kInt);
4  TORCH_CHECK(t.numel() == 3);

stack

torch::stack
 1static void TestStack() {
 2  auto t = torch::empty({6, 5}, torch::kInt);
 3  auto a = torch::stack({t, t}, /*dim*/ 1);
 4  TORCH_CHECK(a.sizes() == torch::ArrayRef<int64_t>({6, 2, 5}));
 5
 6  a = torch::stack({t, t}, /*dim*/ 0);
 7  TORCH_CHECK(a.sizes() == torch::ArrayRef<int64_t>({2, 6, 5}));
 8
 9  a = torch::stack({t, t}, /*dim*/ 2);
10  TORCH_CHECK(a.sizes() == torch::ArrayRef<int64_t>({6, 5, 2}));

unbind

torch::unbind
1static void TestUnbind() {
2  auto t = torch::empty({4, 6, 5}, torch::kInt);
3  std::vector<torch::Tensor> v = torch::unbind(t, /*dim*/ 1);
4  TORCH_CHECK(v.size() == t.size(1));
5  for (int32_t i = 0; i != v.size(); ++i) {
6    TORCH_CHECK(v[i].sizes() == torch::ArrayRef<int64_t>({4, 5}));
7  }

full

torch::full
1static void TestFull() {
2  auto t = torch::full({2, 3}, 10, torch::kInt);
3  const int32_t *p = t.data_ptr<int32_t>();
4  for (int32_t i = 0; i != t.numel(); ++i) {
5    TORCH_CHECK(p[i] == 10);
6  }

split

torch::split
 1static void TestSplit() {
 2  auto t = torch::arange(6).reshape({2, 3});
 3  std::vector<torch::Tensor> s = t.split(1);
 4  TORCH_CHECK(s.size() == 2);
 5  TORCH_CHECK(s[0].sizes() == torch::ArrayRef<int64_t>({1, 3}));
 6  TORCH_CHECK(s[1].sizes() == torch::ArrayRef<int64_t>({1, 3}));
 7
 8  s = t.split(1, /*dim*/ 1);
 9  TORCH_CHECK(s.size() == 3);
10  TORCH_CHECK(s[0].sizes() == torch::ArrayRef<int64_t>({2, 1}));
11  TORCH_CHECK(s[1].sizes() == torch::ArrayRef<int64_t>({2, 1}));
12  TORCH_CHECK(s[2].sizes() == torch::ArrayRef<int64_t>({2, 1}));

zeros

torch::zeros
1static void TestZeros() {
2  auto t = torch::zeros({2, 3}, torch::kFloat);

cat

torch::cat
1static void TestCat() {
2  auto t = torch::arange(24).reshape({2, 3, 4});
3  std::vector<torch::Tensor> v(5, t);
4  auto p = torch::cat(v, /*dim*/ 1);
5  TORCH_CHECK(p.sizes() == torch::ArrayRef<int64_t>({2, 3 * 5, 4}));

division

test division
 1static void TestDivision() {
 2  auto t = torch::arange(4).to(torch::kInt);
 3  auto b = t / 2;
 4  TORCH_CHECK(b.scalar_type() == torch::kFloat);
 5
 6  const float *p = b.data_ptr<float>();
 7  TORCH_CHECK(p[0] == 0 / 2.);
 8  TORCH_CHECK(p[1] == 1 / 2.);
 9  TORCH_CHECK(p[2] == 2 / 2.);
10  TORCH_CHECK(p[3] == 3 / 2.);
11
12  auto c = b.to(torch::kInt);
13
14  const int32_t *q = c.data_ptr<int32_t>();
15  TORCH_CHECK(q[0] == 0 / 2);
16  TORCH_CHECK(q[1] == 1 / 2);
17  TORCH_CHECK(q[2] == 2 / 2);
18  TORCH_CHECK(q[3] == 3 / 2);

default constructed

test default constructed
1void TestDefaultConstructed() {
2  torch::Tensor t;

copy

test rowwise copy
 1void TestCopy() {
 2  auto t0 = torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
 3  auto t = torch::empty({4, 3}, torch::kFloat);
 4
 5  t.slice(/*dim*/ 0, 0, 1) = t0.slice(0, 0, 1);
 6  t.slice(/*dim*/ 0, 1, 2) = t0.slice(0, 1, 2);
 7  t.slice(/*dim*/ 0, 2, 3) = t0.slice(0, 0, 1) + 10;
 8  t.slice(/*dim*/ 0, 3, 4) = t0.slice(0, 1, 2) + 10;
 9
10  std::cout << t << "\n";

default addmm

test default constructed
 1void TestAddmm() {
 2  std::cout << "---TestAddmm---\n";
 3  // 1 2 3
 4  // 4 5 6
 5  torch::Tensor m =
 6      torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
 7
 8  torch::Tensor v = torch::tensor({1, 1, -1}, torch::kFloat).unsqueeze(1);
 9
10  // 10 20 30
11  torch::Tensor a = torch::tensor({10, 20}, torch::kFloat).unsqueeze(1);
12  a.addmm_(m, v);
13  std::cout << a << "\n";
14  std::cout << a.squeeze(1) << "\n";

elementwise operation

test elementwise operation
 1void TestElementwiseOp() {
 2  std::cout << "---TestElementwiseOp---\n";
 3  torch::Tensor a = torch::tensor({1, 2, 3, 40}, torch::kFloat).reshape({2, 2});
 4  torch::Tensor b =
 5      torch::tensor({10, 20, 30, 4}, torch::kFloat).reshape({2, 2});
 6  torch::Tensor c = a * b;
 7  torch::Tensor d = a / b;
 8  torch::Tensor e = 1.0 / a;
 9  std::cout << c << "\n"; // [[10, 40], [90, 160]]
10  std::cout << d << "\n"; // [[0.1, 0.1], [0.1, 10]]
11  std::cout << e << "\n"; // [[1.0, 0.5], [0.3333, 0.0250]], float32

torch.roll

torch.roll
 1void TestRoll() {
 2  // 1 2 3
 3  // 4 5 6
 4  torch::Tensor a =
 5      torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
 6  torch::Tensor b = a.roll(1 /*shift right 1 column*/, 1 /*dim*/);
 7  // Now b is
 8  // 3 2 1
 9  // 6 4 5
10
11  // ----------
12  // 1 2 3 4
13  // 5 6 7 8
14  //
15  // 9 10 11 12
16  // 13 14 15 16
17  a = torch::tensor({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
18                    torch::kInt)
19          .reshape({2, 2, 4});
20  b = a.roll(1 /*shift right 1 column*/, 2 /*dim*/);
21  // now b is
22  // 4 1 2 3
23  // 8 5 6 7
24  //
25  // 12 9 10 11
26  // 16 13 14 15
27  std::cout << b;

torch.mean

torch.mean
  1  auto t = torch::arange(24).reshape({2, 3, 4});
  2  std::vector<torch::Tensor> v(5, t);
  3  auto p = torch::cat(v, /*dim*/ 1);
  4  TORCH_CHECK(p.sizes() == torch::ArrayRef<int64_t>({2, 3 * 5, 4}));
  5}
  6
  7static void TestDivision() {
  8  auto t = torch::arange(4).to(torch::kInt);
  9  auto b = t / 2;
 10  TORCH_CHECK(b.scalar_type() == torch::kFloat);
 11
 12  const float *p = b.data_ptr<float>();
 13  TORCH_CHECK(p[0] == 0 / 2.);
 14  TORCH_CHECK(p[1] == 1 / 2.);
 15  TORCH_CHECK(p[2] == 2 / 2.);
 16  TORCH_CHECK(p[3] == 3 / 2.);
 17
 18  auto c = b.to(torch::kInt);
 19
 20  const int32_t *q = c.data_ptr<int32_t>();
 21  TORCH_CHECK(q[0] == 0 / 2);
 22  TORCH_CHECK(q[1] == 1 / 2);
 23  TORCH_CHECK(q[2] == 2 / 2);
 24  TORCH_CHECK(q[3] == 3 / 2);
 25}
 26
 27void TestDefaultConstructed() {
 28  torch::Tensor t;
 29  TORCH_CHECK(t.size(0) == 0);
 30}
 31
 32void TestCopy() {
 33  auto t0 = torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
 34  auto t = torch::empty({4, 3}, torch::kFloat);
 35
 36  t.slice(/*dim*/ 0, 0, 1) = t0.slice(0, 0, 1);
 37  t.slice(/*dim*/ 0, 1, 2) = t0.slice(0, 1, 2);
 38  t.slice(/*dim*/ 0, 2, 3) = t0.slice(0, 0, 1) + 10;
 39  t.slice(/*dim*/ 0, 3, 4) = t0.slice(0, 1, 2) + 10;
 40
 41  std::cout << t << "\n";
 42}
 43
 44void TestAddmm() {
 45  std::cout << "---TestAddmm---\n";
 46  // 1 2 3
 47  // 4 5 6
 48  torch::Tensor m =
 49      torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
 50
 51  torch::Tensor v = torch::tensor({1, 1, -1}, torch::kFloat).unsqueeze(1);
 52
 53  // 10 20 30
 54  torch::Tensor a = torch::tensor({10, 20}, torch::kFloat).unsqueeze(1);
 55  a.addmm_(m, v);
 56  std::cout << a << "\n";
 57  std::cout << a.squeeze(1) << "\n";
 58}
 59
 60void TestElementwiseOp() {
 61  std::cout << "---TestElementwiseOp---\n";
 62  torch::Tensor a = torch::tensor({1, 2, 3, 40}, torch::kFloat).reshape({2, 2});
 63  torch::Tensor b =
 64      torch::tensor({10, 20, 30, 4}, torch::kFloat).reshape({2, 2});
 65  torch::Tensor c = a * b;
 66  torch::Tensor d = a / b;
 67  torch::Tensor e = 1.0 / a;
 68  std::cout << c << "\n"; // [[10, 40], [90, 160]]
 69  std::cout << d << "\n"; // [[0.1, 0.1], [0.1, 10]]
 70  std::cout << e << "\n"; // [[1.0, 0.5], [0.3333, 0.0250]], float32
 71}
 72
 73void TestRoll() {
 74  // 1 2 3
 75  // 4 5 6
 76  torch::Tensor a =
 77      torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
 78  torch::Tensor b = a.roll(1 /*shift right 1 column*/, 1 /*dim*/);
 79  // Now b is
 80  // 3 2 1
 81  // 6 4 5
 82
 83  // ----------
 84  // 1 2 3 4
 85  // 5 6 7 8
 86  //
 87  // 9 10 11 12
 88  // 13 14 15 16
 89  a = torch::tensor({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
 90                    torch::kInt)
 91          .reshape({2, 2, 4});
 92  b = a.roll(1 /*shift right 1 column*/, 2 /*dim*/);
 93  // now b is
 94  // 4 1 2 3
 95  // 8 5 6 7
 96  //
 97  // 12 9 10 11
 98  // 16 13 14 15
 99  std::cout << b;
100}
101
102void TestMean() {
103  // 1 2 3
104  // 4 5 6
105  torch::Tensor a =
106      torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
107  torch::Tensor b = a.mean(1 /*dim*/, true /*keep_dim*/);
108  std::cout << b;
109  // Now b is:
110  // 2
111  // 5
112  //----------
113  b = a.mean(1 /*dim*/, false /*keep_dim*/);
114  std::cout << b << "\n";

torch.slice

torch.slice
1void TestSlice2() {
2  auto t = torch::full({2, 3}, -1);
3  std::cout << t << "\n";
4
5  // set the last column to 0
6  t.index({torch::indexing::Slice(), -1}) = 0;
7  std::cout << t << "\n";

torch.as_strided

torch.as_strided
 1void TestAsStrided() {
 2  /*
 3   0 1 2 3 4 5 6 7 8 9
 4   */
 5
 6  /*
 7    0 1 2
 8    2 3 4
 9    4 5 6
10    6 7 8
11   */
12  torch::Tensor a = torch::arange(0, 10).to(torch::kFloat);
13  // (10 - 3) // 2 + 1 = 4
14  torch::Tensor b = a.as_strided({4, 3}, {2, 1});
15  std::cout << a << "\n";
16  std::cout << b << "\n";

torch.argmax

torch.argmax
 1void TestArgMax() {
 2  std::vector<float> v = {
 3      //
 4      0.2, 0.5, 0.1, 0.4,
 5      //
 6      0.9, 0.2, 0.0, 0.3,
 7      //
 8      0.8, 0.99, 0.1, 0.3
 9      //
10  };
11  torch::Tensor a = torch::from_blob(v.data(), {3, 4}, torch::kFloat);
12  std::cout << a << "\n"; // shape (3, 4)
13
14  torch::Tensor b = a.argmax(1);
15  std::cout << b << "\n"; // 1-d, shape: (3,)
16  // 1, 0, 1
17
18  // test 3-d
19  a = torch::from_blob(v.data(), {2, 3, 2}, torch::kFloat);
20  std::cout << a << "\n"; // shape (2, 3, 2)
21
22  b = a.argmax(-1);
23  std::cout << b << "\n"; // 1-d, shape: (2, 3)
24                          // 1, 1, 0
25                          // 1, 1, 1

torch.index

torch.index
 1void TestIndex() {
 2  // see https://pytorch.org/cppdocs/notes/tensor_indexing.html
 3  std::vector<float> v = {
 4      //
 5      0.2, 0.5, 0.1, 0.4,
 6      //
 7      0.9, 0.2, 0.0, 0.3,
 8      //
 9      0.8, 0.99, 0.1, 0.3
10      //
11  };
12  torch::Tensor a = torch::from_blob(v.data(), {3, 4}, torch::kFloat);
13  std::cout << a << "\n"; // shape (3, 4)
14
15  torch::Tensor b = a.index({0});
16  std::cout << b << "\n"; // 1-d, shape (4,) 0.2, 0.5, 0.1, 0.4
17                          //
18  b = a.index({2});
19  std::cout << b << "\n"; // 1-d, shape (4,) 0.8, 0.99, 0.1, 0.3

torch.nn.functional.pad

torch.nn.functional.pad
 1void TestPad() {
 2  std::vector<float> v = {0, 1, 2, 3, 4, 5};
 3  torch::Tensor a = torch::from_blob(v.data(), {2, 3}, torch::kFloat);
 4
 5  int32_t padding = 1;
 6
 7  // pad dim 1
 8#ifdef __ANDROID__
 9  auto padding_value = torch::zeros(
10      {a.size(0), padding}, torch::dtype(torch::kFloat).device(a.device()));
11
12  torch::Tensor b = torch::cat({a, padding_value}, 1);
13#else
14  torch::Tensor b = torch::nn::functional::pad(
15      a, torch::nn::functional::PadFuncOptions({0, padding})
16             .mode(torch::kConstant)
17             .value(0));
18#endif
19  /*
20   0 1 2 0
21   3 4 5 0
22   */
23  std::cout << b << "\n";
24
25  // pad dim 1
26#ifdef __ANDROID__
27  padding_value = torch::zeros({padding, a.size(1)},
28                               torch::dtype(torch::kFloat).device(a.device()));
29
30  torch::Tensor c = torch::cat({a, padding_value}, 0);
31#else
32  torch::Tensor c = torch::nn::functional::pad(
33      a, torch::nn::functional::PadFuncOptions({0, 0, 0, padding})
34             .mode(torch::kConstant)
35             .value(0));
36#endif
37  /*
38    0 1 2
39    3 4 5
40    0 0 0
41   */
42  std::cout << c << "\n";
43}

torch.index_put

  1void TestIndexPut() {
  2  // https://pytorch.org/cppdocs/notes/tensor_indexing.html#setter
  3  std::vector<float> v = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
  4  torch::Tensor a = torch::from_blob(v.data(), {12}, torch::kFloat);
  5  torch::Tensor indexes =
  6      torch::tensor({0, 1, 3, 5}, torch::dtype(torch::kLong));
  7
  8  a.index_put_({indexes}, 0);
  9  std::cout << a << "\n";
 10  /*
 11   0 0 2 0 4 0 6 7 8 9 10 11
 12
 13   */
 14
 15  // 2-d
 16
 17  v = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
 18  /*
 19   0 1 2
 20   3 4 5
 21   6 7 8
 22   9 10 11
 23   */
 24  a = torch::from_blob(v.data(), {4, 3}, torch::kFloat);
 25
 26  indexes = torch::tensor({0, 2}, torch::dtype(torch::kLong));
 27  a.index_put_({indexes}, 100);
 28
 29  /*
 30    100 100 100
 31    3   4   5
 32    100 100 100
 33    9   10  11
 34   */
 35  std::cout << a << "\n";
 36
 37  v = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
 38  a = torch::from_blob(v.data(), {4, 3}, torch::kFloat);
 39
 40  indexes = torch::tensor({0, 2}, torch::dtype(torch::kLong));
 41  a.index_put_({"...", indexes}, 99);
 42
 43  /*
 44   99 1 99
 45   99 4 99
 46   99 7 99
 47   99 10 99
 48   */
 49  std::cout << a << "\n";
 50
 51  v = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
 52  a = torch::from_blob(v.data(), {4, 3}, torch::kFloat);
 53
 54  indexes = torch::tensor({0, 2}, torch::dtype(torch::kLong));
 55  a.index_put_({"...", 1}, 6666);
 56  /*
 57   0 6666 2
 58   3 6666 5
 59   6 6666 8
 60   9 6666 11
 61   */
 62  std::cout << a << "\n";
 63
 64  // 3-d
 65  v = {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
 66       12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23};
 67  a = torch::from_blob(v.data(), {2, 3, 4}, torch::kFloat);
 68
 69  /*
 70    0 1 2 3
 71    4 5 6 7
 72    8 9 10 11
 73
 74    ---
 75
 76    12 13 14 15
 77    16 17 18 19
 78    20 21 22 23
 79   */
 80  std::cout << a << "\n";
 81
 82  indexes = torch::tensor({0}, torch::dtype(torch::kLong));
 83  a.index_put_({indexes}, 88); // dim 0
 84  std::cout << a << "\n";
 85  /*
 86    88 88 88 88
 87    88 88 88 88
 88    88 88 88 88
 89
 90    ---
 91
 92    12 13 14 15
 93    16 17 18 19
 94    20 21 22 23
 95   */
 96
 97  v = {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
 98       12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23};
 99  a = torch::from_blob(v.data(), {2, 3, 4}, torch::kFloat);
100  indexes = torch::tensor({0, 3}, torch::dtype(torch::kLong));
101  a.index_put_({"...", indexes}, 66); // the last dim, i.e., dim 2
102  std::cout << a << "\n";
103  /*
104    66 1 2 66
105    66 5 6 66
106    66 9 10 66
107
108    ---
109
110    66 13 14 66
111    66 17 18 66
112    66 21 22 66
113   */
114  v = {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
115       12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23};
116  a = torch::from_blob(v.data(), {2, 3, 4}, torch::kFloat);
117  indexes = torch::tensor({1}, torch::dtype(torch::kLong));
118  a.index_put_({torch::indexing::None, indexes}, 55); // dim 1
119  std::cout << a << "\n";
120  /*
121    0 1 2 3
122    4 5 6 7
123    8 9 10 11
124
125    ---
126
127    55 55 55 55
128    55 55 55 55
129    55 55 55 55
130   */
131
132  v = {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
133       12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23};
134  a = torch::from_blob(v.data(), {2, 3, 4}, torch::kFloat);
135  indexes = torch::tensor({1}, torch::dtype(torch::kLong));
136  a.index_put_(
137      {torch::indexing::Slice(torch::indexing::None, torch::indexing::None,
138                              torch::indexing::None),
139       torch::indexing::Slice(torch::indexing::None, torch::indexing::None,
140                              torch::indexing::None),
141       indexes},
142      33); // dim 2
143  std::cout << a << "\n";
144  /*
145    0 33 2 3
146    4 33 6 7
147    8 33 10 11
148
149    ---
150
151    12 33 14 15
152    16 33 18 19
153    20 33 22 23
154   */
155}

torch.nonzero

torch.nonzero
 1void TestNonZero() {
 2  auto t = torch::tensor({0, 2, 0, 0, 5, 0, 1}, torch::kInt);
 3  std::cout << t << "\n";       // 1-d, shape (7,)
 4  auto indexes = t.nonzero();   //
 5  std::cout << indexes << "\n"; // 2-d, shape (3, 1) 1 4 6
 6
 7  indexes = indexes.squeeze();
 8  auto v = t.index_select(0, indexes);
 9  std::cout << v << "\n";
10}
torch.nonzero
 1void TestNonZero2() {
 2  auto a = torch::tensor({1, 5, 5}, torch::dtype(torch::kInt)).reshape({3, 1});
 3  torch::Tensor b = (a == 5).nonzero();
 4  // 1 0
 5  // 2 0
 6  std::cout << b << "\n"; // shape is (2, 2)
 7
 8  torch::Tensor c = (a.squeeze() == 5).nonzero();
 9  // 1
10  // 2
11  std::cout << c << "\n"; // shape is (2, 1)
12
13  torch::Tensor d = (a == 50).nonzero();
14  std::cout << d.numel() << "\n";
15}