torch::Device
See
DeviceType
torch::DeviceType
is defined as enum class Device: int8_t {...}
.
The most commonly used types are torch::DeviceType::CPU
and torch::DeviceType::CUDA
,
which are aliased to torch::kCPU
and torch::kCUDA
.
./code/device/main.cc
1void TestDeviceType() {
2 torch::DeviceType d = torch::kCPU;
3 std::ostringstream os;
4 os << d;
5 TORCH_CHECK(os.str() == "cpu");
6
7 TORCH_CHECK(DeviceTypeName(d /*,lower_case=false*/) == "CPU");
8 TORCH_CHECK(DeviceTypeName(d, /*lower_case*/ true) == "cpu");
Device
A torch::Device
class has two members: a torch::DeviceType
and an int8_t index
.
./code/device/main.cc (Constructors)
1void TestDeviceConstructorCPU() {
2 torch::Device d(torch::kCPU);
3 TORCH_CHECK(d.is_cpu() == true);
4 TORCH_CHECK(d.is_cuda() == false);
5 TORCH_CHECK(d.type() == torch::kCPU);
6 TORCH_CHECK(d.has_index() == false);
7 TORCH_CHECK(d.index() == -1);
8 TORCH_CHECK(d.str() == "cpu");
9}
10
11void TestDeviceConstructorCUDA() {
12 torch::Device d(torch::kCUDA, 3);
13 TORCH_CHECK(d.is_cpu() == false);
14 TORCH_CHECK(d.is_cuda() == true);
15 TORCH_CHECK(d.type() == torch::kCUDA);
16 TORCH_CHECK(d.has_index() == true);
17 TORCH_CHECK(d.index() == 3);
18 TORCH_CHECK(d.str() == "cuda:3");
19
20 d.set_index(2);
21 TORCH_CHECK(d.index() == 2);
22 TORCH_CHECK(d.str() == "cuda:2");
23
24 d = torch::Device("cpu");
25 TORCH_CHECK(d.is_cpu() == true);
26
27 d = torch::Device("CPU");
28 TORCH_CHECK(d.is_cpu() == true);
29
30 d = torch::Device("cuda:1");
31 TORCH_CHECK(d.is_cuda() == true);
32 TORCH_CHECK(d.index() == 1);
33
34 d = torch::Device("CUDA:1");
35 TORCH_CHECK(d.is_cuda() == true);
36 TORCH_CHECK(d.index() == 1);
37}