Menu

浅读DeepGEMM

post on 03 Mar 2025 about 9227words require 31min
CC BY 4.0 (除特别声明或转载文章外)
如果这篇文章帮助到你,可以请我喝一杯咖啡~

五年前我曾经尝试过 Volta 上的 GEMM,能够接近当时 CUTLASS 的水平,可惜其可读性一直没达到能让自己满意的水平。拜读一下 DeepGEMM,一句话评价:比 CUTLASS 简洁、好上手(CUTLASS 为了兼容各种 Shape 和 Case 做了过于多的抽象,难以读懂的同时算法上束手束脚,DeepGEMM 只针对自己用的Contiguous Layout),很适合阅读。比较佩服的一点是 DS 能够自信自己的工程师技术水平优于 NV CUTLASS 团队的水平,敢于开启这个项目:从代码规模来看,整个项目很可能是单挑的(一个精巧的代码设计需要一个自上到下对算法细节全部精通的工程师),给我几个月不一定能写出来(写一个代码结构差不多的有可能,Debug 多久就要看命了)…

本文参考的 commit 是 dff6bb6,接下来逐文件分析代码结构。

include/deep_gemm/

这个目录是项目的核心,包含核心 GEMM 函数及其需要的一些基本操作的 utils 封装。

mma_utils.cuh

该文件封装了以 wgmma 为主的 PTX 指令。在阅读该文件前,推荐阅读:

  • wgmma 指令,Asynchronous Warpgroup Level Matrix Multiply-Accumulate Instructions 详细说明。

The wgmma instructions perform warpgroup level matrix multiply-and-accumulate operation…A warpgroup is a set of four contiguous warps such that the warp-rank of the first warp is a multiple of 4.

struct SM90_64x16x32_F32E4M3E4M3_SS 为例,其封装了 wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3,其含义是:

  • 矩阵乘的 shape 是 m64n16k32wgmma 指令对于 e4m3e5m2 的 dense 计算均要求 m=64,k=32,而n是8的倍数,大于等于 8 小于等于 256。具体的存储格式详见这里,如果过于展开也不过是对原文的翻译,本文的篇幅就要爆炸了。
  • sync 表明该指令会导致一次同步;aligned 表明调用该指令的线程必须调用同一地址的指令,不能在代码的不同分支分别执行。
  • 两个 e4m3 代表输入的格式,f32 代表 Accumulator D的格式。相信本文的读者不需要对 e4m3f32 的解释。
  • SS 意义不明,推测为 A 和 B 都在 SMEM 上。

随后使用 FP8MMASelector<N> 对上述不同的 N 进行了编译期的选择。

warpgroup_wait<N> 封装了warpgroup的同步操作。由于一个warpgroup中有128个线程,一个block中最多1024个线程,这里N一定小于8。

还封装了一些访存指令,例如ld_shared。注意该函数有一个 int4 版本,因为 NVGPU 一次访存的宽度是 128 bit,一次读四个效率是最高的。封装访存指令,而不是直接写 a = b[i] 等着编译器弄,应该是希望显式指定读取的地址是 SMEM,不要留到运行时由 ld 指令自己判断。

tma_utils.h

这一页的内容基本上是为了最后一个tma_copy<N>服务,调用 CUTLASS 中的 cute 库实现对 Hopper TMA 单元的利用。

utils.cuh

DG_HOST_ASSERTDG_DEVICE_ASSERT 等宏; ceil_div 向上取整,注意是 constexpr 的。

scheduler.cuh

这个文件用于计算不同的 GEMM 算法下当前的 block 在计算哪些块。这里引用此处的说法:

  • 常规稠密 GEMM:通过函数 deep_gemm.gemm_fp8_fp8_bf16_nt 调用,适用于常规矩阵乘法。
  • 分组 GEMM(连续布局,Contiguous Layout):针对 MoE 模型优化,仅对 M 轴分组,N 和 K 保持固定。这种设计适用于 MoE 专家共享相同形状的情况。将多个专家的 token 拼接成单一连续张量,适用于训练前向或推理预填充阶段。每个专家段需对齐到 GEMM 的 M 块大小。
  • 分组 GEMM(掩码分组,Masked Grouped GEMM):支持推理解码阶段,结合 CUDA Graph,适应动态 token 分配。这种分组策略与 CUTLASS 的传统分组 GEMM 不同,体现了 DeepGEMM 对 MoE 模型的针对性优化。
1
2
3
4
5
enum class GemmType {
    Normal,
    GroupedContiguous,
    GroupedMasked
};

根据 kernel launch 处的启动参数,启动的 block 数量是 num_sms,换言之启动的block数量与m、n、k无关,运行的block需要自行决定自己的计算任务(persistent kernel)。我认为这样做的好处是有利于Group GEMM,方便支持不同的 m、n。

get_next_block将这个物理block的任务从当前逻辑block切换到下一逻辑block。核心逻辑通过 get_swizzled_block_idx 函数实现,该函数用于计算m_block_idxn_block_idx48行的注释说这样做有助于L2利用,目的是尽量使一个block计算的 C 能够共享更多的 A/B,详细推导可参考我的这篇博客

1
2
3
4
5
6
7
8
9
10
11
12
template <bool kIgnoreGroupedForGroupedContiguous=true>
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
                                                   const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
    if constexpr (kGemmType == GemmType::Normal) {
        return block_idx * block_size;
    } else if (kGemmType == GemmType::GroupedContiguous) {
        auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
        return offset * shape_dim + block_idx * block_size;
    } else if (kGemmType == GemmType::GroupedMasked) {
        return curr_group_idx * shape_dim + block_idx * block_size;
    }
}

fp8_gemm.h

这个文件是整个库的核心。首先要介绍的是,这个算法使用了 WarpSpecialization (WS),这项技术最早出现在 SC’11 的 “CudaDMA: Optimizing GPU Memory Bandwidth via Warp Specialization” 这篇文章中。另一篇推荐阅读的文章是 HPCA’24 的 “WASP: Exploiting GPU Pipeline Parallelism with Hardware-Accelerated Automatic Warp Specialization”。在 2011 年 WS 其实是很不靠谱的做法,当时 gpu 多线程能力太差。但是 Hopper 中引入了以 TMA 为首的大量异步操作(NV 称之为第一代异步 GPU),WS 得以大放光彩。可能 DeepGEMM 胜过 cutlass 的很大原因是后者因为架构包袱没有及时跟上 WS 算法(未查证 cutlass 实现)?

简单来说,使用了 128 个线程做 TMA(生产者),128/256 个线程做计算(消费者)。这样做的好处我能想到两点:计算访存重叠(TLP相较于ILP重叠的机会更大,GEMM时活跃线程数很小,可以充分利用剩下的线程做生产者);sync() 的粒度可以变小。

366 行有个注释,因为 153 行使用了 setmaxnreg.aligned 指令释放了生产者线程使用的多余的寄存器(由于 TMA 直接将数据搬到 SMEM 不经过寄存器,其需求远小于消费者线程),所以生产者线程需要是128的倍数。wgmma 指令需要的一个 warp group 是 128 个,所以消费者数量需要是 128 的倍数。根据 377行,blockDim 在 block_m 是 64 的时候是 256,要么是 384。这两个数字我这样理解:根据 145 行 中所说,消费者线程需要 232 个寄存器,生产者线程是 40 个寄存器,一个 block 最多使用 $256\times 232+128\times 40=64512$ 个寄存器,一个 SM 里一共 65536 个寄存器,算上每 block 保留的 1K 寄存器,刚好能塞进去。如果没有 setmaxnreg.aligned,那消费者线程最多只能有 128 个。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K, BLOCK_M, BLOCK_N, BLOCK_K,
                              kNumGroups, kNumStages, kNumTMAThreads, kNumMathThreadsPerGroup,
                              kNumTMAMulticast, kGemmType>;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);

// Cluster launch
cudaLaunchConfig_t config;
config.gridDim = num_sms;
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
config.dynamicSmemBytes = smem_size;
config.stream = stream;

接下来进入正戏。30行 的几个模板参数含义如下:

1
2
3
4
5
6
template <uint32_t SHAPE_N, uint32_t SHAPE_K, // 全局矩阵维度N和K;shape_m 是函数参数
          uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, // 线程块处理的分块维度
          uint32_t kNumGroups /* 分组数(支持 MoE 等分组 GEMM)*/, uint32_t kNumStages /* 流水线阶段数 */,
          uint32_t kNumTMAThreads /* TMA 和计算线程数 */, uint32_t kNumMathThreadsPerGroup /* TMA 多播集群大小 */,
          uint32_t kNumTMAMulticast, // TMA 多播集群大小
          GemmType kGemmType> // GEMM 类型(普通、连续分组、掩码分组)

44行127行 在做准备工作,主要是多 Buffer 的偏移地址计算。

132行定义了一个 lambda,使用 kNumStages 阶段流水线,通过 launch_k_iterations 处理 K 维度的迭代,支持可整除和不可整除的 K 维度,通过模板元编程优化展开。

1
2
3
4
5
6
7
8
9
10
11
12
13
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [](const auto& func) {
    if constexpr (SHAPE_K % kFullKOfAllStages == 0) {
        for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter)
            func(k_iter, DivisibleK{});
    } else {
        for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
            func(k_iter, DivisibleK{});
        func(kNumIterations - 1, NotDivisibleK{});
    }
};

DivisibleKNotDivisibleK160行 被用到,用于判断当前iter需要做多少Stage(kNumInnerStages)里,这个模板元体操真的绕了我一下!

1
2
3
4
5
launch_k_iterations([&](int k_iter, auto type) {
    constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
    constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
    // ...
}

167行是做 TMA 的生产者线程。对每个 Stage s,通过TMA Broadcast读A,通过TMA读B(没必要 Broadcast),随后对 full_barriers[s].arrive_and_expect_tx188行处理未对齐的剩余的 Stage。

203行 是最核心的消费者线程。作者非常善良,这里基本上每个 Loop 都有注释标注。不过222行 “overlap loading B scales with TMA stores between tasks” 没怎么看懂,推测是强制让不同warp的工作量产生差异从而产生pipeline…接下来和生产者线程类似的进入 Stages。

259行等待生产者读好数据。265行开始正式做核心的矩阵乘,调用了上文中的 FP8MMASelector<>。由于所有的操作都在 SMEM 上进行,所以意外的简洁…

1
2
3
4
5
6
7
8
9
10
11
12
13
// Commit WGMMA instructions
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
    warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
    auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1)
    auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
    WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
    warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();

252行285行计算了矩阵乘的 scal,通过常规的 CUDA Core 实现。scale 和 accum 都是 fp32 以保证精度。

1
2
3
4
5
6
7
8
9
10
11
12
13
// Promote with scales
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
    scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
    bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
    final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
    final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
    final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
    final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}

当消费者线程有256个的时候,两个WG也许会存在TensorCore与CUDA Core的Overlapp,进一步提高硬件资源利用率。

310行开始是一段 Magic Code,负责写回计算结果。

1
2
3
4
5
6
7
8
9
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
    SM90_U32x4_STSM_N<nv_bfloat162>::copy(
        __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}),
        __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}),
        __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}),
        __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}),
        smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)
    );
}

由于前文我刻意跳过了 wgemm 指令数据 Layout 介绍(懒),这里很难细致解释,我引用知乎 @yuukyuu,非常专业,推荐去看一看:

为了简单的解释明白, 先从 SM90_U32x1_STSM_N ( stmatrix.sync.aligned.x1.m8n8.shared.b16 ) 说起, 本质就是一个 warp ( 32 threads ) 每个 thread 给一个 32 位寄存器, 写入一个 8x8 的矩阵 ( 每个元素 16 位 ). 然后连续的每 4 个线程 ( 共 4 个 32 位元素 ) 填充一行, 而传入的另外一个参数就是每一行的首地址 ( 一共 8 个, 由前 8 个 thread 提供 ). 然后再是 SM90_U32x4_STSM_N, 简单理解就是上面的逻辑执行了 4 次, 当然也就需要 32 个首地址, 刚好每个线程提供一个.

考虑到之前 WGMMA 说 CLayout 时, 每个 thread 的 C 元素刚好是临近的两个在 N 方向连续, 连续的 4 个 thread 又刚好组成连续的一行 8 个元素, 因此可以完美的使用此指令进行数据搬运, 剩下的问题就是如何计算这 32 个首地址了.

结合 CLayout ((4, 8, 4), (2, 2, N/8)):((128, 1, 16), (64, 8, 512)), 可以发现每个 warp ( 32 thread ) 的结果刚好是连续的 16 行每行 N 个元素. 又由于使用 U32x4 指令一次最多存储 8 * 32 = 256 个 ( 16 位 ) 元素, 所以代码中每次每个线程存储 8 个元素, 也就是存储了 16 行每行 16 个元素. 因此, 这 32 个首地址分别就是 16 行的第 0 个元素位置, 外加 16 行的第 8 个元素位置.

回到代码可以发现 warp_idx * 16 表示当前 warp 的第 0 行位置 ( 上一节说了每个 warp 处理 16 行 ), 而 lane_idx % 16 表示了地址的行偏移量 ( 由于后 16 个地址行偏移量和前 16 个地址一样, 所以这里 % 16 ). 而列偏移量中, 由于每一批数据写入每行 16 个元素, 因此 i * 16 表示当前这一批数据的列偏移量, 而 8 * (lane_idx / 16) 则是表示前 16 个地址还是后 16 个地址 ( 后 16 个地址是当前行的第 8 个元素)

其他

关于 JIT

由于 N、K 均为模板参数,使得其非常适合通过 JIT 的方式将分块参数/调度算法搜索从运行时移到运行前,这是其优于 cublas/CUTLASS 的重要原因。

关于 SASS 优化

网上很多人说这是很骚的操作。简单来说,作者发现 CUDA 12.3 后 CUTLASS 的性能有提升,对比 SASS 后发现某些 bit 发生了反转,猜测这些 bit 能够放松指令间的依赖关系以获取更高的并行度,于是写了一个脚本专门用于处理生成的 SASS。Respect!

在 SASS 层面优化矩阵乘曾经是 cublas 库的做法,但是逐渐被 CUTLASS 取代;此外,还有一些尝试通过调整 SASS 中寄存器 Layout,减少 Register Conflict 的(不错,register 与 SMEM 一样都有 conflict,但是每一代 GPU 的都不相同,很难从黑箱中推测)。引用此处的说法:

虽然使用 SASS 能充分挖掘 GPU 的性能,但面临有三大问题:

  1. 第三方 NV GPU 汇编器依赖于对 GPU 架构的逆向研究,可能因为没有探究到全部的硬件底层细节而存在未知的 BUG;
  2. 汇编 Kernel 难于开发,更难于调试;
  3. NV 每一代 GPU 的 ISA(指令集)都不尽相同,需要不断开发对应的汇编器和汇编 Kernel。

正因为这几大问题的存在,使得使用 SASS 编写 Kernel 是个费时费力的工作,除非有追求极致性能的需求,否则不建议轻易尝试。

值得吐槽的一点是 NV 的硬件发展速度太快,Ampere 到 Hopper 的改变大到自家的 compile 团队都难以跟上。在我眼里 NV 的 compile 团队应该是最强的…另一方面,CUDA 12.3 CUTLASS 性能的提升也说明 NV 的 arch 团队、compile 团队、CUTLASS 团队三者能够协同进化。反观国内互联网企业是否应当想想,部门间甚至小组间竞争是否多于协作?

总结

DeepGEMM 是一个精炼的矩阵乘法范例,在最重要的矩阵乘消费者线程中给出了恰到好处的注释。亦或者是我只放出了对应代码的行数而非直接给出所有对应代码,本文的篇幅比预期中要小很多,这更加能说明这个库的精妙之处。

(完)

Related posts

Loading comments...