Tensor
See
Common methods
./code/tensor/main.cc (Not recommended constructors)
1static void TestCommonMethods() {
2 torch::Tensor t = torch::rand({2, 3, 4});
3
4 TORCH_CHECK(t.dim() == 3); // 3-d tensor
5 TORCH_CHECK(t.ndimension() == t.dim()); // same
6 TORCH_CHECK(t.numel() == 2 * 3 * 4);
7 TORCH_CHECK(t.is_contiguous() == true);
8 TORCH_CHECK(t.contiguous().is_contiguous() == true);
9
10 t.fill_(10); // fill all entries to 0
11 t.zero_(); // zero out all entries
12
13 t = t.to(torch::kInt);
14 TORCH_CHECK(t.is_floating_point() == false);
15 TORCH_CHECK(t.is_signed() == true);
16
17 TORCH_CHECK(t.size(0) == 2);
18 TORCH_CHECK(t.size(1) == 3);
19 TORCH_CHECK(t.size(2) == 4);
20 TORCH_CHECK(t.sizes() == torch::ArrayRef<int64_t>({2, 3, 4}));
21
22 t = t.contiguous();
23 TORCH_CHECK(t.stride(0) == 3 * 4);
24 TORCH_CHECK(t.stride(1) == 4);
25 TORCH_CHECK(t.stride(2) == 1);
26 TORCH_CHECK(t.strides() == torch::ArrayRef<int64_t>({12, 4, 1}));
27
28 TORCH_CHECK(t.defined() == true);
29 {
30 torch::Tensor a;
31 TORCH_CHECK(a.defined() == false);
32 a = t;
33 TORCH_CHECK(a.defined() == true);
34 a.reset();
35 TORCH_CHECK(a.defined() == false);
36 }
37
38 t = t.to(torch::kShort);
39 TORCH_CHECK(t.itemsize() == sizeof(int16_t));
40 TORCH_CHECK(t.nbytes() == t.numel() * t.itemsize());
41 TORCH_CHECK(t.itemsize() == t.element_size()); // same
42
43 TORCH_CHECK(t.scalar_type() == torch::kShort);
44 TORCH_CHECK(t.dtype() == caffe2::TypeMeta::Make<int16_t>());
45 TORCH_CHECK(t.dtype().toScalarType() == torch::kShort);
46
47 TORCH_CHECK(t.device() == torch::Device("cpu"));
48 TORCH_CHECK(t.device() == torch::Device(torch::kCPU));
49
50 // Note: t.device() return an instance of torch::Device
51 // t.get_device() returns the device index.
52 TORCH_CHECK(t.get_device() == t.device().index());
53
54 TORCH_CHECK(t.is_cpu() == true);
55 TORCH_CHECK(t.is_cuda() == false);
56
57 t = t.to(torch::kInt);
58 int32_t *p = t.data_ptr<int32_t>();
59 p[0] = 100;
60
61 torch::TensorAccessor<int32_t, 3> acc = t.accessor<int32_t, 3>();
62 TORCH_CHECK(acc[0][0][0] == p[0]);
63 p[12] = -2;
64 TORCH_CHECK(acc[1][0][0] == -2);
65
66 acc[1][1][2] = 3;
67 TORCH_CHECK(*(p + 12 + 4 + 2) == 3);
68
69 t = t.to(torch::kFloat);
70 t.set_requires_grad(true);
71 TORCH_CHECK(t.requires_grad() == true);
72
73 t.set_requires_grad(false);
74 TORCH_CHECK(t.requires_grad() == false);
75
76 t = t.cuda();
77 TORCH_CHECK(t.device().type() == torch::kCUDA);
78 t = t.cpu();
79
80 torch::TensorOptions opts = t.options();
81 TORCH_CHECK(opts.device() == t.device());
slice
torch::slice
1static void TestSlice() {
2 auto t = torch::tensor({1, 2, 3, 4, 5}, torch::kInt);
3 torch::TensorAccessor<int32_t, 1> acc = t.accessor<int32_t, 1>();
4
5 // t2 = t[1:3]
6 torch::Tensor t2 = t.slice(/*dim*/ 0, /*start*/ 1,
7 /*end, exclusive*/ 3); // memory is shared
8 torch::TensorAccessor<int32_t, 1> acc2 = t2.accessor<int32_t, 1>();
9 TORCH_CHECK(acc2[0] == 2);
10 TORCH_CHECK(acc2[1] == 3);
11
12 acc2[0] = 10; // also changes t since the memory is shared
13 TORCH_CHECK(acc[1] == 10);
torch::slice
1void TestSlice2() {
2 auto t = torch::full({2, 3}, -1);
3 std::cout << t << "\n";
4
5 // set the last column to 0
6 t.index({torch::indexing::Slice(), -1}) = 0;
7 std::cout << t << "\n";
8}
torch::slice
1void TestSlice3() {
2 auto a = torch::tensor({1, 3, 10}, torch::kFloat);
3 auto count = torch::zeros({10}, torch::kFloat);
4
5 count.slice(0, 1, 4).add_(a);
6 // count: 0 1 3 10 0 0 0 0 0 0
7 std::cout << "count: " << count << "\n";
8 count.slice(0, 2, 5).add_(a);
9 std::cout << "count: " << count << "\n";
10 // count: 0 1 4 13 10 0 0 0 0 0
11}
topk
torch::topk
1// https://pytorch.org/docs/stable/generated/torch.topk.html
2static void TestTopK() {
3 auto t = torch::tensor({1, 0, 3, -1}, torch::kInt).to(torch::kFloat);
4 torch::Tensor values, indexes;
5 std::tie(values, indexes) =
6 t.topk(/*k*/ 2, /*dim*/ 0, /*largest*/ true, /*sorted*/ true);
7 auto values_acc = values.accessor<float, 1>();
8 auto indexes_acc = indexes.accessor<int64_t, 1>(); // Note: it is int64_t
9
10 TORCH_CHECK(values.numel() == 2); // k in topk is 2
11 TORCH_CHECK(values_acc[0] == 3); // the largest value is 3, at t[2]
12 TORCH_CHECK(values_acc[1] == 1); // the second largest value is 1, at t[0]
13 //
14 TORCH_CHECK(indexes_acc[0] == 2); // the largest value is t[2]
15 TORCH_CHECK(indexes_acc[1] == 0); // the second largest value is t[0]
floor_divide
torch::floor_divide
1static void TestFloorDivide() {
2 auto t = torch::tensor({1, 0, 3, 5, 9}, torch::kInt);
3 auto p = torch::floor_divide(t, 2);
4 auto acc = p.accessor<int32_t, 1>();
5 TORCH_CHECK(acc[0] == 1 / 2);
6 TORCH_CHECK(acc[1] == 0 / 2);
7 TORCH_CHECK(acc[2] == 3 / 2);
8 TORCH_CHECK(acc[3] == 5 / 2);
9 TORCH_CHECK(acc[4] == 9 / 2);
div
torch::div
1// https://pytorch.org/docs/stable/generated/torch.div.html
2static void TestDiv() {
3 auto t = torch::tensor({1, 0, 3, 5, 9}, torch::kInt);
4 // the rounding mode is supported in torch >= 1.8.0
5 auto p = torch::div(t, 2, /*rounding_mode*/ "trunc");
6 auto acc = p.accessor<int32_t, 1>();
7 TORCH_CHECK(acc[0] == 1 / 2);
8 TORCH_CHECK(acc[1] == 0 / 2);
9 TORCH_CHECK(acc[2] == 3 / 2);
10 TORCH_CHECK(acc[3] == 5 / 2);
11 TORCH_CHECK(acc[4] == 9 / 2);
remainder
torch::remainder
1// https://pytorch.org/docs/1.6.0/generated/torch.remainder.html
2static void TestRemainder() {
3 auto t = torch::tensor({1, 3, 8}, torch::kInt);
4 auto p = torch::remainder(t, 3);
5 auto acc = p.accessor<int32_t, 1>();
6 TORCH_CHECK(acc[0] == 1);
7 TORCH_CHECK(acc[1] == 0);
8 TORCH_CHECK(acc[2] == 2);
empty
torch::empty
1static void TestEmpty() {
2 auto t = torch::empty({3}, torch::kInt);
3 TORCH_CHECK(t.scalar_type() == torch::kInt);
4 TORCH_CHECK(t.numel() == 3);
stack
torch::stack
1static void TestStack() {
2 auto t = torch::empty({6, 5}, torch::kInt);
3 auto a = torch::stack({t, t}, /*dim*/ 1);
4 TORCH_CHECK(a.sizes() == torch::ArrayRef<int64_t>({6, 2, 5}));
5
6 a = torch::stack({t, t}, /*dim*/ 0);
7 TORCH_CHECK(a.sizes() == torch::ArrayRef<int64_t>({2, 6, 5}));
8
9 a = torch::stack({t, t}, /*dim*/ 2);
10 TORCH_CHECK(a.sizes() == torch::ArrayRef<int64_t>({6, 5, 2}));
unbind
torch::unbind
1static void TestUnbind() {
2 auto t = torch::empty({4, 6, 5}, torch::kInt);
3 std::vector<torch::Tensor> v = torch::unbind(t, /*dim*/ 1);
4 TORCH_CHECK(v.size() == t.size(1));
5 for (int32_t i = 0; i != v.size(); ++i) {
6 TORCH_CHECK(v[i].sizes() == torch::ArrayRef<int64_t>({4, 5}));
7 }
full
torch::full
1static void TestFull() {
2 auto t = torch::full({2, 3}, 10, torch::kInt);
3 const int32_t *p = t.data_ptr<int32_t>();
4 for (int32_t i = 0; i != t.numel(); ++i) {
5 TORCH_CHECK(p[i] == 10);
6 }
split
torch::split
1static void TestSplit() {
2 auto t = torch::arange(6).reshape({2, 3});
3 std::vector<torch::Tensor> s = t.split(1);
4 TORCH_CHECK(s.size() == 2);
5 TORCH_CHECK(s[0].sizes() == torch::ArrayRef<int64_t>({1, 3}));
6 TORCH_CHECK(s[1].sizes() == torch::ArrayRef<int64_t>({1, 3}));
7
8 s = t.split(1, /*dim*/ 1);
9 TORCH_CHECK(s.size() == 3);
10 TORCH_CHECK(s[0].sizes() == torch::ArrayRef<int64_t>({2, 1}));
11 TORCH_CHECK(s[1].sizes() == torch::ArrayRef<int64_t>({2, 1}));
12 TORCH_CHECK(s[2].sizes() == torch::ArrayRef<int64_t>({2, 1}));
zeros
torch::zeros
1static void TestZeros() {
2 auto t = torch::zeros({2, 3}, torch::kFloat);
cat
torch::cat
1static void TestCat() {
2 auto t = torch::arange(24).reshape({2, 3, 4});
3 std::vector<torch::Tensor> v(5, t);
4 auto p = torch::cat(v, /*dim*/ 1);
5 TORCH_CHECK(p.sizes() == torch::ArrayRef<int64_t>({2, 3 * 5, 4}));
division
test division
1static void TestDivision() {
2 auto t = torch::arange(4).to(torch::kInt);
3 auto b = t / 2;
4 TORCH_CHECK(b.scalar_type() == torch::kFloat);
5
6 const float *p = b.data_ptr<float>();
7 TORCH_CHECK(p[0] == 0 / 2.);
8 TORCH_CHECK(p[1] == 1 / 2.);
9 TORCH_CHECK(p[2] == 2 / 2.);
10 TORCH_CHECK(p[3] == 3 / 2.);
11
12 auto c = b.to(torch::kInt);
13
14 const int32_t *q = c.data_ptr<int32_t>();
15 TORCH_CHECK(q[0] == 0 / 2);
16 TORCH_CHECK(q[1] == 1 / 2);
17 TORCH_CHECK(q[2] == 2 / 2);
18 TORCH_CHECK(q[3] == 3 / 2);
default constructed
test default constructed
1void TestDefaultConstructed() {
2 torch::Tensor t;
copy
test rowwise copy
1void TestCopy() {
2 auto t0 = torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
3 auto t = torch::empty({4, 3}, torch::kFloat);
4
5 t.slice(/*dim*/ 0, 0, 1) = t0.slice(0, 0, 1);
6 t.slice(/*dim*/ 0, 1, 2) = t0.slice(0, 1, 2);
7 t.slice(/*dim*/ 0, 2, 3) = t0.slice(0, 0, 1) + 10;
8 t.slice(/*dim*/ 0, 3, 4) = t0.slice(0, 1, 2) + 10;
9
10 std::cout << t << "\n";
default addmm
test default constructed
1void TestAddmm() {
2 std::cout << "---TestAddmm---\n";
3 // 1 2 3
4 // 4 5 6
5 torch::Tensor m =
6 torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
7
8 torch::Tensor v = torch::tensor({1, 1, -1}, torch::kFloat).unsqueeze(1);
9
10 // 10 20 30
11 torch::Tensor a = torch::tensor({10, 20}, torch::kFloat).unsqueeze(1);
12 a.addmm_(m, v);
13 std::cout << a << "\n";
14 std::cout << a.squeeze(1) << "\n";
elementwise operation
test elementwise operation
1void TestElementwiseOp() {
2 std::cout << "---TestElementwiseOp---\n";
3 torch::Tensor a = torch::tensor({1, 2, 3, 40}, torch::kFloat).reshape({2, 2});
4 torch::Tensor b =
5 torch::tensor({10, 20, 30, 4}, torch::kFloat).reshape({2, 2});
6 torch::Tensor c = a * b;
7 torch::Tensor d = a / b;
8 torch::Tensor e = 1.0 / a;
9 std::cout << c << "\n"; // [[10, 40], [90, 160]]
10 std::cout << d << "\n"; // [[0.1, 0.1], [0.1, 10]]
11 std::cout << e << "\n"; // [[1.0, 0.5], [0.3333, 0.0250]], float32
torch.roll
torch.roll
1void TestRoll() {
2 // 1 2 3
3 // 4 5 6
4 torch::Tensor a =
5 torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
6 torch::Tensor b = a.roll(1 /*shift right 1 column*/, 1 /*dim*/);
7 // Now b is
8 // 3 2 1
9 // 6 4 5
10
11 // ----------
12 // 1 2 3 4
13 // 5 6 7 8
14 //
15 // 9 10 11 12
16 // 13 14 15 16
17 a = torch::tensor({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
18 torch::kInt)
19 .reshape({2, 2, 4});
20 b = a.roll(1 /*shift right 1 column*/, 2 /*dim*/);
21 // now b is
22 // 4 1 2 3
23 // 8 5 6 7
24 //
25 // 12 9 10 11
26 // 16 13 14 15
27 std::cout << b;
torch.mean
torch.mean
1 auto t = torch::arange(24).reshape({2, 3, 4});
2 std::vector<torch::Tensor> v(5, t);
3 auto p = torch::cat(v, /*dim*/ 1);
4 TORCH_CHECK(p.sizes() == torch::ArrayRef<int64_t>({2, 3 * 5, 4}));
5}
6
7static void TestDivision() {
8 auto t = torch::arange(4).to(torch::kInt);
9 auto b = t / 2;
10 TORCH_CHECK(b.scalar_type() == torch::kFloat);
11
12 const float *p = b.data_ptr<float>();
13 TORCH_CHECK(p[0] == 0 / 2.);
14 TORCH_CHECK(p[1] == 1 / 2.);
15 TORCH_CHECK(p[2] == 2 / 2.);
16 TORCH_CHECK(p[3] == 3 / 2.);
17
18 auto c = b.to(torch::kInt);
19
20 const int32_t *q = c.data_ptr<int32_t>();
21 TORCH_CHECK(q[0] == 0 / 2);
22 TORCH_CHECK(q[1] == 1 / 2);
23 TORCH_CHECK(q[2] == 2 / 2);
24 TORCH_CHECK(q[3] == 3 / 2);
25}
26
27void TestDefaultConstructed() {
28 torch::Tensor t;
29 TORCH_CHECK(t.size(0) == 0);
30}
31
32void TestCopy() {
33 auto t0 = torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
34 auto t = torch::empty({4, 3}, torch::kFloat);
35
36 t.slice(/*dim*/ 0, 0, 1) = t0.slice(0, 0, 1);
37 t.slice(/*dim*/ 0, 1, 2) = t0.slice(0, 1, 2);
38 t.slice(/*dim*/ 0, 2, 3) = t0.slice(0, 0, 1) + 10;
39 t.slice(/*dim*/ 0, 3, 4) = t0.slice(0, 1, 2) + 10;
40
41 std::cout << t << "\n";
42}
43
44void TestAddmm() {
45 std::cout << "---TestAddmm---\n";
46 // 1 2 3
47 // 4 5 6
48 torch::Tensor m =
49 torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
50
51 torch::Tensor v = torch::tensor({1, 1, -1}, torch::kFloat).unsqueeze(1);
52
53 // 10 20 30
54 torch::Tensor a = torch::tensor({10, 20}, torch::kFloat).unsqueeze(1);
55 a.addmm_(m, v);
56 std::cout << a << "\n";
57 std::cout << a.squeeze(1) << "\n";
58}
59
60void TestElementwiseOp() {
61 std::cout << "---TestElementwiseOp---\n";
62 torch::Tensor a = torch::tensor({1, 2, 3, 40}, torch::kFloat).reshape({2, 2});
63 torch::Tensor b =
64 torch::tensor({10, 20, 30, 4}, torch::kFloat).reshape({2, 2});
65 torch::Tensor c = a * b;
66 torch::Tensor d = a / b;
67 torch::Tensor e = 1.0 / a;
68 std::cout << c << "\n"; // [[10, 40], [90, 160]]
69 std::cout << d << "\n"; // [[0.1, 0.1], [0.1, 10]]
70 std::cout << e << "\n"; // [[1.0, 0.5], [0.3333, 0.0250]], float32
71}
72
73void TestRoll() {
74 // 1 2 3
75 // 4 5 6
76 torch::Tensor a =
77 torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
78 torch::Tensor b = a.roll(1 /*shift right 1 column*/, 1 /*dim*/);
79 // Now b is
80 // 3 2 1
81 // 6 4 5
82
83 // ----------
84 // 1 2 3 4
85 // 5 6 7 8
86 //
87 // 9 10 11 12
88 // 13 14 15 16
89 a = torch::tensor({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
90 torch::kInt)
91 .reshape({2, 2, 4});
92 b = a.roll(1 /*shift right 1 column*/, 2 /*dim*/);
93 // now b is
94 // 4 1 2 3
95 // 8 5 6 7
96 //
97 // 12 9 10 11
98 // 16 13 14 15
99 std::cout << b;
100}
101
102void TestMean() {
103 // 1 2 3
104 // 4 5 6
105 torch::Tensor a =
106 torch::tensor({1, 2, 3, 4, 5, 6}, torch::kFloat).reshape({2, 3});
107 torch::Tensor b = a.mean(1 /*dim*/, true /*keep_dim*/);
108 std::cout << b;
109 // Now b is:
110 // 2
111 // 5
112 //----------
113 b = a.mean(1 /*dim*/, false /*keep_dim*/);
114 std::cout << b << "\n";
torch.slice
torch.slice
1void TestSlice2() {
2 auto t = torch::full({2, 3}, -1);
3 std::cout << t << "\n";
4
5 // set the last column to 0
6 t.index({torch::indexing::Slice(), -1}) = 0;
7 std::cout << t << "\n";
torch.as_strided
torch.as_strided
1void TestAsStrided() {
2 /*
3 0 1 2 3 4 5 6 7 8 9
4 */
5
6 /*
7 0 1 2
8 2 3 4
9 4 5 6
10 6 7 8
11 */
12 torch::Tensor a = torch::arange(0, 10).to(torch::kFloat);
13 // (10 - 3) // 2 + 1 = 4
14 torch::Tensor b = a.as_strided({4, 3}, {2, 1});
15 std::cout << a << "\n";
16 std::cout << b << "\n";
torch.argmax
torch.argmax
1void TestArgMax() {
2 std::vector<float> v = {
3 //
4 0.2, 0.5, 0.1, 0.4,
5 //
6 0.9, 0.2, 0.0, 0.3,
7 //
8 0.8, 0.99, 0.1, 0.3
9 //
10 };
11 torch::Tensor a = torch::from_blob(v.data(), {3, 4}, torch::kFloat);
12 std::cout << a << "\n"; // shape (3, 4)
13
14 torch::Tensor b = a.argmax(1);
15 std::cout << b << "\n"; // 1-d, shape: (3,)
16 // 1, 0, 1
17
18 // test 3-d
19 a = torch::from_blob(v.data(), {2, 3, 2}, torch::kFloat);
20 std::cout << a << "\n"; // shape (2, 3, 2)
21
22 b = a.argmax(-1);
23 std::cout << b << "\n"; // 1-d, shape: (2, 3)
24 // 1, 1, 0
25 // 1, 1, 1
torch.index
torch.index
1void TestIndex() {
2 // see https://pytorch.org/cppdocs/notes/tensor_indexing.html
3 std::vector<float> v = {
4 //
5 0.2, 0.5, 0.1, 0.4,
6 //
7 0.9, 0.2, 0.0, 0.3,
8 //
9 0.8, 0.99, 0.1, 0.3
10 //
11 };
12 torch::Tensor a = torch::from_blob(v.data(), {3, 4}, torch::kFloat);
13 std::cout << a << "\n"; // shape (3, 4)
14
15 torch::Tensor b = a.index({0});
16 std::cout << b << "\n"; // 1-d, shape (4,) 0.2, 0.5, 0.1, 0.4
17 //
18 b = a.index({2});
19 std::cout << b << "\n"; // 1-d, shape (4,) 0.8, 0.99, 0.1, 0.3
torch.nn.functional.pad
torch.nn.functional.pad
1void TestPad() {
2 std::vector<float> v = {0, 1, 2, 3, 4, 5};
3 torch::Tensor a = torch::from_blob(v.data(), {2, 3}, torch::kFloat);
4
5 int32_t padding = 1;
6
7 // pad dim 1
8#ifdef __ANDROID__
9 auto padding_value = torch::zeros(
10 {a.size(0), padding}, torch::dtype(torch::kFloat).device(a.device()));
11
12 torch::Tensor b = torch::cat({a, padding_value}, 1);
13#else
14 torch::Tensor b = torch::nn::functional::pad(
15 a, torch::nn::functional::PadFuncOptions({0, padding})
16 .mode(torch::kConstant)
17 .value(0));
18#endif
19 /*
20 0 1 2 0
21 3 4 5 0
22 */
23 std::cout << b << "\n";
24
25 // pad dim 1
26#ifdef __ANDROID__
27 padding_value = torch::zeros({padding, a.size(1)},
28 torch::dtype(torch::kFloat).device(a.device()));
29
30 torch::Tensor c = torch::cat({a, padding_value}, 0);
31#else
32 torch::Tensor c = torch::nn::functional::pad(
33 a, torch::nn::functional::PadFuncOptions({0, 0, 0, padding})
34 .mode(torch::kConstant)
35 .value(0));
36#endif
37 /*
38 0 1 2
39 3 4 5
40 0 0 0
41 */
42 std::cout << c << "\n";
43}
torch.index_put
1void TestIndexPut() {
2 // https://pytorch.org/cppdocs/notes/tensor_indexing.html#setter
3 std::vector<float> v = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
4 torch::Tensor a = torch::from_blob(v.data(), {12}, torch::kFloat);
5 torch::Tensor indexes =
6 torch::tensor({0, 1, 3, 5}, torch::dtype(torch::kLong));
7
8 a.index_put_({indexes}, 0);
9 std::cout << a << "\n";
10 /*
11 0 0 2 0 4 0 6 7 8 9 10 11
12
13 */
14
15 // 2-d
16
17 v = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
18 /*
19 0 1 2
20 3 4 5
21 6 7 8
22 9 10 11
23 */
24 a = torch::from_blob(v.data(), {4, 3}, torch::kFloat);
25
26 indexes = torch::tensor({0, 2}, torch::dtype(torch::kLong));
27 a.index_put_({indexes}, 100);
28
29 /*
30 100 100 100
31 3 4 5
32 100 100 100
33 9 10 11
34 */
35 std::cout << a << "\n";
36
37 v = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
38 a = torch::from_blob(v.data(), {4, 3}, torch::kFloat);
39
40 indexes = torch::tensor({0, 2}, torch::dtype(torch::kLong));
41 a.index_put_({"...", indexes}, 99);
42
43 /*
44 99 1 99
45 99 4 99
46 99 7 99
47 99 10 99
48 */
49 std::cout << a << "\n";
50
51 v = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
52 a = torch::from_blob(v.data(), {4, 3}, torch::kFloat);
53
54 indexes = torch::tensor({0, 2}, torch::dtype(torch::kLong));
55 a.index_put_({"...", 1}, 6666);
56 /*
57 0 6666 2
58 3 6666 5
59 6 6666 8
60 9 6666 11
61 */
62 std::cout << a << "\n";
63
64 // 3-d
65 v = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
66 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23};
67 a = torch::from_blob(v.data(), {2, 3, 4}, torch::kFloat);
68
69 /*
70 0 1 2 3
71 4 5 6 7
72 8 9 10 11
73
74 ---
75
76 12 13 14 15
77 16 17 18 19
78 20 21 22 23
79 */
80 std::cout << a << "\n";
81
82 indexes = torch::tensor({0}, torch::dtype(torch::kLong));
83 a.index_put_({indexes}, 88); // dim 0
84 std::cout << a << "\n";
85 /*
86 88 88 88 88
87 88 88 88 88
88 88 88 88 88
89
90 ---
91
92 12 13 14 15
93 16 17 18 19
94 20 21 22 23
95 */
96
97 v = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
98 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23};
99 a = torch::from_blob(v.data(), {2, 3, 4}, torch::kFloat);
100 indexes = torch::tensor({0, 3}, torch::dtype(torch::kLong));
101 a.index_put_({"...", indexes}, 66); // the last dim, i.e., dim 2
102 std::cout << a << "\n";
103 /*
104 66 1 2 66
105 66 5 6 66
106 66 9 10 66
107
108 ---
109
110 66 13 14 66
111 66 17 18 66
112 66 21 22 66
113 */
114 v = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
115 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23};
116 a = torch::from_blob(v.data(), {2, 3, 4}, torch::kFloat);
117 indexes = torch::tensor({1}, torch::dtype(torch::kLong));
118 a.index_put_({torch::indexing::None, indexes}, 55); // dim 1
119 std::cout << a << "\n";
120 /*
121 0 1 2 3
122 4 5 6 7
123 8 9 10 11
124
125 ---
126
127 55 55 55 55
128 55 55 55 55
129 55 55 55 55
130 */
131
132 v = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
133 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23};
134 a = torch::from_blob(v.data(), {2, 3, 4}, torch::kFloat);
135 indexes = torch::tensor({1}, torch::dtype(torch::kLong));
136 a.index_put_(
137 {torch::indexing::Slice(torch::indexing::None, torch::indexing::None,
138 torch::indexing::None),
139 torch::indexing::Slice(torch::indexing::None, torch::indexing::None,
140 torch::indexing::None),
141 indexes},
142 33); // dim 2
143 std::cout << a << "\n";
144 /*
145 0 33 2 3
146 4 33 6 7
147 8 33 10 11
148
149 ---
150
151 12 33 14 15
152 16 33 18 19
153 20 33 22 23
154 */
155}
torch.nonzero
torch.nonzero
1void TestNonZero() {
2 auto t = torch::tensor({0, 2, 0, 0, 5, 0, 1}, torch::kInt);
3 std::cout << t << "\n"; // 1-d, shape (7,)
4 auto indexes = t.nonzero(); //
5 std::cout << indexes << "\n"; // 2-d, shape (3, 1) 1 4 6
6
7 indexes = indexes.squeeze();
8 auto v = t.index_select(0, indexes);
9 std::cout << v << "\n";
10}
torch.nonzero
1void TestNonZero2() {
2 auto a = torch::tensor({1, 5, 5}, torch::dtype(torch::kInt)).reshape({3, 1});
3 torch::Tensor b = (a == 5).nonzero();
4 // 1 0
5 // 2 0
6 std::cout << b << "\n"; // shape is (2, 2)
7
8 torch::Tensor c = (a.squeeze() == 5).nonzero();
9 // 1
10 // 2
11 std::cout << c << "\n"; // shape is (2, 1)
12
13 torch::Tensor d = (a == 50).nonzero();
14 std::cout << d.numel() << "\n";
15}