random

Note that, we have to save the random generator state for the CPU and GPU.

./code/random_test.py
  1#!/usr/bin/env python3
  2
  3import torch
  4
  5
  6class Foo(torch.nn.Module):
  7    def __init__(self):
  8        super().__init__()
  9
 10    def forward(self, x):
 11        y = torch.rand(*x.shape).to(x.device)
 12        return torch.nn.functional.dropout(x + y, p=0.3)
 13
 14
 15def test_cpu():
 16    f = Foo()
 17
 18    x1 = torch.rand(3, 5)
 19    x2 = x1.clone()
 20    x3 = x1.clone()
 21
 22    cpu_state = torch.get_rng_state()
 23    y1 = f(x1)
 24    print("y1", y1, y1.sum(), y1.mean())
 25    with torch.random.fork_rng(devices=[]):
 26        y2 = f(x2)
 27
 28    with torch.random.fork_rng(devices=[]):
 29        torch.set_rng_state(cpu_state)
 30        y3 = f(x3)
 31
 32    print("y2", y2, y2.sum(), y2.mean())
 33    print("y3", y3, y3.sum(), y3.mean())
 34
 35
 36def test_cuda():
 37    f = Foo()
 38    device = torch.device("cuda", 0)
 39
 40    x1 = torch.rand(3, 5).to(device)
 41    x2 = x1.clone()
 42    x3 = x1.clone()
 43
 44    cpu_state = torch.get_rng_state()
 45    cuda_state = torch.cuda.get_rng_state(device)
 46    print(
 47        "cuda_state",
 48        type(cuda_state),
 49        cuda_state.device,
 50        cuda_state.dtype,
 51        cuda_state.shape,
 52    )
 53
 54    y1 = f(x1)
 55    print("y1", y1, y1.sum(), y1.mean())
 56    with torch.random.fork_rng(devices=[]):
 57        y2 = f(x2)
 58
 59    with torch.random.fork_rng(devices=[device]):
 60        torch.set_rng_state(cpu_state)
 61        torch.cuda.set_rng_state(cuda_state, device)
 62        y3 = f(x3)
 63
 64    print("y2", y2, y2.sum(), y2.mean())
 65    print("y3", y3, y3.sum(), y3.mean())
 66
 67
 68def main():
 69    test_cpu()
 70    print(torch.cuda.is_available())
 71    if torch.cuda.is_available():
 72        test_cuda()
 73
 74
 75if __name__ == "__main__":
 76    torch.manual_seed(20241030)
 77    main()
 78
 79"""
 80----------macos----------
 81y1 tensor([[1.8172, 0.9755, 1.4394, 1.3970, 0.0000],
 82        [1.0299, 2.4723, 1.1365, 0.0000, 0.7647],
 83        [2.0160, 1.8454, 1.9144, 1.8337, 1.5052]]) tensor(20.1471) tensor(1.3431)
 84y2 tensor([[2.1100, 0.6028, 0.9254, 0.0000, 1.3935],
 85        [1.9948, 0.0000, 1.4811, 1.2179, 0.8196],
 86        [2.1118, 1.3885, 1.5176, 1.2972, 2.2623]]) tensor(19.1227) tensor(1.2748)
 87y3 tensor([[1.8172, 0.9755, 1.4394, 1.3970, 0.0000],
 88        [1.0299, 2.4723, 1.1365, 0.0000, 0.7647],
 89        [2.0160, 1.8454, 1.9144, 1.8337, 1.5052]]) tensor(20.1471) tensor(1.3431)
 90False
 91----------Linux----------
 92y1 tensor([[1.8172, 0.9755, 1.4394, 1.3970, 0.0000],
 93        [1.0299, 2.4723, 1.1365, 0.0000, 0.7647],
 94        [2.0160, 1.8454, 1.9144, 1.8337, 1.5052]]) tensor(20.1471) tensor(1.3431)
 95y2 tensor([[2.1100, 0.6028, 0.9254, 0.0000, 1.3935],
 96        [1.9948, 0.0000, 1.4811, 1.2179, 0.8196],
 97        [2.1118, 1.3885, 1.5176, 1.2972, 2.2623]]) tensor(19.1227) tensor(1.2748)
 98y3 tensor([[1.8172, 0.9755, 1.4394, 1.3970, 0.0000],
 99        [1.0299, 2.4723, 1.1365, 0.0000, 0.7647],
100        [2.0160, 1.8454, 1.9144, 1.8337, 1.5052]]) tensor(20.1471) tensor(1.3431)
101True
102cuda_state <class 'torch.Tensor'> cpu torch.uint8 torch.Size([16])
103y1 tensor([[1.2276, 0.0716, 0.0000, 0.5980, 1.4526],
104        [1.5889, 0.5063, 0.0000, 1.0267, 1.5081],
105        [1.7808, 1.3360, 1.5424, 1.8120, 0.0000]], device='cuda:0') tensor(14.4510, device='cuda:0') tensor(0.9634, device='cuda:0')
106y2 tensor([[1.8274, 0.0000, 1.1841, 0.6805, 1.0811],
107        [1.4730, 1.7636, 1.4561, 1.1214, 0.0000],
108        [1.6906, 1.0212, 1.7333, 1.2885, 2.6000]], device='cuda:0') tensor(18.9209, device='cuda:0') tensor(1.2614, device='cuda:0')
109y3 tensor([[1.2276, 0.0716, 0.0000, 0.5980, 1.4526],
110        [1.5889, 0.5063, 0.0000, 1.0267, 1.5081],
111        [1.7808, 1.3360, 1.5424, 1.8120, 0.0000]], device='cuda:0') tensor(14.4510, device='cuda:0') tensor(0.9634, device='cuda:0')
112"""