Hello
test-cuda
./code/test-cuda.py
1#!/usr/bin/env python3
2
3import torch
4
5"""
6L2_cache_size: 6291456
7gcnArchName: Tesla V100-PCIE-32GB
8is_integrated: 0
9is_multi_gpu_board: 0
10major: 7
11max_threads_per_multi_processor: 2048
12minor: 0
13multi_processor_count: 80
14name: Tesla V100-PCIE-32GB
15regs_per_multiprocessor: 65536
16shared_memory_per_block: 49152
17shared_memory_per_block_optin: 98304
18shared_memory_per_multiprocessor: 98304
19total_memory: 34079899648
20uuid: 2e9d29fc-608b-1348-6c9d-c190dc3bddbe
21warp_size: 32
22"""
23
24
25def test():
26 # print(help(torch.cuda.get_device_properties("cuda:0")))
27 props = torch.cuda.get_device_properties(0)
28 for attr in dir(props):
29 # skip internal/private attributes
30 if not attr.startswith("_"):
31 print(f"{attr}: {getattr(props, attr)}")
32
33
34def main():
35 test()
36
37
38if __name__ == "__main__":
39 main()
vector additon
./code/vector_addition.py
1#!/usr/bin/env python3
2
3import torch
4import triton
5import triton.language as tl
6
7# e.g., cuda:0
8DEVICE = triton.runtime.driver.active.get_active_torch_device()
9assert isinstance(DEVICE, torch.device), type(DEVICE)
10
11
12@triton.jit
13def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
14 pid = tl.program_id(axis=0)
15 block_start = pid * BLOCK_SIZE
16 offsets = block_start + tl.arange(0, BLOCK_SIZE)
17 mask = offsets < n_elements
18 x = tl.load(x_ptr + offsets, mask=mask)
19 y = tl.load(y_ptr + offsets, mask=mask)
20 output = x + y
21 tl.store(output_ptr + offsets, output, mask=mask)
22
23
24def add(x: torch.Tensor, y: torch.Tensor):
25 output = torch.empty_like(x)
26 assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE, (
27 x.deivce,
28 y.device,
29 output.device,
30 )
31 n_elements = output.numel()
32 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
33 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
34 return output
35
36
37torch.manual_seed(0)
38size = 98432
39x = torch.rand(size, device=DEVICE)
40y = torch.rand(size, device=DEVICE)
41output_torch = x + y
42output_triton = add(x, y)
43print(output_torch[:10])
44print(output_triton[:10])
45print(
46 f"The maximum difference between torch and triton is "
47 f"{torch.max(torch.abs(output_torch - output_triton))}"
48)
vector subtraction
./code/vector_sub.py
1#!/usr/bin/env python3
2import torch
3import triton
4import triton.language as tl
5
6device = torch.device("cuda", 0)
7
8
9@triton.jit
10def sub_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
11 start = tl.program_id(0)
12 offset = start * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
13 mask = offset < n_elements
14 x = tl.load(x_ptr + offset, mask=mask)
15 y = tl.load(y_ptr + offset, mask=mask)
16 z = x - y
17 tl.store(output_ptr + offset, z, mask=mask)
18
19
20def sub(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
21 assert x.device == device, (x.device, device)
22 assert y.device == device, (y.device, device)
23 n = x.nelement()
24 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
25 z = torch.empty_like(x)
26 sub_kernel[grid](x, y, z, n, BLOCK_SIZE=1024)
27 return z
28
29
30def main():
31 n = 10
32 x = torch.rand(n, device=device)
33 y = torch.rand(n, device=device)
34 z0 = x - y
35 z1 = sub(x, y)
36 print(z0)
37 print(z1)
38 print((z0 - z1).abs().max())
39
40
41if __name__ == "__main__":
42 main()
vector add scalar
./code/vector_add_scalar.py
1#!/usr/bin/env python3
2import torch
3import triton
4import triton.language as tl
5
6device = torch.device("cuda", 0)
7
8
9@triton.jit
10def vector_add_scalar_kernel(x_ptr, y_ptr, scalar, n, BLOCK_SIZE: tl.constexpr):
11 start = tl.program_id(0)
12 offset = start * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
13 mask = offset < n
14 x = tl.load(x_ptr + offset, mask=mask)
15 y = x + scalar
16 tl.store(y_ptr + offset, y, mask=mask)
17
18
19def vector_add_scalar(x: torch.Tensor, scalar: float):
20 assert x.device == device, (x.device, device)
21 y = torch.empty_like(x)
22 n = x.nelement()
23 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
24 vector_add_scalar_kernel[grid](x, y, scalar, n, BLOCK_SIZE=4096)
25 return y
26
27
28def main():
29 x = torch.randn(10240, device=device)
30 scalar = 2.25
31 y0 = x + scalar
32 y1 = vector_add_scalar(x, scalar)
33 # print(x)
34 # print(y0)
35 # print(y1)
36 print((y0 - y1).abs().max())
37
38
39if __name__ == "__main__":
40 main()
matrix add matrix
./code/matrix_add_matrix.py
1#!/usr/bin/env python3
2import torch
3import triton
4import triton.language as tl
5
6device = torch.device("cuda", 0)
7
8
9@triton.jit
10def matrix_add_matrix_kernel(
11 x_ptr,
12 y_ptr,
13 out_ptr,
14 num_rows,
15 num_cols,
16 BLOCK_SIZE_ROW: tl.constexpr,
17 BLOCK_SIZE_COL: tl.constexpr,
18):
19 row_start = tl.program_id(0)
20 col_start = tl.program_id(1)
21
22 row_offset = row_start * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW)[:, None]
23 col_offset = col_start * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)[None, :]
24 mask = (row_offset < num_rows) & (col_offset < num_cols)
25
26 offset = row_offset * num_cols + col_offset
27
28 x = tl.load(x_ptr + offset, mask=mask)
29 y = tl.load(y_ptr + offset, mask=mask)
30 out = x + y
31 tl.store(out_ptr + offset, out, mask=mask)
32
33
34def matrix_add_matrix(x: torch.Tensor, y: torch.Tensor):
35 assert x.device == device, (x.device, device)
36 assert y.device == device, (y.device, device)
37 z = torch.empty_like(x)
38 num_rows = x.shape[0]
39 num_cols = x.shape[1]
40 grid = lambda meta: (
41 triton.cdiv(num_rows, meta["BLOCK_SIZE_ROW"]),
42 triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]),
43 )
44 matrix_add_matrix_kernel[grid](
45 x, y, z, num_rows, num_cols, BLOCK_SIZE_ROW=16, BLOCK_SIZE_COL=32
46 )
47 return z
48
49
50def main():
51 x = torch.randn(301, 529, device=device)
52 y = torch.randn(301, 529, device=device)
53 z0 = x + y
54 z1 = matrix_add_matrix(x, y)
55 # print(x)
56 # print(y)
57 # print(z0)
58 # print(z1)
59 print((z0 - z1).abs().max())
60
61
62if __name__ == "__main__":
63 torch.manual_seed(20250810)
64 main()
batched matrix add matrix
./code/batched_matrix_add.py
1#!/usr/bin/env python3
2import torch
3import triton
4import triton.language as tl
5
6device = torch.device("cuda", 0)
7
8
9@triton.jit
10def add_kernel(
11 a_ptr,
12 b_ptr,
13 c_ptr,
14 num_rows,
15 num_cols,
16 BLOCK_SIZE_ROW: tl.constexpr,
17 BLOCK_SIZE_COL: tl.constexpr,
18):
19 batch_start = tl.program_id(0)
20 row_start = tl.program_id(1)
21 col_start = tl.program_id(2)
22
23 batch_offset = batch_start * num_rows * num_cols
24
25 row_offset = row_start * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW)[:, None]
26 col_offset = col_start * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)[None, :]
27
28 mask = (row_offset < num_rows) & (col_offset < num_cols)
29
30 offset = batch_offset + row_offset * num_cols + col_offset
31 a = tl.load(a_ptr + offset, mask=mask)
32 b = tl.load(b_ptr + offset, mask=mask)
33 c = a + b
34 tl.store(c_ptr + offset, c, mask=mask)
35
36
37def batched_add(a: torch.Tensor, b: torch.Tensor):
38 assert a.device == device, (a.device, device)
39 assert b.device == device, (b.device, device)
40
41 n, r, c = a.shape
42 BLOCK_SIZE_ROW = 32
43 BLOCK_SIZE_COL = 32
44 grid = (n, triton.cdiv(r, BLOCK_SIZE_ROW), triton.cdiv(c, BLOCK_SIZE_COL))
45 out = torch.empty_like(a)
46 add_kernel[grid](a, b, out, r, c, BLOCK_SIZE_ROW, BLOCK_SIZE_COL)
47 return out
48
49
50def main():
51 a = torch.randn(5, 10, 20, device=device)
52 b = torch.randn(5, 10, 20, device=device)
53 c0 = a + b
54 c1 = batched_add(a, b)
55 print((c0 - c1).abs().max())
56
57
58if __name__ == "__main__":
59 main()
matrix multiplication
./code/matrix_mul.py
1#!/usr/bin/env python3
2
3import torch
4import triton.language as tl
5import triton
6
7
8def is_cuda():
9 return triton.runtime.driver.active.get_current_target().backend == "cuda"
10
11
12def get_cuda_autotune_config():
13 return [
14 triton.Config(
15 {
16 "BLOCK_SIZE_M": 128,
17 "BLOCK_SIZE_N": 256,
18 "BLOCK_SIZE_K": 64,
19 "GROUP_SIZE_M": 8,
20 },
21 num_stages=3,
22 num_warps=8,
23 ),
24 triton.Config(
25 {
26 "BLOCK_SIZE_M": 64,
27 "BLOCK_SIZE_N": 256,
28 "BLOCK_SIZE_K": 32,
29 "GROUP_SIZE_M": 8,
30 },
31 num_stages=4,
32 num_warps=4,
33 ),
34 triton.Config(
35 {
36 "BLOCK_SIZE_M": 128,
37 "BLOCK_SIZE_N": 128,
38 "BLOCK_SIZE_K": 32,
39 "GROUP_SIZE_M": 8,
40 },
41 num_stages=4,
42 num_warps=4,
43 ),
44 triton.Config(
45 {
46 "BLOCK_SIZE_M": 128,
47 "BLOCK_SIZE_N": 64,
48 "BLOCK_SIZE_K": 32,
49 "GROUP_SIZE_M": 8,
50 },
51 num_stages=4,
52 num_warps=4,
53 ),
54 triton.Config(
55 {
56 "BLOCK_SIZE_M": 64,
57 "BLOCK_SIZE_N": 128,
58 "BLOCK_SIZE_K": 32,
59 "GROUP_SIZE_M": 8,
60 },
61 num_stages=4,
62 num_warps=4,
63 ),
64 triton.Config(
65 {
66 "BLOCK_SIZE_M": 128,
67 "BLOCK_SIZE_N": 32,
68 "BLOCK_SIZE_K": 32,
69 "GROUP_SIZE_M": 8,
70 },
71 num_stages=4,
72 num_warps=4,
73 ),
74 triton.Config(
75 {
76 "BLOCK_SIZE_M": 64,
77 "BLOCK_SIZE_N": 32,
78 "BLOCK_SIZE_K": 32,
79 "GROUP_SIZE_M": 8,
80 },
81 num_stages=5,
82 num_warps=2,
83 ),
84 triton.Config(
85 {
86 "BLOCK_SIZE_M": 32,
87 "BLOCK_SIZE_N": 64,
88 "BLOCK_SIZE_K": 32,
89 "GROUP_SIZE_M": 8,
90 },
91 num_stages=5,
92 num_warps=2,
93 ),
94 # Good config for fp8 inputs.
95 triton.Config(
96 {
97 "BLOCK_SIZE_M": 128,
98 "BLOCK_SIZE_N": 256,
99 "BLOCK_SIZE_K": 128,
100 "GROUP_SIZE_M": 8,
101 },
102 num_stages=3,
103 num_warps=8,
104 ),
105 triton.Config(
106 {
107 "BLOCK_SIZE_M": 256,
108 "BLOCK_SIZE_N": 128,
109 "BLOCK_SIZE_K": 128,
110 "GROUP_SIZE_M": 8,
111 },
112 num_stages=3,
113 num_warps=8,
114 ),
115 triton.Config(
116 {
117 "BLOCK_SIZE_M": 256,
118 "BLOCK_SIZE_N": 64,
119 "BLOCK_SIZE_K": 128,
120 "GROUP_SIZE_M": 8,
121 },
122 num_stages=4,
123 num_warps=4,
124 ),
125 triton.Config(
126 {
127 "BLOCK_SIZE_M": 64,
128 "BLOCK_SIZE_N": 256,
129 "BLOCK_SIZE_K": 128,
130 "GROUP_SIZE_M": 8,
131 },
132 num_stages=4,
133 num_warps=4,
134 ),
135 triton.Config(
136 {
137 "BLOCK_SIZE_M": 128,
138 "BLOCK_SIZE_N": 128,
139 "BLOCK_SIZE_K": 128,
140 "GROUP_SIZE_M": 8,
141 },
142 num_stages=4,
143 num_warps=4,
144 ),
145 triton.Config(
146 {
147 "BLOCK_SIZE_M": 128,
148 "BLOCK_SIZE_N": 64,
149 "BLOCK_SIZE_K": 64,
150 "GROUP_SIZE_M": 8,
151 },
152 num_stages=4,
153 num_warps=4,
154 ),
155 triton.Config(
156 {
157 "BLOCK_SIZE_M": 64,
158 "BLOCK_SIZE_N": 128,
159 "BLOCK_SIZE_K": 64,
160 "GROUP_SIZE_M": 8,
161 },
162 num_stages=4,
163 num_warps=4,
164 ),
165 triton.Config(
166 {
167 "BLOCK_SIZE_M": 128,
168 "BLOCK_SIZE_N": 32,
169 "BLOCK_SIZE_K": 64,
170 "GROUP_SIZE_M": 8,
171 },
172 num_stages=4,
173 num_warps=4,
174 ),
175 ][:2]
176
177
178device = torch.device("cuda", 0)
179
180
181def get_autotune_config():
182 assert is_cuda()
183 return get_cuda_autotune_config()
184
185
186@triton.autotune(
187 configs=get_autotune_config(),
188 key=["M", "N", "K"],
189)
190@triton.jit
191def matmul_kernel(
192 a_ptr,
193 b_ptr,
194 c_ptr,
195 M,
196 N,
197 K,
198 stride_am,
199 stride_ak,
200 stride_bk,
201 stride_bn,
202 stride_cm,
203 stride_cn,
204 BLOCK_SIZE_M: tl.constexpr,
205 BLOCK_SIZE_N: tl.constexpr,
206 BLOCK_SIZE_K: tl.constexpr,
207 GROUP_SIZE_M: tl.constexpr,
208 ACTIVATION: tl.constexpr,
209):
210 pid = tl.program_id(axis=0)
211 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
212 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
213 num_pid_in_group = GROUP_SIZE_M * num_pid_n
214 group_id = pid // num_pid_in_group
215 first_pid_m = group_id * GROUP_SIZE_M
216 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
217 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
218 pid_n = (pid % num_pid_in_group) // group_size_m
219
220 tl.assume(pid_m >= 0)
221 tl.assume(pid_n >= 0)
222 tl.assume(stride_am >= 0)
223 tl.assume(stride_ak >= 0)
224 tl.assume(stride_bk >= 0)
225 tl.assume(stride_bn >= 0)
226 tl.assume(stride_cm >= 0)
227 tl.assume(stride_cn >= 0)
228
229 offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
230 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
231 offs_k = tl.arange(0, BLOCK_SIZE_K)
232
233 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
234 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
235 accumlator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
236
237 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
238 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
239 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
240 accumlator = tl.dot(a, b, accumlator)
241
242 a_ptrs += BLOCK_SIZE_K * stride_ak
243 b_ptrs += BLOCK_SIZE_K * stride_bk
244 if ACTIVATION == "leaky_relu":
245 accumlator = leaky_relu(accumlator)
246
247 c = accumlator.to(tl.float16)
248
249 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
250 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
251 c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn)
252 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
253 tl.store(c_ptrs, c, mask=c_mask)
254
255
256@triton.jit
257def leaky_relu(x):
258 return tl.where(x >= 0, x, 0.01 * x)
259
260
261def matmul(a, b, activation=""):
262 assert a.device == device, (a.device, device)
263 assert b.device == device, (b.device, device)
264 assert a.shape[1] == b.shape[0], (a.shape, b.shape)
265 assert a.is_contiguous()
266
267 M, K = a.shape
268 _, N = b.shape
269
270 c = torch.empty((M, N), device=device, dtype=torch.float16)
271 grid = lambda meta: (
272 triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
273 )
274 matmul_kernel[grid](
275 a,
276 b,
277 c,
278 M,
279 N,
280 K,
281 a.stride(0),
282 a.stride(1),
283 b.stride(0),
284 b.stride(1),
285 c.stride(0),
286 c.stride(1),
287 ACTIVATION=activation,
288 )
289 return c
290
291
292def main():
293 a = torch.randn(512, 512, device=device, dtype=torch.float16) - 0.5
294 b = torch.randn(512, 512, device=device, dtype=torch.float16) - 0.5
295 c0 = torch.matmul(a, b)
296 # c1 = matmul(a, b, activation="leaky_relu")
297 c1 = matmul(a, b)
298 print((c0 - c1).abs().max())
299
300
301if __name__ == "__main__":
302
303 torch.manual_seed(20250811)
304 main()