Hello
test-cuda
./code/test-cuda.py
 1#!/usr/bin/env python3
 2
 3import torch
 4
 5"""
 6L2_cache_size: 6291456
 7gcnArchName: Tesla V100-PCIE-32GB
 8is_integrated: 0
 9is_multi_gpu_board: 0
10major: 7
11max_threads_per_multi_processor: 2048
12minor: 0
13multi_processor_count: 80
14name: Tesla V100-PCIE-32GB
15regs_per_multiprocessor: 65536
16shared_memory_per_block: 49152
17shared_memory_per_block_optin: 98304
18shared_memory_per_multiprocessor: 98304
19total_memory: 34079899648
20uuid: 2e9d29fc-608b-1348-6c9d-c190dc3bddbe
21warp_size: 32
22"""
23
24
25def test():
26    #  print(help(torch.cuda.get_device_properties("cuda:0")))
27    props = torch.cuda.get_device_properties(0)
28    for attr in dir(props):
29        # skip internal/private attributes
30        if not attr.startswith("_"):
31            print(f"{attr}: {getattr(props, attr)}")
32
33
34def main():
35    test()
36
37
38if __name__ == "__main__":
39    main()
vector additon
./code/vector_addition.py
 1#!/usr/bin/env python3
 2
 3import torch
 4import triton
 5import triton.language as tl
 6
 7# e.g., cuda:0
 8DEVICE = triton.runtime.driver.active.get_active_torch_device()
 9assert isinstance(DEVICE, torch.device), type(DEVICE)
10
11
12@triton.jit
13def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
14    pid = tl.program_id(axis=0)
15    block_start = pid * BLOCK_SIZE
16    offsets = block_start + tl.arange(0, BLOCK_SIZE)
17    mask = offsets < n_elements
18    x = tl.load(x_ptr + offsets, mask=mask)
19    y = tl.load(y_ptr + offsets, mask=mask)
20    output = x + y
21    tl.store(output_ptr + offsets, output, mask=mask)
22
23
24def add(x: torch.Tensor, y: torch.Tensor):
25    output = torch.empty_like(x)
26    assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE, (
27        x.deivce,
28        y.device,
29        output.device,
30    )
31    n_elements = output.numel()
32    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
33    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
34    return output
35
36
37torch.manual_seed(0)
38size = 98432
39x = torch.rand(size, device=DEVICE)
40y = torch.rand(size, device=DEVICE)
41output_torch = x + y
42output_triton = add(x, y)
43print(output_torch[:10])
44print(output_triton[:10])
45print(
46    f"The maximum difference between torch and triton is "
47    f"{torch.max(torch.abs(output_torch - output_triton))}"
48)
vector subtraction
./code/vector_sub.py
 1#!/usr/bin/env python3
 2import torch
 3import triton
 4import triton.language as tl
 5
 6device = torch.device("cuda", 0)
 7
 8
 9@triton.jit
10def sub_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
11    start = tl.program_id(0)
12    offset = start * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
13    mask = offset < n_elements
14    x = tl.load(x_ptr + offset, mask=mask)
15    y = tl.load(y_ptr + offset, mask=mask)
16    z = x - y
17    tl.store(output_ptr + offset, z, mask=mask)
18
19
20def sub(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
21    assert x.device == device, (x.device, device)
22    assert y.device == device, (y.device, device)
23    n = x.nelement()
24    grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
25    z = torch.empty_like(x)
26    sub_kernel[grid](x, y, z, n, BLOCK_SIZE=1024)
27    return z
28
29
30def main():
31    n = 10
32    x = torch.rand(n, device=device)
33    y = torch.rand(n, device=device)
34    z0 = x - y
35    z1 = sub(x, y)
36    print(z0)
37    print(z1)
38    print((z0 - z1).abs().max())
39
40
41if __name__ == "__main__":
42    main()
vector add scalar
./code/vector_add_scalar.py
 1#!/usr/bin/env python3
 2import torch
 3import triton
 4import triton.language as tl
 5
 6device = torch.device("cuda", 0)
 7
 8
 9@triton.jit
10def vector_add_scalar_kernel(x_ptr, y_ptr, scalar, n, BLOCK_SIZE: tl.constexpr):
11    start = tl.program_id(0)
12    offset = start * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
13    mask = offset < n
14    x = tl.load(x_ptr + offset, mask=mask)
15    y = x + scalar
16    tl.store(y_ptr + offset, y, mask=mask)
17
18
19def vector_add_scalar(x: torch.Tensor, scalar: float):
20    assert x.device == device, (x.device, device)
21    y = torch.empty_like(x)
22    n = x.nelement()
23    grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
24    vector_add_scalar_kernel[grid](x, y, scalar, n, BLOCK_SIZE=4096)
25    return y
26
27
28def main():
29    x = torch.randn(10240, device=device)
30    scalar = 2.25
31    y0 = x + scalar
32    y1 = vector_add_scalar(x, scalar)
33    #  print(x)
34    #  print(y0)
35    #  print(y1)
36    print((y0 - y1).abs().max())
37
38
39if __name__ == "__main__":
40    main()
matrix add matrix
./code/matrix_add_matrix.py
 1#!/usr/bin/env python3
 2import torch
 3import triton
 4import triton.language as tl
 5
 6device = torch.device("cuda", 0)
 7
 8
 9@triton.jit
10def matrix_add_matrix_kernel(
11    x_ptr,
12    y_ptr,
13    out_ptr,
14    num_rows,
15    num_cols,
16    BLOCK_SIZE_ROW: tl.constexpr,
17    BLOCK_SIZE_COL: tl.constexpr,
18):
19    row_start = tl.program_id(0)
20    col_start = tl.program_id(1)
21
22    row_offset = row_start * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW)[:, None]
23    col_offset = col_start * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)[None, :]
24    mask = (row_offset < num_rows) & (col_offset < num_cols)
25
26    offset = row_offset * num_cols + col_offset
27
28    x = tl.load(x_ptr + offset, mask=mask)
29    y = tl.load(y_ptr + offset, mask=mask)
30    out = x + y
31    tl.store(out_ptr + offset, out, mask=mask)
32
33
34def matrix_add_matrix(x: torch.Tensor, y: torch.Tensor):
35    assert x.device == device, (x.device, device)
36    assert y.device == device, (y.device, device)
37    z = torch.empty_like(x)
38    num_rows = x.shape[0]
39    num_cols = x.shape[1]
40    grid = lambda meta: (
41        triton.cdiv(num_rows, meta["BLOCK_SIZE_ROW"]),
42        triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]),
43    )
44    matrix_add_matrix_kernel[grid](
45        x, y, z, num_rows, num_cols, BLOCK_SIZE_ROW=16, BLOCK_SIZE_COL=32
46    )
47    return z
48
49
50def main():
51    x = torch.randn(301, 529, device=device)
52    y = torch.randn(301, 529, device=device)
53    z0 = x + y
54    z1 = matrix_add_matrix(x, y)
55    #  print(x)
56    #  print(y)
57    #  print(z0)
58    #  print(z1)
59    print((z0 - z1).abs().max())
60
61
62if __name__ == "__main__":
63    torch.manual_seed(20250810)
64    main()
batched matrix add matrix
./code/batched_matrix_add.py
 1#!/usr/bin/env python3
 2import torch
 3import triton
 4import triton.language as tl
 5
 6device = torch.device("cuda", 0)
 7
 8
 9@triton.jit
10def add_kernel(
11    a_ptr,
12    b_ptr,
13    c_ptr,
14    num_rows,
15    num_cols,
16    BLOCK_SIZE_ROW: tl.constexpr,
17    BLOCK_SIZE_COL: tl.constexpr,
18):
19    batch_start = tl.program_id(0)
20    row_start = tl.program_id(1)
21    col_start = tl.program_id(2)
22
23    batch_offset = batch_start * num_rows * num_cols
24
25    row_offset = row_start * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW)[:, None]
26    col_offset = col_start * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)[None, :]
27
28    mask = (row_offset < num_rows) & (col_offset < num_cols)
29
30    offset = batch_offset + row_offset * num_cols + col_offset
31    a = tl.load(a_ptr + offset, mask=mask)
32    b = tl.load(b_ptr + offset, mask=mask)
33    c = a + b
34    tl.store(c_ptr + offset, c, mask=mask)
35
36
37def batched_add(a: torch.Tensor, b: torch.Tensor):
38    assert a.device == device, (a.device, device)
39    assert b.device == device, (b.device, device)
40
41    n, r, c = a.shape
42    BLOCK_SIZE_ROW = 32
43    BLOCK_SIZE_COL = 32
44    grid = (n, triton.cdiv(r, BLOCK_SIZE_ROW), triton.cdiv(c, BLOCK_SIZE_COL))
45    out = torch.empty_like(a)
46    add_kernel[grid](a, b, out, r, c, BLOCK_SIZE_ROW, BLOCK_SIZE_COL)
47    return out
48
49
50def main():
51    a = torch.randn(5, 10, 20, device=device)
52    b = torch.randn(5, 10, 20, device=device)
53    c0 = a + b
54    c1 = batched_add(a, b)
55    print((c0 - c1).abs().max())
56
57
58if __name__ == "__main__":
59    main()
matrix multiplication
./code/matrix_mul.py
  1#!/usr/bin/env python3
  2
  3import torch
  4import triton.language as tl
  5import triton
  6
  7
  8def is_cuda():
  9    return triton.runtime.driver.active.get_current_target().backend == "cuda"
 10
 11
 12def get_cuda_autotune_config():
 13    return [
 14        triton.Config(
 15            {
 16                "BLOCK_SIZE_M": 128,
 17                "BLOCK_SIZE_N": 256,
 18                "BLOCK_SIZE_K": 64,
 19                "GROUP_SIZE_M": 8,
 20            },
 21            num_stages=3,
 22            num_warps=8,
 23        ),
 24        triton.Config(
 25            {
 26                "BLOCK_SIZE_M": 64,
 27                "BLOCK_SIZE_N": 256,
 28                "BLOCK_SIZE_K": 32,
 29                "GROUP_SIZE_M": 8,
 30            },
 31            num_stages=4,
 32            num_warps=4,
 33        ),
 34        triton.Config(
 35            {
 36                "BLOCK_SIZE_M": 128,
 37                "BLOCK_SIZE_N": 128,
 38                "BLOCK_SIZE_K": 32,
 39                "GROUP_SIZE_M": 8,
 40            },
 41            num_stages=4,
 42            num_warps=4,
 43        ),
 44        triton.Config(
 45            {
 46                "BLOCK_SIZE_M": 128,
 47                "BLOCK_SIZE_N": 64,
 48                "BLOCK_SIZE_K": 32,
 49                "GROUP_SIZE_M": 8,
 50            },
 51            num_stages=4,
 52            num_warps=4,
 53        ),
 54        triton.Config(
 55            {
 56                "BLOCK_SIZE_M": 64,
 57                "BLOCK_SIZE_N": 128,
 58                "BLOCK_SIZE_K": 32,
 59                "GROUP_SIZE_M": 8,
 60            },
 61            num_stages=4,
 62            num_warps=4,
 63        ),
 64        triton.Config(
 65            {
 66                "BLOCK_SIZE_M": 128,
 67                "BLOCK_SIZE_N": 32,
 68                "BLOCK_SIZE_K": 32,
 69                "GROUP_SIZE_M": 8,
 70            },
 71            num_stages=4,
 72            num_warps=4,
 73        ),
 74        triton.Config(
 75            {
 76                "BLOCK_SIZE_M": 64,
 77                "BLOCK_SIZE_N": 32,
 78                "BLOCK_SIZE_K": 32,
 79                "GROUP_SIZE_M": 8,
 80            },
 81            num_stages=5,
 82            num_warps=2,
 83        ),
 84        triton.Config(
 85            {
 86                "BLOCK_SIZE_M": 32,
 87                "BLOCK_SIZE_N": 64,
 88                "BLOCK_SIZE_K": 32,
 89                "GROUP_SIZE_M": 8,
 90            },
 91            num_stages=5,
 92            num_warps=2,
 93        ),
 94        # Good config for fp8 inputs.
 95        triton.Config(
 96            {
 97                "BLOCK_SIZE_M": 128,
 98                "BLOCK_SIZE_N": 256,
 99                "BLOCK_SIZE_K": 128,
100                "GROUP_SIZE_M": 8,
101            },
102            num_stages=3,
103            num_warps=8,
104        ),
105        triton.Config(
106            {
107                "BLOCK_SIZE_M": 256,
108                "BLOCK_SIZE_N": 128,
109                "BLOCK_SIZE_K": 128,
110                "GROUP_SIZE_M": 8,
111            },
112            num_stages=3,
113            num_warps=8,
114        ),
115        triton.Config(
116            {
117                "BLOCK_SIZE_M": 256,
118                "BLOCK_SIZE_N": 64,
119                "BLOCK_SIZE_K": 128,
120                "GROUP_SIZE_M": 8,
121            },
122            num_stages=4,
123            num_warps=4,
124        ),
125        triton.Config(
126            {
127                "BLOCK_SIZE_M": 64,
128                "BLOCK_SIZE_N": 256,
129                "BLOCK_SIZE_K": 128,
130                "GROUP_SIZE_M": 8,
131            },
132            num_stages=4,
133            num_warps=4,
134        ),
135        triton.Config(
136            {
137                "BLOCK_SIZE_M": 128,
138                "BLOCK_SIZE_N": 128,
139                "BLOCK_SIZE_K": 128,
140                "GROUP_SIZE_M": 8,
141            },
142            num_stages=4,
143            num_warps=4,
144        ),
145        triton.Config(
146            {
147                "BLOCK_SIZE_M": 128,
148                "BLOCK_SIZE_N": 64,
149                "BLOCK_SIZE_K": 64,
150                "GROUP_SIZE_M": 8,
151            },
152            num_stages=4,
153            num_warps=4,
154        ),
155        triton.Config(
156            {
157                "BLOCK_SIZE_M": 64,
158                "BLOCK_SIZE_N": 128,
159                "BLOCK_SIZE_K": 64,
160                "GROUP_SIZE_M": 8,
161            },
162            num_stages=4,
163            num_warps=4,
164        ),
165        triton.Config(
166            {
167                "BLOCK_SIZE_M": 128,
168                "BLOCK_SIZE_N": 32,
169                "BLOCK_SIZE_K": 64,
170                "GROUP_SIZE_M": 8,
171            },
172            num_stages=4,
173            num_warps=4,
174        ),
175    ][:2]
176
177
178device = torch.device("cuda", 0)
179
180
181def get_autotune_config():
182    assert is_cuda()
183    return get_cuda_autotune_config()
184
185
186@triton.autotune(
187    configs=get_autotune_config(),
188    key=["M", "N", "K"],
189)
190@triton.jit
191def matmul_kernel(
192    a_ptr,
193    b_ptr,
194    c_ptr,
195    M,
196    N,
197    K,
198    stride_am,
199    stride_ak,
200    stride_bk,
201    stride_bn,
202    stride_cm,
203    stride_cn,
204    BLOCK_SIZE_M: tl.constexpr,
205    BLOCK_SIZE_N: tl.constexpr,
206    BLOCK_SIZE_K: tl.constexpr,
207    GROUP_SIZE_M: tl.constexpr,
208    ACTIVATION: tl.constexpr,
209):
210    pid = tl.program_id(axis=0)
211    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
212    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
213    num_pid_in_group = GROUP_SIZE_M * num_pid_n
214    group_id = pid // num_pid_in_group
215    first_pid_m = group_id * GROUP_SIZE_M
216    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
217    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
218    pid_n = (pid % num_pid_in_group) // group_size_m
219
220    tl.assume(pid_m >= 0)
221    tl.assume(pid_n >= 0)
222    tl.assume(stride_am >= 0)
223    tl.assume(stride_ak >= 0)
224    tl.assume(stride_bk >= 0)
225    tl.assume(stride_bn >= 0)
226    tl.assume(stride_cm >= 0)
227    tl.assume(stride_cn >= 0)
228
229    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
230    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
231    offs_k = tl.arange(0, BLOCK_SIZE_K)
232
233    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
234    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
235    accumlator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
236
237    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
238        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
239        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
240        accumlator = tl.dot(a, b, accumlator)
241
242        a_ptrs += BLOCK_SIZE_K * stride_ak
243        b_ptrs += BLOCK_SIZE_K * stride_bk
244    if ACTIVATION == "leaky_relu":
245        accumlator = leaky_relu(accumlator)
246
247    c = accumlator.to(tl.float16)
248
249    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
250    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
251    c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn)
252    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
253    tl.store(c_ptrs, c, mask=c_mask)
254
255
256@triton.jit
257def leaky_relu(x):
258    return tl.where(x >= 0, x, 0.01 * x)
259
260
261def matmul(a, b, activation=""):
262    assert a.device == device, (a.device, device)
263    assert b.device == device, (b.device, device)
264    assert a.shape[1] == b.shape[0], (a.shape, b.shape)
265    assert a.is_contiguous()
266
267    M, K = a.shape
268    _, N = b.shape
269
270    c = torch.empty((M, N), device=device, dtype=torch.float16)
271    grid = lambda meta: (
272        triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
273    )
274    matmul_kernel[grid](
275        a,
276        b,
277        c,
278        M,
279        N,
280        K,
281        a.stride(0),
282        a.stride(1),
283        b.stride(0),
284        b.stride(1),
285        c.stride(0),
286        c.stride(1),
287        ACTIVATION=activation,
288    )
289    return c
290
291
292def main():
293    a = torch.randn(512, 512, device=device, dtype=torch.float16) - 0.5
294    b = torch.randn(512, 512, device=device, dtype=torch.float16) - 0.5
295    c0 = torch.matmul(a, b)
296    #  c1 = matmul(a, b, activation="leaky_relu")
297    c1 = matmul(a, b)
298    print((c0 - c1).abs().max())
299
300
301if __name__ == "__main__":
302
303    torch.manual_seed(20250811)
304    main()