Menu

Tilelang 入门-L2 友好的矩阵乘

post on 05 Jul 2025 about 11754words require 40min
CC BY 4.0 (除特别声明或转载文章外)
如果这些文字帮助到你,可以请我喝一杯咖啡~

和 Tilelang 群友友好交流了之前关于 L2 优化矩阵乘算法的疑问,本文基于 Tilelang 实现 matmul 算子,验证 L2 分块算法。或许是哪里写的有问题,现在的代码性能完全不及预期,且存在显存泄露情况。待和群友进一步请教。

matmul-tflops

参考了官方 example_gemm_schedule.py,我在官方示例的基础上增加了 panel_size == 0 时不使用 T.use_swizzle 策略。同时导出运行时的层层下降的中间表示,用于观察 tilelang 生成算子的具体逻辑。

实验环境

源代码 matmul.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# spack load [email protected] [email protected]+cuda py-triton py-matplotlib py-pandas
# LD_LIBRARY_PATH=$(spack location -i cuda)/lib64:$LD_LIBRARY_PATH python3 matmul.py
import tilelang
import tilelang.language as T
import torch
import triton


def kernel_matmul(
    M, N, K, block_M, block_N, block_K, dtype, accum_dtype, panel_size=10
):
    @T.prim_func
    def gemm_schedule(
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((K, N), dtype),
        C: T.Tensor((M, N), dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
            bx,
            by,
        ):
            T.use_swizzle(panel_size=panel_size)
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)
            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
                T.copy(A[by * block_M, ko * block_K], A_shared)
                for k, j in T.Parallel(block_K, block_N):
                    B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
                T.gemm(A_shared, B_shared, C_local)
            T.copy(C_local, C[by * block_M, bx * block_N])

    return gemm_schedule


def kernel_matmul_no_l2_opt(
    M, N, K, block_M, block_N, block_K, dtype, accum_dtype, panel_size=10
):
    @T.prim_func
    def gemm_schedule(
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((K, N), dtype),
        C: T.Tensor((M, N), dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
            bx,
            by,
        ):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)
            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
                T.copy(A[by * block_M, ko * block_K], A_shared)
                for k, j in T.Parallel(block_K, block_N):
                    B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
                T.gemm(A_shared, B_shared, C_local)
            T.copy(C_local, C[by * block_M, bx * block_N])

    return gemm_schedule


def test():
    M, N, K = 2**12, 2**11, 2**10
    kernel = kernel_matmul(M, N, K, 128, 128, 32, "float16", "float")
    with open("matmul.kernel.py", "w") as f:
        f.write(str(kernel))

    tilelang_matmul = tilelang.compile(kernel, out_idx=-1)
    with open("matmul.kernel.cu", "w") as f:
        f.write(tilelang_matmul.get_kernel_source())

    DEVICE = "cuda"
    a = torch.rand([M, K], device=DEVICE, dtype=torch.float16)
    b = torch.rand([K, N], device=DEVICE, dtype=torch.float16)

    torch_c = torch.matmul(a, b)
    triton_c = tilelang_matmul(a, b)
    print("Maxdiff is {}".format(torch.max(torch.abs(torch_c - triton_c))))


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["M", "N", "K"],
        x_vals=[1024 * i for i in range(12, 21)],
        line_arg="provider",
        line_vals=["torch"] + ["tilelang_ps_" + str(i * 10) for i in range(-1, 3)],
        line_names=["Torch"] + ["Tilelang_ps_" + str(i * 10) for i in range(-1, 3)],
        plot_name="matmul-tflops",
        args={},
    )
)
def benchmark(M, N, K, provider):
    DEVICE = "cuda"
    a = torch.rand([M, K], device=DEVICE, dtype=torch.float16)
    b = torch.rand([K, N], device=DEVICE, dtype=torch.float16)
    mp = {"torch": lambda: torch.matmul(a, b)}
    for i in range(-1, 3):
        kernel = (
            kernel_matmul(M, N, K, 128, 128, 32, "float16", "float", i * 10)
            if i >= 0
            else kernel_matmul_no_l2_opt(
                M, N, K, 128, 128, 32, "float16", "float", i * 10
            )
        )
        tilelang_matmul = tilelang.compile(kernel, out_idx=-1)
        mp["tilelang_ps_" + str(i * 10)] = lambda: tilelang_matmul(a, b)
    ms = triton.testing.do_bench(mp[provider])
    torch.cuda.empty_cache()
    tflops = 2 * M * N * K * 1e-12 / (ms * 1e-3)
    return tflops


if __name__ == "__main__":
    torch.manual_seed(3407)
    test()
    benchmark.run(print_data=True, show_plots=False, save_path=".")

程序输出

推荐的 Tilelang_ps_-10 是关掉优化。

Maxdiff is 0.0
matmul-tflops:
         M       Torch  Tilelang_ps_-10  Tilelang_ps_0  Tilelang_ps_10  Tilelang_ps_20
0  12288.0  218.185030       163.195512     160.787932      155.639821      167.653793
1  13312.0  218.176101       174.291067     172.917449      172.878523      176.228789
2  14336.0  214.400675       174.471258     167.645072      168.738852      159.228886
3  15360.0  219.363970       157.663036     169.193896      171.710036      169.996476
4  16384.0  220.999889       173.871241     174.994087      174.047391      170.979700
5  17408.0  220.840380       134.329065     137.557182      130.383652      168.406982
6  18432.0  216.221871       173.599278     123.953247      136.703516      170.494457
7  19456.0  217.925130       163.030749     120.179175      134.268942      175.348523
8  20480.0  197.038250       157.498531     137.681804      137.550856      110.879759

matmul-tflops

matmul.kernel.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# from tvm.script import tir as T

@T.prim_func
def gemm_schedule(A: T.Buffer((4096, 1024), "float16"), B: T.Buffer((1024, 2048), "float16"), C: T.Buffer((4096, 2048), "float16")):
    # with T.block("root"):
    bx = T.launch_thread("blockIdx.x", 16)
    by = T.launch_thread("blockIdx.y", 32)
    tx = T.launch_thread("threadIdx.x", 128)
    ty = T.launch_thread("threadIdx.y", 1)
    tz = T.launch_thread("threadIdx.z", 1)
    with T.block("tilelang_root"):
        T.reads(A[by * 128, 0:993], B[0:1024, bx * 128:bx * 128 + 128], C[by * 128, bx * 128])
        T.writes()
        A_shared = T.alloc_buffer((128, 32), "float16", scope="shared.dyn")
        B_shared = T.alloc_buffer((32, 128), "float16", scope="shared.dyn")
        C_local = T.alloc_buffer((128, 128), scope="local.fragment")
        T.attr(None, "threadblock_swizzle_pattern", "tl::rasterization2DRow<10>")
        T.fill(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 16384, 2), 0)
        for ko in T.serial(32, annotations={"num_stages": 3}):
            T.copy(T.region(A[by * 128, ko * 32], 1, 128, 32), T.region(A_shared[0, 0], 2, 128, 32))
            for k in T.parallel(32):
                for j in T.parallel(128):
                    B_shared[k, j] = B[ko * 32 + k, bx * 128 + j]
            T.gemm(T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 4096, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 4096, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 16384, 3), T.bool(False), T.bool(False), 128, 128, 32, 0, T.bool(False), 1, 0)
        T.copy(T.region(C_local[0, 0], 1, 128, 128), T.region(C[by * 128, bx * 128], 2, 128, 128))

matmul.kernel.cu

五年前自己写的矩阵乘(见原文代码的 724 行) 大概也是这个画风…

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>

extern "C" __global__ void gemm_schedule_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C);
extern "C" __global__ void __launch_bounds__(128, 1) gemm_schedule_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
  extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
  float C_local[128];
  const dim3 blockIdx = tl::rasterization2DRow<10>();
  #pragma unroll
  for (int i = 0; i < 64; ++i) {
    *(float2*)(C_local + (i * 2)) = make_float2(0.000000e+00f, 0.000000e+00f);
  }
  #pragma unroll
  for (int i_1 = 0; i_1 < 4; ++i_1) {
    tl::cp_async_gs<16>(buf_dyn_shmem+((((i_1 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), A+((((((int)blockIdx.y) * 131072) + (i_1 * 32768)) + ((((int)threadIdx.x) >> 2) * 1024)) + ((((int)threadIdx.x) & 3) * 8)));
  }
  #pragma unroll
  for (int i_2 = 0; i_2 < 4; ++i_2) {
    tl::cp_async_gs<16>(buf_dyn_shmem+(((((((((((int)threadIdx.x) & 15) >> 3) * 4096) + (i_2 * 1024)) + ((((int)threadIdx.x) >> 4) * 128)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 24576), B+((((i_2 * 16384) + ((((int)threadIdx.x) >> 4) * 2048)) + (((int)blockIdx.x) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
  }
  tl::cp_async_commit();
  #pragma unroll
  for (int i_3 = 0; i_3 < 4; ++i_3) {
    tl::cp_async_gs<16>(buf_dyn_shmem+(((((i_3 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 8192), A+(((((((int)blockIdx.y) * 131072) + (i_3 * 32768)) + ((((int)threadIdx.x) >> 2) * 1024)) + ((((int)threadIdx.x) & 3) * 8)) + 32));
  }
  #pragma unroll
  for (int i_4 = 0; i_4 < 4; ++i_4) {
    tl::cp_async_gs<16>(buf_dyn_shmem+(((((((((((int)threadIdx.x) & 15) >> 3) * 4096) + (i_4 * 1024)) + ((((int)threadIdx.x) >> 4) * 128)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 32768), B+(((((i_4 * 16384) + ((((int)threadIdx.x) >> 4) * 2048)) + (((int)blockIdx.x) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 65536));
  }
  tl::cp_async_commit();
  for (int ko = 0; ko < 30; ++ko) {
    __syncthreads();
    #pragma unroll
    for (int i_5 = 0; i_5 < 4; ++i_5) {
      tl::cp_async_gs<16>(buf_dyn_shmem+(((((((ko + 2) % 3) * 8192) + (i_5 * 2048)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), A+((((((((int)blockIdx.y) * 131072) + (i_5 * 32768)) + ((((int)threadIdx.x) >> 2) * 1024)) + (ko * 32)) + ((((int)threadIdx.x) & 3) * 8)) + 64));
    }
    #pragma unroll
    for (int i_6 = 0; i_6 < 4; ++i_6) {
      tl::cp_async_gs<16>(buf_dyn_shmem+((((((((((ko + 2) % 3) * 8192) + (((((int)threadIdx.x) & 15) >> 3) * 4096)) + (i_6 * 1024)) + ((((int)threadIdx.x) >> 4) * 128)) + ((((((int)threadIdx.x) >> 6) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 24576), B+((((((ko * 65536) + (i_6 * 16384)) + ((((int)threadIdx.x) >> 4) * 2048)) + (((int)blockIdx.x) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 131072));
    }
    tl::cp_async_commit();
    tl::cp_async_wait<2>();
    __syncthreads();
    tl::gemm_ss<128, 128, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[((ko % 3) * 4096)])), (&(((half_t*)buf_dyn_shmem)[(((ko % 3) * 4096) + 12288)])), (&(C_local[0])));
  }
  tl::cp_async_wait<1>();
  __syncthreads();
  tl::gemm_ss<128, 128, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[0])), (&(((half_t*)buf_dyn_shmem)[12288])), (&(C_local[0])));
  tl::cp_async_wait<0>();
  __syncthreads();
  tl::gemm_ss<128, 128, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[4096])), (&(((half_t*)buf_dyn_shmem)[16384])), (&(C_local[0])));
  #pragma unroll
  for (int i_7 = 0; i_7 < 64; ++i_7) {
    uint1 __1;
    float2 v_ = *(float2*)(C_local + (i_7 * 2));
    ((half2*)(&(__1.x)))->x = (half_t)(v_.x);
    ((half2*)(&(__1.x)))->y = (half_t)(v_.y);
    *(uint1*)(C + (((((((((((int)blockIdx.y) * 262144) + (((i_7 & 7) >> 1) * 65536)) + (((((int)threadIdx.x) & 63) >> 5) * 32768)) + ((i_7 & 1) * 16384)) + (((((int)threadIdx.x) & 31) >> 2) * 2048)) + (((int)blockIdx.x) * 128)) + ((i_7 >> 3) * 16)) + ((((int)threadIdx.x) >> 6) * 8)) + ((((int)threadIdx.x) & 3) * 2))) = __1;
  }
}


#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];

extern "C" const char* get_last_error() {
    return error_buf;
}

extern "C" int init() {
    error_buf[0] = '\0';
    
    cudaError_t result_gemm_schedule_kernel = cudaFuncSetAttribute(gemm_schedule_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 49152);
    if (result_gemm_schedule_kernel != CUDA_SUCCESS) {
        snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", 49152, cudaGetErrorString(result_gemm_schedule_kernel));
        return -1;
    }

    return 0;
}

extern "C" int call(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C, cudaStream_t stream=cudaStreamDefault) {
	gemm_schedule_kernel<<<dim3(16, 32, 1), dim3(128, 1, 1), 49152, stream>>>(A, B, C);
	TILELANG_CHECK_LAST_ERROR("gemm_schedule_kernel");

	return 0;
}

Related posts

Loading comments...