random
Note that, we have to save the random generator state for the CPU and GPU.
./code/random_test.py
1#!/usr/bin/env python3
2
3import torch
4
5
6class Foo(torch.nn.Module):
7 def __init__(self):
8 super().__init__()
9
10 def forward(self, x):
11 y = torch.rand(*x.shape).to(x.device)
12 return torch.nn.functional.dropout(x + y, p=0.3)
13
14
15def test_cpu():
16 f = Foo()
17
18 x1 = torch.rand(3, 5)
19 x2 = x1.clone()
20 x3 = x1.clone()
21
22 cpu_state = torch.get_rng_state()
23 y1 = f(x1)
24 print("y1", y1, y1.sum(), y1.mean())
25 with torch.random.fork_rng(devices=[]):
26 y2 = f(x2)
27
28 with torch.random.fork_rng(devices=[]):
29 torch.set_rng_state(cpu_state)
30 y3 = f(x3)
31
32 print("y2", y2, y2.sum(), y2.mean())
33 print("y3", y3, y3.sum(), y3.mean())
34
35
36def test_cuda():
37 f = Foo()
38 device = torch.device("cuda", 0)
39
40 x1 = torch.rand(3, 5).to(device)
41 x2 = x1.clone()
42 x3 = x1.clone()
43
44 cpu_state = torch.get_rng_state()
45 cuda_state = torch.cuda.get_rng_state(device)
46 print(
47 "cuda_state",
48 type(cuda_state),
49 cuda_state.device,
50 cuda_state.dtype,
51 cuda_state.shape,
52 )
53
54 y1 = f(x1)
55 print("y1", y1, y1.sum(), y1.mean())
56 with torch.random.fork_rng(devices=[]):
57 y2 = f(x2)
58
59 with torch.random.fork_rng(devices=[device]):
60 torch.set_rng_state(cpu_state)
61 torch.cuda.set_rng_state(cuda_state, device)
62 y3 = f(x3)
63
64 print("y2", y2, y2.sum(), y2.mean())
65 print("y3", y3, y3.sum(), y3.mean())
66
67
68def main():
69 test_cpu()
70 print(torch.cuda.is_available())
71 if torch.cuda.is_available():
72 test_cuda()
73
74
75if __name__ == "__main__":
76 torch.manual_seed(20241030)
77 main()
78
79"""
80----------macos----------
81y1 tensor([[1.8172, 0.9755, 1.4394, 1.3970, 0.0000],
82 [1.0299, 2.4723, 1.1365, 0.0000, 0.7647],
83 [2.0160, 1.8454, 1.9144, 1.8337, 1.5052]]) tensor(20.1471) tensor(1.3431)
84y2 tensor([[2.1100, 0.6028, 0.9254, 0.0000, 1.3935],
85 [1.9948, 0.0000, 1.4811, 1.2179, 0.8196],
86 [2.1118, 1.3885, 1.5176, 1.2972, 2.2623]]) tensor(19.1227) tensor(1.2748)
87y3 tensor([[1.8172, 0.9755, 1.4394, 1.3970, 0.0000],
88 [1.0299, 2.4723, 1.1365, 0.0000, 0.7647],
89 [2.0160, 1.8454, 1.9144, 1.8337, 1.5052]]) tensor(20.1471) tensor(1.3431)
90False
91----------Linux----------
92y1 tensor([[1.8172, 0.9755, 1.4394, 1.3970, 0.0000],
93 [1.0299, 2.4723, 1.1365, 0.0000, 0.7647],
94 [2.0160, 1.8454, 1.9144, 1.8337, 1.5052]]) tensor(20.1471) tensor(1.3431)
95y2 tensor([[2.1100, 0.6028, 0.9254, 0.0000, 1.3935],
96 [1.9948, 0.0000, 1.4811, 1.2179, 0.8196],
97 [2.1118, 1.3885, 1.5176, 1.2972, 2.2623]]) tensor(19.1227) tensor(1.2748)
98y3 tensor([[1.8172, 0.9755, 1.4394, 1.3970, 0.0000],
99 [1.0299, 2.4723, 1.1365, 0.0000, 0.7647],
100 [2.0160, 1.8454, 1.9144, 1.8337, 1.5052]]) tensor(20.1471) tensor(1.3431)
101True
102cuda_state <class 'torch.Tensor'> cpu torch.uint8 torch.Size([16])
103y1 tensor([[1.2276, 0.0716, 0.0000, 0.5980, 1.4526],
104 [1.5889, 0.5063, 0.0000, 1.0267, 1.5081],
105 [1.7808, 1.3360, 1.5424, 1.8120, 0.0000]], device='cuda:0') tensor(14.4510, device='cuda:0') tensor(0.9634, device='cuda:0')
106y2 tensor([[1.8274, 0.0000, 1.1841, 0.6805, 1.0811],
107 [1.4730, 1.7636, 1.4561, 1.1214, 0.0000],
108 [1.6906, 1.0212, 1.7333, 1.2885, 2.6000]], device='cuda:0') tensor(18.9209, device='cuda:0') tensor(1.2614, device='cuda:0')
109y3 tensor([[1.2276, 0.0716, 0.0000, 0.5980, 1.4526],
110 [1.5889, 0.5063, 0.0000, 1.0267, 1.5081],
111 [1.7808, 1.3360, 1.5424, 1.8120, 0.0000]], device='cuda:0') tensor(14.4510, device='cuda:0') tensor(0.9634, device='cuda:0')
112"""