Tensor Creation
See
TensorDataContainer
Note
data is copied to the returned tensor!
See
Support the following data types:
From a
std::vector<T>
From a scalar
From an initializer list
From an
ArrayRef<T>
.
From std::vector
./code/tensor-creation/main.cc
1static void FromStdVecotr() {
2 torch::Tensor t1 = torch::tensor(std::vector<int32_t>{1, 2, 3});
3 TORCH_CHECK(t1.scalar_type() == torch::kLong);
4 t1 = t1.to(torch::kInt);
5 const int32_t *p1 = t1.data_ptr<int32_t>();
6 TORCH_CHECK(p1[0] == 1);
7 TORCH_CHECK(p1[1] == 2);
8 TORCH_CHECK(p1[2] == 3);
9
10 torch::Tensor t2 = torch::tensor(std::vector<float>{1, 2, 3});
11 TORCH_CHECK(t2.scalar_type() == torch::kFloat);
12
13 torch::Tensor t3 =
14 torch::tensor(std::vector<double>{1, 2, 3}, torch::kDouble);
15 TORCH_CHECK(t3.scalar_type() == torch::kDouble);
16
17 torch::Tensor t4 =
18 torch::tensor(std::vector<double>{1, 2, 3},
19 torch::dtype(torch::kDouble).device("cuda:0"));
20 TORCH_CHECK(t4.is_cuda());
21}
From scalar
./code/tensor-creation/main.cc
1static void FromScalar() {
2 torch::Tensor t = torch::tensor(3);
3 TORCH_CHECK(t.item<int64_t>() == 3);
4
5 torch::Tensor t2 = torch::tensor(0.5);
6 TORCH_CHECK(t2.scalar_type() == torch::kFloat);
7}
From initializer list
./code/tensor-creation/main.cc
1static void FromInitializerList() {
2 torch::Tensor t1 = torch::tensor({1, 2, 3});
3 torch::Tensor t2 = torch::tensor(std::vector<int32_t>{1, 2, 3});
4 TORCH_CHECK(torch::allclose(t1, t2));
5
6 torch::Tensor t3 = torch::tensor({{1, 2, 3}, {4, 5, 6}});
7 TORCH_CHECK(t3.dim() == 2);
8
9 torch::Tensor t4 = torch::tensor({1, 2, 3});
10 torch::Tensor t5 = torch::tensor({4, 5, 6});
11 TORCH_CHECK(torch::allclose(t3[0], t4));
12 TORCH_CHECK(torch::allclose(t3[1], t5));
13}
From ArrayRef
./code/tensor-creation/main.cc
1static void FromArrayRef() {
2 int32_t i[] = {1, 2, 3};
3 torch::ArrayRef<int32_t> a(i);
4 torch::Tensor t = torch::tensor(a);
5 // Data is copied to t
6
7 TORCH_CHECK(t[0].item<int64_t>(), 1);
8 TORCH_CHECK(t[1].item<int64_t>(), 2);
9 TORCH_CHECK(t[2].item<int64_t>(), 3);
10}