Flash Attention: Triton 实现
Triton 实现
这部分我们详解 Triton 官方提供的 Flash Attention v2 实现,再加一些实际的性能测试。
性能测试
先上性能表现。我自己做了个 benchmark,对比了该实现与 naive attention 的性能表现。Naive Attention 的代码:
def pytorch_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool, sm_scale: float) -> torch.Tensor:
scores = torch.matmul(q, k.transpose(-1, -2)) * sm_scale
if causal:
seq_len = q.shape[-2]
causal_mask = torch.triu(
torch.ones((seq_len, seq_len), device=q.device, dtype=torch.bool),
diagonal=1,
)
scores = scores.masked_fill(causal_mask, float("-inf"))
probs = torch.softmax(scores.float(), dim=-1).to(dtype=q.dtype)
return torch.matmul(probs, v)
性能测试结果:


我们注意到 PyTorch 的 sequence length 范围要更小,这是因为显存爆炸了没办法跑。可见无论是性能还是显存占用,flash attention 表现都要更好。
代码详解
Grid 与并行策略
# fa_tutorial.py L550-L551
def grid(META):
return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
Grid 的 x 维度按 Q 的序列维度分块(每个 block 处理 BLOCK_M 行 Q),y 维度展开 batch 和 head。这意味着:
- 每个 thread block 负责一个 (batch, head, Q_block) 三元组
- 不同 block 之间完全独立,无通信
Q-outer 主循环:_attn_fwd
# fa_tutorial.py L191-L194
start_m = tl.program_id(0) # 当前 block 负责的 Q 行起始位置
off_hz = tl.program_id(1) # batch * head 索引
off_z = off_hz // H # batch 索引
off_h = off_hz % H # head 索引
program_id(0) 直接对应 Q 的分块索引 —— 这正是 FA2 “Q 在外层” 的体现。
接着初始化 online softmax 所需的状态变量:
# fa_tutorial.py L216-L218
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # 逐行 running max
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 # 逐行 running sum (分母)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) # 输出累加器
然后加载 Q tile 到 SRAM,整个内层循环中不再重新加载 Q:
# fa_tutorial.py L222-L223
qk_scale *= 1.44269504 # 1/log(2),用 exp2 代替 exp,避免昂贵的 exp 指令
q = desc_q.load([qo_offset_y, 0]) # Q 只加载一次,常驻 SRAM
内层循环:_attn_fwd_inner — 遍历 K/V
# fa_tutorial.py L69
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
每次迭代处理一个 K/V tile。核心计算步骤:
Step 1: 计算 QK^T
# fa_tutorial.py L72-L73
k = desc_k.load([offsetk_y, 0]).T # 加载 K tile 并转置
qk = tl.dot(q, k) # [BLOCK_M, BLOCK_N] = Q @ K^T
Step 2: Online Softmax 的 rescale
这是 online softmax 算法的核心。对于非 causal 路径:
# fa_tutorial.py L80-L81
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) # 更新 running max
qk = qk * qk_scale - m_ij[:, None] # 减去 max 做数值稳定
这里 m_i 是之前所有 K/V tile 的逐行最大值,m_ij 是合并当前 tile 后的新最大值。减去 max 后再做 exp,保证数值稳定性。
Step 3: 计算 attention weights 和修正因子
# fa_tutorial.py L82-L85
p = tl.math.exp2(qk) # P = exp(QK^T / sqrt(d) - max)
alpha = tl.math.exp2(m_i - m_ij) # 修正因子:旧 max 到新 max 的缩放
l_ij = tl.sum(p, 1) # 当前 tile 的 softmax 分母
alpha 是关键的 rescale 因子。因为之前累加的 acc 是基于旧的 m_i 计算的,现在 max 更新为 m_ij,所以需要用 alpha 将旧累加结果缩放到新的 scale 上。
Step 4: 更新累加器
# fa_tutorial.py L95
acc = acc * alpha[:, None] # 用修正因子 rescale 旧累加结果
# L101-L103
p = p.to(dtype)
acc = tl.dot(p, v, acc) # 累加当前 tile 的贡献:acc += P @ V
Step 5: 更新状态
# fa_tutorial.py L106-L107
l_i = l_i * alpha + l_ij # 更新 softmax 分母
m_i = m_ij # 更新 running max
注意 l_i 和 m_i 的更新被放在循环末尾,这是为了减少寄存器压力 —— 让编译器有更多空间调度前面的计算指令。
Causal Mask 的处理:off-band 与 on-band
对于 causal attention,Q 的第 i 行只能 attend 到 K 的前 i 个位置。代码将 K/V 的遍历范围分为两个区域:
# fa_tutorial.py L55-L62
if STAGE == 1: # off-band: Q 行之前的 K(全部合法,无需 mask)
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2: # on-band: 包含对角线的 K(需要 causal mask)
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
else: # non-causal: 全部 K(无需 mask)
lo, hi = 0, N_CTX
off-band 区域的所有 K 位置都在 Q 行之前,整块都满足 causal 条件,不需要逐元素 mask,直接走快速路径。
on-band 区域包含对角线,块内部分位置需要被 mask 掉:
# fa_tutorial.py L74-L78
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) # 非法位置加 -inf
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
这种分区域处理是工程上的优化:大多数 tile 走无 mask 的快速路径,只有对角线附近的少量 tile 才做逐元素 mask,最大化计算效率。
Epilogue:最终归一化
# fa_tutorial.py L242-L247
m_i += tl.math.log2(l_i) # 最终 log-sum-exp 值,存下来给 backward 用
acc = acc / l_i[:, None] # 最终除以 softmax 分母
tl.store(m_ptrs, m_i) # 保存 LSE 到 HBM(backward 需要)
desc_o.store([qo_offset_y, 0], acc.to(dtype)) # 写回输出
这里额外保存了 m_i + log2(l_i)(即 log-sum-exp 值)到 HBM,这是 backward pass 需要的中间结果。FA2 的 backward 需要 forward 的 LSE 来高效计算梯度。