post on 05 Jul 2025 about 11754words require 40min
CC BY 4.0 (除特别声明或转载文章外)
如果这些文字帮助到你,可以请我喝一杯咖啡~
和 Tilelang 群友友好交流了之前关于 L2 优化矩阵乘算法的疑问,本文基于 Tilelang 实现 matmul 算子,验证 L2 分块算法。或许是哪里写的有问题,现在的代码性能完全不及预期,且存在显存泄露情况。待和群友进一步请教。
参考了官方 example_gemm_schedule.py,我在官方示例的基础上增加了 panel_size == 0
时不使用 T.use_swizzle
策略。同时导出运行时的层层下降的中间表示,用于观察 tilelang 生成算子的具体逻辑。
matmul.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# spack load [email protected] [email protected]+cuda py-triton py-matplotlib py-pandas
# LD_LIBRARY_PATH=$(spack location -i cuda)/lib64:$LD_LIBRARY_PATH python3 matmul.py
import tilelang
import tilelang.language as T
import torch
import triton
def kernel_matmul(
M, N, K, block_M, block_N, block_K, dtype, accum_dtype, panel_size=10
):
@T.prim_func
def gemm_schedule(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
T.use_swizzle(panel_size=panel_size)
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, ko * block_K], A_shared)
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm_schedule
def kernel_matmul_no_l2_opt(
M, N, K, block_M, block_N, block_K, dtype, accum_dtype, panel_size=10
):
@T.prim_func
def gemm_schedule(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, ko * block_K], A_shared)
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm_schedule
def test():
M, N, K = 2**12, 2**11, 2**10
kernel = kernel_matmul(M, N, K, 128, 128, 32, "float16", "float")
with open("matmul.kernel.py", "w") as f:
f.write(str(kernel))
tilelang_matmul = tilelang.compile(kernel, out_idx=-1)
with open("matmul.kernel.cu", "w") as f:
f.write(tilelang_matmul.get_kernel_source())
DEVICE = "cuda"
a = torch.rand([M, K], device=DEVICE, dtype=torch.float16)
b = torch.rand([K, N], device=DEVICE, dtype=torch.float16)
torch_c = torch.matmul(a, b)
triton_c = tilelang_matmul(a, b)
print("Maxdiff is {}".format(torch.max(torch.abs(torch_c - triton_c))))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["M", "N", "K"],
x_vals=[1024 * i for i in range(12, 21)],
line_arg="provider",
line_vals=["torch"] + ["tilelang_ps_" + str(i * 10) for i in range(-1, 3)],
line_names=["Torch"] + ["Tilelang_ps_" + str(i * 10) for i in range(-1, 3)],
plot_name="matmul-tflops",
args={},
)
)
def benchmark(M, N, K, provider):
DEVICE = "cuda"
a = torch.rand([M, K], device=DEVICE, dtype=torch.float16)
b = torch.rand([K, N], device=DEVICE, dtype=torch.float16)
mp = {"torch": lambda: torch.matmul(a, b)}
for i in range(-1, 3):
kernel = (
kernel_matmul(M, N, K, 128, 128, 32, "float16", "float", i * 10)
if i >= 0
else kernel_matmul_no_l2_opt(
M, N, K, 128, 128, 32, "float16", "float", i * 10
)
)
tilelang_matmul = tilelang.compile(kernel, out_idx=-1)
mp["tilelang_ps_" + str(i * 10)] = lambda: tilelang_matmul(a, b)
ms = triton.testing.do_bench(mp[provider])
torch.cuda.empty_cache()
tflops = 2 * M * N * K * 1e-12 / (ms * 1e-3)
return tflops
if __name__ == "__main__":
torch.manual_seed(3407)
test()
benchmark.run(print_data=True, show_plots=False, save_path=".")
推荐的 Tilelang_ps_-10
是关掉优化。
Maxdiff is 0.0
matmul-tflops:
M Torch Tilelang_ps_-10 Tilelang_ps_0 Tilelang_ps_10 Tilelang_ps_20
0 12288.0 218.185030 163.195512 160.787932 155.639821 167.653793
1 13312.0 218.176101 174.291067 172.917449 172.878523 176.228789
2 14336.0 214.400675 174.471258 167.645072 168.738852 159.228886
3 15360.0 219.363970 157.663036 169.193896 171.710036 169.996476
4 16384.0 220.999889 173.871241 174.994087 174.047391 170.979700
5 17408.0 220.840380 134.329065 137.557182 130.383652 168.406982
6 18432.0 216.221871 173.599278 123.953247 136.703516 170.494457
7 19456.0 217.925130 163.030749 120.179175 134.268942 175.348523
8 20480.0 197.038250 157.498531 137.681804 137.550856 110.879759
matmul.kernel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# from tvm.script import tir as T
@T.prim_func
def gemm_schedule(A: T.Buffer((4096, 1024), "float16"), B: T.Buffer((1024, 2048), "float16"), C: T.Buffer((4096, 2048), "float16")):
# with T.block("root"):
bx = T.launch_thread("blockIdx.x", 16)
by = T.launch_thread("blockIdx.y", 32)
tx = T.launch_thread("threadIdx.x", 128)
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
with T.block("tilelang_root"):
T.reads(A[by * 128, 0:993], B[0:1024, bx * 128:bx * 128 + 128], C[by * 128, bx * 128])
T.writes()
A_shared = T.alloc_buffer((128, 32), "float16", scope="shared.dyn")
B_shared = T.alloc_buffer((32, 128), "float16", scope="shared.dyn")
C_local = T.alloc_buffer((128, 128), scope="local.fragment")
T.attr(None, "threadblock_swizzle_pattern", "tl::rasterization2DRow<10>")
T.fill(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 16384, 2), 0)
for ko in T.serial(32, annotations={"num_stages": 3}):
T.copy(T.region(A[by * 128, ko * 32], 1, 128, 32), T.region(A_shared[0, 0], 2, 128, 32))
for k in T.parallel(32):
for j in T.parallel(128):
B_shared[k, j] = B[ko * 32 + k, bx * 128 + j]
T.gemm(T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 4096, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 4096, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 16384, 3), T.bool(False), T.bool(False), 128, 128, 32, 0, T.bool(False), 1, 0)
T.copy(T.region(C_local[0, 0], 1, 128, 128), T.region(C[by * 128, bx * 128], 2, 128, 128))
matmul.kernel.cu
五年前自己写的矩阵乘(见原文代码的 724 行) 大概也是这个画风…
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
extern "C" __global__ void gemm_schedule_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C);
extern "C" __global__ void __launch_bounds__(128, 1) gemm_schedule_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
float C_local[128];
const dim3 blockIdx = tl::rasterization2DRow<10>();
#pragma unroll
for (int i = 0; i < 64; ++i) {
*(float2*)(C_local + (i * 2)) = make_float2(0.000000e+00f, 0.000000e+00f);
}
#pragma unroll
for (int i_1 = 0; i_1 < 4; ++i_1) {
tl::cp_async_gs<16>(buf_dyn_shmem+((((i_1 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), A+((((((int)blockIdx.y) * 131072) + (i_1 * 32768)) + ((((int)threadIdx.x) >> 2) * 1024)) + ((((int)threadIdx.x) & 3) * 8)));
}
#pragma unroll
for (int i_2 = 0; i_2 < 4; ++i_2) {
tl::cp_async_gs<16>(buf_dyn_shmem+(((((((((((int)threadIdx.x) & 15) >> 3) * 4096) + (i_2 * 1024)) + ((((int)threadIdx.x) >> 4) * 128)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 24576), B+((((i_2 * 16384) + ((((int)threadIdx.x) >> 4) * 2048)) + (((int)blockIdx.x) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
}
tl::cp_async_commit();
#pragma unroll
for (int i_3 = 0; i_3 < 4; ++i_3) {
tl::cp_async_gs<16>(buf_dyn_shmem+(((((i_3 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 8192), A+(((((((int)blockIdx.y) * 131072) + (i_3 * 32768)) + ((((int)threadIdx.x) >> 2) * 1024)) + ((((int)threadIdx.x) & 3) * 8)) + 32));
}
#pragma unroll
for (int i_4 = 0; i_4 < 4; ++i_4) {
tl::cp_async_gs<16>(buf_dyn_shmem+(((((((((((int)threadIdx.x) & 15) >> 3) * 4096) + (i_4 * 1024)) + ((((int)threadIdx.x) >> 4) * 128)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 32768), B+(((((i_4 * 16384) + ((((int)threadIdx.x) >> 4) * 2048)) + (((int)blockIdx.x) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 65536));
}
tl::cp_async_commit();
for (int ko = 0; ko < 30; ++ko) {
__syncthreads();
#pragma unroll
for (int i_5 = 0; i_5 < 4; ++i_5) {
tl::cp_async_gs<16>(buf_dyn_shmem+(((((((ko + 2) % 3) * 8192) + (i_5 * 2048)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), A+((((((((int)blockIdx.y) * 131072) + (i_5 * 32768)) + ((((int)threadIdx.x) >> 2) * 1024)) + (ko * 32)) + ((((int)threadIdx.x) & 3) * 8)) + 64));
}
#pragma unroll
for (int i_6 = 0; i_6 < 4; ++i_6) {
tl::cp_async_gs<16>(buf_dyn_shmem+((((((((((ko + 2) % 3) * 8192) + (((((int)threadIdx.x) & 15) >> 3) * 4096)) + (i_6 * 1024)) + ((((int)threadIdx.x) >> 4) * 128)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 24576), B+((((((ko * 65536) + (i_6 * 16384)) + ((((int)threadIdx.x) >> 4) * 2048)) + (((int)blockIdx.x) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 131072));
}
tl::cp_async_commit();
tl::cp_async_wait<2>();
__syncthreads();
tl::gemm_ss<128, 128, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[((ko % 3) * 4096)])), (&(((half_t*)buf_dyn_shmem)[(((ko % 3) * 4096) + 12288)])), (&(C_local[0])));
}
tl::cp_async_wait<1>();
__syncthreads();
tl::gemm_ss<128, 128, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[0])), (&(((half_t*)buf_dyn_shmem)[12288])), (&(C_local[0])));
tl::cp_async_wait<0>();
__syncthreads();
tl::gemm_ss<128, 128, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[4096])), (&(((half_t*)buf_dyn_shmem)[16384])), (&(C_local[0])));
#pragma unroll
for (int i_7 = 0; i_7 < 64; ++i_7) {
uint1 __1;
float2 v_ = *(float2*)(C_local + (i_7 * 2));
((half2*)(&(__1.x)))->x = (half_t)(v_.x);
((half2*)(&(__1.x)))->y = (half_t)(v_.y);
*(uint1*)(C + (((((((((((int)blockIdx.y) * 262144) + (((i_7 & 7) >> 1) * 65536)) + (((((int)threadIdx.x) & 63) >> 5) * 32768)) + ((i_7 & 1) * 16384)) + (((((int)threadIdx.x) & 31) >> 2) * 2048)) + (((int)blockIdx.x) * 128)) + ((i_7 >> 3) * 16)) + ((((int)threadIdx.x) >> 6) * 8)) + ((((int)threadIdx.x) & 3) * 2))) = __1;
}
}
#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];
extern "C" const char* get_last_error() {
return error_buf;
}
extern "C" int init() {
error_buf[0] = '\0';
cudaError_t result_gemm_schedule_kernel = cudaFuncSetAttribute(gemm_schedule_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 49152);
if (result_gemm_schedule_kernel != CUDA_SUCCESS) {
snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", 49152, cudaGetErrorString(result_gemm_schedule_kernel));
return -1;
}
return 0;
}
extern "C" int call(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C, cudaStream_t stream=cudaStreamDefault) {
gemm_schedule_kernel<<<dim3(16, 32, 1), dim3(128, 1, 1), 49152, stream>>>(A, B, C);
TILELANG_CHECK_LAST_ERROR("gemm_schedule_kernel");
return 0;
}
Related posts