post on 21 Nov 2024 about 7478words require 25min
CC BY 4.0 (除特别声明或转载文章外)
如果这篇文章帮助到你,可以请我喝一杯咖啡~
本文基于 Triton 逐步实现 online-softmax 算子,并与 Torch 的性能进行比较。实验结果显示,我的版本相较于 Torch 稳定快 26.1%。
\[y_{i,j}=\frac{\exp (x_{i,j}-\max_j\lbrace x_{i,j}\rbrace)}{\max\lbrace\sum_j\exp (x_{i,j}-\max_j\lbrace x_{i,j}\rbrace),\epsilon\rbrace}\]本文只考虑二维张量下连续维度的 softmax。对于其他维度,转置(rearrange
)后再连续维度 softmax 是更快的。
softmax.py
分别实现了三个版本的 softmax:
n_col < 65536
) 时 triton 的算子融合方式性能较好,否则性能差。推测过多的寄存器/SMEM占用影响了 occupancy。n_col > 131072
时直接无法启动。平均每个元素 1 次 load 1 次 store。exp
交换也是可以接受的。尤其值得注意的是 126 行处处理的 nan
问题,真的 debug 了很久。
顺带一提,我认为区间规约操作 tl.max
tl.sum
开销很大,因此把他们都写在循环外面。
此外我还实现了一个 CACHE_OPT
策略,每次循环按照上次循环的相反方向进行,这样能尽可能 load 最近还停留在 cache 里的 block,这样做对 cache 更加友好。实测有效。
对 num_warps=32
、BLOCK_SIZE
进行了一些手动调优。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# spack load [email protected] [email protected]+cuda [email protected] [email protected]
# PATH=/usr/sbin:$PATH python3 softmax.py
import triton
import triton.language as tl
import torch
@triton.jit
def kernel_softmax_fuse(
x_ptr,
x_row_stride,
y_ptr,
y_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
row_idx = tl.program_id(0)
x_ptr += row_idx * x_row_stride
y_ptr += row_idx * y_row_stride
idx = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
x = tl.exp(x - tl.max(x))
eps = float(1e-9)
x /= tl.maximum(tl.sum(x), eps)
tl.store(y_ptr + idx, x, mask=idx < n_cols)
def triton_softmax_dim1_fuse(x):
n_rows, n_cols = x.shape
y = torch.empty_like(x)
kernel_softmax_fuse[[n_rows]](
x,
x.stride(0),
y,
y.stride(0),
n_cols,
BLOCK_SIZE=triton.next_power_of_2(n_cols),
num_warps=32,
)
return y
@triton.jit
def kernel_softmax_tile(
x_ptr,
x_row_stride,
y_ptr,
y_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
CACHE_OPT: tl.constexpr,
):
row_idx = tl.program_id(0)
x_ptr += row_idx * x_row_stride
y_ptr += row_idx * y_row_stride
mm = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - float("inf")
for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
mm = tl.maximum(mm, x)
mm = tl.max(mm)
ss = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
if CACHE_OPT:
for i in range(tl.cdiv(n_cols, BLOCK_SIZE) - 1, -1, -1):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
x = tl.exp(x - mm)
ss += x
else:
for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
x = tl.exp(x - mm)
ss += x
ss = tl.sum(ss)
eps = float(1e-9)
ss = tl.maximum(ss, eps)
for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
x = tl.exp(x - mm) / ss
tl.store(y_ptr + idx, x, mask=idx < n_cols)
def triton_softmax_dim1_tile(x, cache_opt=True):
n_rows, n_cols = x.shape
y = torch.empty_like(x)
kernel_softmax_tile[[n_rows]](
x,
x.stride(0),
y,
y.stride(0),
n_cols,
BLOCK_SIZE=2**14,
CACHE_OPT=cache_opt,
num_warps=32,
)
return y
@triton.jit
def kernel_softmax_online(
x_ptr,
x_row_stride,
y_ptr,
y_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
CACHE_OPT: tl.constexpr,
):
row_idx = tl.program_id(0)
x_ptr += row_idx * x_row_stride
y_ptr += row_idx * y_row_stride
mm = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - float("inf")
ss = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
mm_new = tl.maximum(mm, x)
if i: # 第 1 轮不需要,且容易整出 nan
ss *= tl.exp(mm - mm_new)
x = tl.exp(x - mm_new)
ss += tl.where(idx < n_cols, x, 0.0)
mm = mm_new
mm_new = tl.max(mm)
ss *= tl.exp(mm - mm_new)
ss = tl.sum(ss)
mm = mm_new
eps = float(1e-9)
ss = tl.maximum(ss, eps)
if CACHE_OPT:
for i in range(tl.cdiv(n_cols, BLOCK_SIZE) - 1, -1, -1):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
x = tl.exp(x - mm) / ss
tl.store(y_ptr + idx, x, mask=idx < n_cols)
else:
for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
x = tl.exp(x - mm) / ss
tl.store(y_ptr + idx, x, mask=idx < n_cols)
def triton_softmax_dim1_online(x, cache_opt=True):
n_rows, n_cols = x.shape
y = torch.empty_like(x)
kernel_softmax_online[[n_rows]](
x,
x.stride(0),
y,
y.stride(0),
n_cols,
BLOCK_SIZE=2**12,
CACHE_OPT=cache_opt,
num_warps=32,
)
return y
def test():
DEVICE = "cuda" # triton.runtime.driver.active.get_active_torch_device()
x = torch.rand([2**10, 2**15], device=DEVICE)
mp = {
"torch": lambda: torch.softmax(x, dim=1),
"triton_fuse": lambda: triton_softmax_dim1_fuse(x),
"triton_tile_no_cache": lambda: triton_softmax_dim1_tile(x, cache_opt=False),
"triton_tile": lambda: triton_softmax_dim1_tile(x),
"triton_online_no_cache": lambda: triton_softmax_dim1_online(
x, cache_opt=False
),
"triton_online": lambda: triton_softmax_dim1_online(x),
}
y_torch = mp["torch"]()
for k, v in mp.items():
y_triton = v()
print("{}: Maxdiff is {}".format(k, torch.max(torch.abs(y_torch - y_triton))))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["n_col"],
x_vals=[2**i for i in range(8, 18)], # triton maximum tensor numel (131072)
line_arg="provider",
line_vals=[
"torch",
"triton_fuse",
"triton_tile_no_cache",
"triton_tile",
"triton_online_no_cache",
"triton_online",
],
line_names=[
"Torch",
"Triton_fuse",
"Triton_tile_no_cache",
"Triton_tile",
"Triton_online_no_cache",
"Triton_online",
],
plot_name="softmax-time",
args={},
)
)
def benchmark(n_col, provider):
DEVICE = "cuda" # triton.runtime.driver.active.get_active_torch_device()
x = torch.rand([2**10, n_col], device=DEVICE)
mp = {
"torch": lambda: torch.softmax(x, dim=1),
"triton_fuse": lambda: triton_softmax_dim1_fuse(x),
"triton_tile_no_cache": lambda: triton_softmax_dim1_tile(x, cache_opt=False),
"triton_tile": lambda: triton_softmax_dim1_tile(x),
"triton_online_no_cache": lambda: triton_softmax_dim1_online(
x, cache_opt=False
),
"triton_online": lambda: triton_softmax_dim1_online(x),
}
return triton.testing.do_bench(mp[provider]) # ms
if __name__ == "__main__":
torch.manual_seed(3407)
test()
benchmark.run(print_data=True, show_plots=False, save_path=".")
torch: Maxdiff is 0.0
triton_fuse: Maxdiff is 1.0913936421275139e-11
triton_tile_no_cache: Maxdiff is 1.4551915228366852e-11
triton_tile: Maxdiff is 1.4551915228366852e-11
triton_online_no_cache: Maxdiff is 1.4551915228366852e-11
triton_online: Maxdiff is 1.4551915228366852e-11
softmax-time:
n_col Torch Triton_fuse Triton_tile_no_cache Triton_tile Triton_online_no_cache Triton_online
0 256.0 0.008916 0.016236 0.062424 0.063158 0.026684 0.025777
1 512.0 0.011025 0.016656 0.062662 0.063190 0.027446 0.026267
2 1024.0 0.015170 0.017557 0.063223 0.063944 0.028480 0.027188
3 2048.0 0.024382 0.021070 0.064796 0.065755 0.029716 0.028612
4 4096.0 0.035157 0.032228 0.071570 0.072163 0.035416 0.034703
5 8192.0 0.061981 0.056555 0.079998 0.079716 0.061953 0.061838
6 16384.0 0.120682 0.106090 0.121865 0.122116 0.125531 0.122809
7 32768.0 0.362888 0.208714 0.232481 0.229461 0.297884 0.269076
8 65536.0 0.797070 0.630735 0.667213 0.597031 0.606105 0.569532
9 131072.0 1.569379 4.005666 1.566432 1.389664 1.196355 1.159837
Related posts