Tensor Creation

See

TensorDataContainer

Note

data is copied to the returned tensor!

See

Support the following data types:

  • From a std::vector<T>

  • From a scalar

  • From an initializer list

  • From an ArrayRef<T>.

From std::vector

./code/tensor-creation/main.cc
 1static void FromStdVecotr() {
 2  torch::Tensor t1 = torch::tensor(std::vector<int32_t>{1, 2, 3});
 3  TORCH_CHECK(t1.scalar_type() == torch::kLong);
 4  t1 = t1.to(torch::kInt);
 5  const int32_t *p1 = t1.data_ptr<int32_t>();
 6  TORCH_CHECK(p1[0] == 1);
 7  TORCH_CHECK(p1[1] == 2);
 8  TORCH_CHECK(p1[2] == 3);
 9
10  torch::Tensor t2 = torch::tensor(std::vector<float>{1, 2, 3});
11  TORCH_CHECK(t2.scalar_type() == torch::kFloat);
12
13  torch::Tensor t3 =
14      torch::tensor(std::vector<double>{1, 2, 3}, torch::kDouble);
15  TORCH_CHECK(t3.scalar_type() == torch::kDouble);
16
17  torch::Tensor t4 =
18      torch::tensor(std::vector<double>{1, 2, 3},
19                    torch::dtype(torch::kDouble).device("cuda:0"));
20  TORCH_CHECK(t4.is_cuda());
21}

From scalar

./code/tensor-creation/main.cc
1static void FromScalar() {
2  torch::Tensor t = torch::tensor(3);
3  TORCH_CHECK(t.item<int64_t>() == 3);
4
5  torch::Tensor t2 = torch::tensor(0.5);
6  TORCH_CHECK(t2.scalar_type() == torch::kFloat);
7}

From initializer list

./code/tensor-creation/main.cc
 1static void FromInitializerList() {
 2  torch::Tensor t1 = torch::tensor({1, 2, 3});
 3  torch::Tensor t2 = torch::tensor(std::vector<int32_t>{1, 2, 3});
 4  TORCH_CHECK(torch::allclose(t1, t2));
 5
 6  torch::Tensor t3 = torch::tensor({{1, 2, 3}, {4, 5, 6}});
 7  TORCH_CHECK(t3.dim() == 2);
 8
 9  torch::Tensor t4 = torch::tensor({1, 2, 3});
10  torch::Tensor t5 = torch::tensor({4, 5, 6});
11  TORCH_CHECK(torch::allclose(t3[0], t4));
12  TORCH_CHECK(torch::allclose(t3[1], t5));
13}

From ArrayRef

./code/tensor-creation/main.cc
 1static void FromArrayRef() {
 2  int32_t i[] = {1, 2, 3};
 3  torch::ArrayRef<int32_t> a(i);
 4  torch::Tensor t = torch::tensor(a);
 5  // Data is copied to t
 6
 7  TORCH_CHECK(t[0].item<int64_t>(), 1);
 8  TORCH_CHECK(t[1].item<int64_t>(), 2);
 9  TORCH_CHECK(t[2].item<int64_t>(), 3);
10}