NVIDIA闪存注意力实操:极速提升4倍性能

在本文中,我们深入探讨了现代人工智能中至关重要的工作负载:闪存注意力。你将学习如何使用NVIDIA cuTile实现闪存注意力,并进行生产级实现的全面代码演示。我们将带你穿越“陷阱与救援”的优化之旅,并掌握FMA模式、快速数学、循环拆分以及自适应平铺等高级技术,以实现最大性能。
首先,我们来了解实现闪存注意力的基本环境要求:需要CUDA 13.1以及更高版本。而GPU架构则以NVIDIA Blackwell系列为例,如NVIDIA B200或GeForce RTX 50系列。此外,Python版本需为3.10或更高。
什么是注意力机制?
注意力机制是Transformers模型的计算核心。对给定的序列令牌,注意力机制允许每个令牌“查看”其他令牌并决定权重比例。其数学公式为:
[O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V]
内存带宽问题
对于序列长度为(N = 16,384)的现代语言模型,注意力矩阵(QK^{T})包含(N^2 = 268)百万个元素。在FP16下,每个注意力头、每个批量项需要中间存储512MB。
标准注意力实现
- 计算完整的(N \times N)注意力矩阵并将其写入全局内存(速度慢)。
- 按行应用softmax。
- 读回矩阵并与(V)相乘。
这种方法由于GPU将大部分时间用于等待数据在HBM和计算单元之间移动,而不是进行计算,因此成为内存瓶颈。
闪存注意力如何解决内存带宽问题
闪存注意力是一种IO意识算法,它从不合成完整的(N \times N)矩阵。相反,它:
- 将计算切分:将(Q, K, V)处理为适合快速片上SMEM的小块。
- 使用在线softmax:不需要完整行就能逐渐计算softmax。
- 融合操作:将矩阵乘法和softmax结合成一个内核通道。
结果是速度提高2-4倍,显著的内存节省,支持更长的上下文化。
在线Softmax的理解
闪存注意力的关键算法洞察是在线softmax技巧。安全的softmax需计算整行的最大值:
[\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}]
但当我们处理小块时,并没有整行的访问权限。在线softmax通过更新逐渐累积的数据解决了这一问题。
在线Softmax算法
我们为每一行维护两个运行值:
- (m_i):到目前为止看到的最大值(用于数值稳定性)。
- (l_i):到目前为止看到的指数和(softmax分母)。
处理新块时:
- 更新最大值:(m_{new} = \max(m_i, \max(x_{new})))。
- 计算修正因子:(\alpha = e^{m_i - m_{new}})(缩放先前计算)。
- 更新和:(l_i = l_i \cdot \alpha + \sum e^{x_{new} - m_{new}})。
- 更新累加器:(acc = acc \cdot \alpha + P_{new} \cdot V_{tile})。
最终归一化:
[O = acc / l_i]
这样无需存储完整行即可计算精确的softmax。
因果注意力与分组查询注意力
因果注意力
在自回归语言模型如GPT、LLaMA等中,每个令牌只能关注序列中之前的令牌,而不关注未来的令牌。这防止模型在训练中“作弊”,提前查看预测的下一个词。我们在注意力得分上应用三角掩码:
[\text{mask}_{ij} = \begin{cases} 0 & \text{if } i \geq j \text{ (query position ≥ key position)} \ -\infty & \text{if } i < j \text{ (future tokens)} \end{cases}]
掩码注意力可以这样表达:
[O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}} + \text{mask}\right)V]
在softmax后,向未来位置加上(-\infty)确保它们变为零,有效地阻止了来自未来令牌的信息流动。
使用因果掩码,约半个注意力矩阵被掩码(上三角)。这一优化为K循环拆分提供了至关重要的2倍算法速度提升。
分组查询注意力
标准多头注意力为每个注意力头分配单独的(K,V)矩阵,导致高内存使用:
- 多头注意力(MHA):32个查询头 → 32个K/V头(1:1比例)。
- 分组查询注意力(GQA):32个查询头 → 4个K/V头(8:1比例)。
- 多查询注意力(MQA):32个查询头 → 1个K/V头(32:1比例)。
在GQA中,多查询头共享相同的K/V头。这在推理期间将K/V缓存大小减少8倍,对于长上下文模型部署至关重要。现代语言模型如LlamA 2、Llama 3、Mistral和Qwen大量使用GQA。
在闪存注意力中实现时,每个CUDA块计算一个查询头的注意力,并加载合适的共享K/V头:
head_idx = bid_y % num_heads
kv_head_idx = head_idx // query_group_size
当查询组大小为8时,所有查询头0-7都映射到kv_head_idx = 0,共享内存中相同的K/V块。
Part 1: CUDA Tile中的闪存注意力内核实现
我们分步实现闪存注意力。我们的基线使用小的64×64块和直观代码,对但尚未优化。
1. 定义内核接口
在cuTile中,@ct.kernel装饰器标记一个Python函数作为GPU内核。我们通过ct.Constant[T]类型注解传递编译时常量。
import math
import cuda.tile as ct
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
INV_LOG_2 = 1.0 / math.log(2)
@ct.kernel()
def fmha_kernel(
Q, K, V, Out,
qk_scale: float,
input_pos: int,
TILE_D: ConstInt,
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
QUERY_GROUP_SIZE: ConstInt,
CAUSAL: ConstBool,
EVEN_K: ConstBool,
):
2. 块ID映射
每个CUDA块计算输出的一个块。使用ct.bid将2D网格映射到批次/头索引:
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
off_kv_h = head_idx // QUERY_GROUP_SIZE
3. 初始化累加器
在主循环之前,我们初始化在线softmax状态和输出累加器:
qk_scale = qk_scale * INV_LOG_2
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m += input_pos
offs_m = offs_m[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
我们使用float32来保持迭代softmax计算的数值精度。
4. 加载查询块
一次加载查询块,并在所有K/V迭代中复用:
q = ct.load(
Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
).reshape((TILE_M, TILE_D))
ct.load函数在块超出张量边缘时自动处理边界条件。
5. K/V块的主循环
这是闪存注意力的核心。我们迭代K/V块:
m_end = input_pos + (bid_x + 1) * TILE_M
k_seqlen = K.shape[2]
if CAUSAL:
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
for j in range(0, Tc):
k = ct.load(
K, index=(batch_idx, off_kv_h, 0, j),
shape=(1, 1, TILE_D, TILE_N), order=(0, 1, 3, 2), latency=2
).reshape((TILE_D, TILE_N))
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k, qk)
然后使用ct.mma执行cuTile矩阵乘法累加。
6. 应用因果掩码
对于自回归模型(GPT、Llama等),每个令牌只能关注之前的令牌:
if CAUSAL or not EVEN_K:
offs_n = j * TILE_N + offs_n_tile
mask = ct.full((TILE_M, TILE_N), True, dtype=ct.bool_)
if not EVEN_K:
mask = mask & (offs_n < k_seqlen)
if CAUSAL:
mask = mask & (offs_m >= offs_n)
mask = ct.where(mask, 0.0, -math.inf)
qk += mask
掩码后添加(-\infty)确保位置在softmax后变为零。
7. 在线Softmax更新
qk_max = ct.max(qk, axis=-1, keepdims=True)
qk_max_scaled = qk_max * qk_scale
m_ij = max(m_i, qk_max_scaled)
qk = qk * qk_scale
qk = qk - m_ij
p = ct.exp2(qk)
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij)
l_i = l_i * alpha
l_i = l_i + l_ij
acc = acc * alpha
8. 累加输出
最终加载值块并累加:
v = ct.load(
V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D), latency=4
).reshape((TILE_N, TILE_D))
p = p.astype(Q.dtype)
acc = ct.mma(p, v, acc)
m_i = m_ij
9. 最终规范化与存储
处理完所有块后,通过总和归一化并写入结果:
acc = ct.truediv(acc, l_i)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
启动内核:主机代码
我们来看看启动内核的主机代码:
import torch
from math import ceil
def tile_fmha(q, k, v, sm_scale=None, is_causal=True):
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(q.size(-1))
batch_size, num_heads, seq_len, head_dim = q.shape
_, num_kv_heads, _, _ = k.shape
query_group_size = num_heads // num_kv_heads
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
o = torch.empty_like(q)
TILE_M, TILE_N = 64, 64
grid_x = ceil(seq_len / TILE_M)
grid_y = batch_size * num_heads
grid = (grid_x, grid_y, 1)
EVEN_K = (k.shape[2] % TILE_N) == 0
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_kernel,
(q, k, v, o, sm_scale, 0, head_dim, num_heads, TILE_M, TILE_N, query_group_size, is_causal, EVEN_K)
)
return o
Part 2: 陷阱与救援优化之旅
我们在以下配置进行基准测试:
- 硬件:NVIDIA B200
- 批量:4
- 头数:32
- 头维度:128
- 注意力:因果
- Dtype:FP16
- 序列长度:1024, 2048, 4096, 8192, 16384
我们使用Nsight Compute解析每个步骤:
- LaunchStats
- Occupancy
- SpeedOfLight
- ComputeWorkloadAnalysis
- MemoryWorkloadAnalysis
基线性能
| SeqLen | Throughput (TFLOPS) |
|---|---|
| 1,024 | 330 |
| 2,048 | 441 |
| 4,096 | 511 |
| 8,192 | 546 |
| 16,384 | 566 |
这是我们用64×64块和无优化开始的点。
1. 较大块的陷阱
在GPU编程中常有的直觉是“更大块=更好性能”。更大块可:
- 分摊内存访问开销。
- 改善L2缓存利用率。
- 减少每个元素的内核启动开销。
我们将块大小从64×64增至256×128:
TILE_M, TILE_N = 256, 128
预期的结果是带宽利用率更好→更快。但实际上TFLOPS的结果是:
| SeqLen | Baseline (64×64) | Larger tiles (256×128) | Performance Degradation |
|---|---|---|---|
| 1,024 | 330 | 187 | -43% |
| 2,048 | 441 | 268 | -39% |
| 4,096 | 511 | 347 | -32% |
| 8,192 | 546 | 415 | -24% |
| 16,384 | 566 | 463 | -18% |
性能在所有序列长度上下降了18-43%。这是陷阱,大块反而使性能下降。
为什么会发生这种情况?
- 计算瓶颈:块包含更多元素,效率不高的运算成为瓶颈。
- 指令开销:块工作量大意味着更多指令。
教训:块大小和计算效率是互相依存的。大块只有在计算效率足够高时才有帮助。
2. 使用快速数学拯救
其中的瓶颈是特殊函数:exp2(指数)和truediv(除法)。默认情况下,它们是IEEE-754精确的——高精度,但速度慢。为深度学习,我们可以在牺牲一丁点精度的情况下大幅加速:
之前(精确运算):
p = ct.exp2(qk)
alpha = ct.exp2(m_i - m_ij)
acc = ct.truediv(acc, l_i)
之后(快速数学):
p = ct.exp2(qk, flush_to_zero=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
使用快速数学,我们“拯救”了大块,TFLOPS结果是:
| SeqLen | Larger tiles (trap) | Fast math (rescue) | Improvement |
|---|---|---|---|
| 1,024 | 187 | 322 | +72% |
| 2,048 | 268 | 436 | +63% |
| 4,096 | 347 | 524 | +51% |
| 8,192 | 415 | 585 | +41% |
| 16,384 | 463 | 620 | +34% |
3. K循环拆分
对于因果注意力,我们应用三角掩码:每个查询仅能关注早于位置的键。在我们的基线中,在每次循环迭代中检查CAUSAL: mask…。但想想看,对位置1000的查询块,大多键块(0-900)根本不需要掩码。只有接近对角线的块需要掩码。而位置超出查询的块则被完全掩码(我们可以完全跳过它们)。
优化将循环分为阶段:
mask_start = (input_pos + bid_x * TILE_M) // TILE_N
mask_start = min(mask_start, k_seqlen // TILE_N)
if CAUSAL:
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
for j in range(0, Tc):
if (CAUSAL or not EVEN_K) and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = ct.full((TILE_M, TILE_N), True, dtype=ct.bool_)
if not EVEN_K:
mask = mask & (offs_n < k_seqlen)
if CAUSAL:
mask = mask & (offs_m >= offs_n)
mask = ct.where(mask, 0.0, -math.inf)
qk += mask
为何这至关重要:
- 对于16K序列和256令牌块:
- ~50%的块完全未掩码(无分支,无掩码计算)。
- 每行约1个块部分掩码(完整逻辑)。
- 其余块完全被跳过(提前退出)。
TFLOPS结果:
| SeqLen | Fast math | Loop split | Improvement |
|---|---|---|---|
| 1,024 | 322 | 373 | +16% |
| 2,048 | 436 | 552 | +27% |
| 4,096 | 524 | 684 | +31% |
| 8,192 | 585 | 770 | +32% |
| 16,384 | 620 | 813 | +31% |
4. ProgramId重映射
对于因果注意力,一个微妙的优化是反转块顺序。反向处理块(右下到左上)时,由于因果掩码,后启动的块工作量更少。这改善了负载均衡,减少了尾效应。
之前(标准顺序):
bid_x = ct.bid(0)
之后(因果反转):
if CAUSAL:
bid_x = NUM_M_BLOCKS - 1 - ct.bid(0)
else:
bid_x = ct.bid(0)
TFLOPS结果:
| SeqLen | Loop split | Remapping | Improvement |
|---|---|---|---|
| 1,024 | 373 | 377 | +1% |
| 2,048 | 552 | 560 | +1.5% |
| 4,096 | 684 | 696 | +1.8% |
| 8,192 | 770 | 781 | +1.5% |
| 16,384 | 813 | 835 | +2.6% |
5. 自动调优
我们优化了大块,但有一个问题:短序列仍然需要小块。为什么?对于1,024令牌序列和256令牌块,我们只有4个块。这不足以充分利用B200上的所有SM。小块(64×64)给我们16个块,更好地填满GPU。
而不是手动选择阈值,我们可以让cuTile的自动调优器对多个配置进行基准测试并为每个输入形状缓存最佳配置。
自动调优器的工作方式:
def _fmha_autotune_configs():
gpu_capability = torch.cuda.get_device_capability()
if gpu_capability in [(12, 0), (12, 1)]:
yield SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2)
else:
yield SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2)
yield SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=1, occupancy=2)
yield SimpleNamespace(TILE_M=256, TILE_N=128, num_ctas=1, occupancy=1)
如何启动自动调优:
import cuda.tile_experimental as ct_experimental
def autotune_launch_fmha(
stream, q, k, v, o, sm_scale, input_pos, hidden_size, num_heads, query_group_size, is_causal
):
batch_size, _, q_len, _ = q.shape
def _grid_fn(cfg):
return (math.ceil(q_len / cfg.TILE_M), batch_size * num_heads, 1)
def _args_fn(cfg):
num_m_blocks = math.ceil(q_len / cfg.TILE_M)
even_k = (k.shape[2] % cfg.TILE_N) == 0
return (
q, k, v, o, sm_scale, input_pos, hidden_size, num_heads,
cfg.TILE_M, cfg.TILE_N, query_group_size, is_causal, even_k, num_m_blocks
)
ct_experimental.autotune_launch(
stream, grid_fn=_grid_fn, kernel=fmha_kernel, args_fn=_args_fn,
hints_fn=lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy},
search_space=_fmha_autotune_configs,
)
自动调优器智能发现:
- 第一次调用
seq_len=1024:对所有3个配置进行基准测试,缓存最优的。 - 第一次调用
seq_len=2048:对所有3个配置进行基准测试,缓存最优的。 - 后续调用:使用缓存配置(零开销)。
TFLOPS结果:
| SeqLen | Baseline | Remapping | Autotune | Speedup vs baseline |
|---|---|---|---|---|
| 1,024 | 330 | 377 | 548 | 1.66x |
| 2,048 | 441 | 560 | 708 | 1.61x |
| 4,096 | 511 | 696 | 817 | 1.60x |
| 8,192 | 546 | 781 | 887 | 1.62x |
| 16,384 | 566 | 835 | 918 | 1.62x |
自动调优器发现对于短序列≤2048,64×64块最佳,然后为长序列切换到大块。与固定大块相比,短序列提供额外45%的性能,并同时保持长序列的峰值性能。
结论
编写高性能内核很少是关于找到一个“魔法”设置。正如我们在“陷阱与救援”中看到的:
- 优化是相互依存的:大块在我们修复数学问题之前更慢。无法在孤立情境下评估块大小。
- 数学重要性:诸如
flush_to_zero和APPROX标志对于解锁张量核心吞吐至关重要。深度学习精确数学往往过于冗余。 - 算法成功复合:K循环拆分通过避免不必要的工作给我们带来了最大的单一提升。
- 自动调优胜于手动启发式:cuTile的自动调优器发现了序列长度的最佳块大小,通过交替设置实现15-46%的提升。
- 累积效应是乘法:完整优化堆栈在所有序列长度上提供了令人印象深刻的加速——远超单独任何一个优化。
新媒网跨境认为,这项研究揭示了深入思考每一个细节的重要性,为我们提供了宝贵的实战经验与洞察。
新媒网(公号: 新媒网跨境发布),是一个专业的跨境电商、游戏、支付、贸易和广告社区平台,为百万跨境人传递最新的海外淘金精准资讯情报。
本文来源:新媒网 https://nmedialink.com/posts/nvidia-flash-attention-boost-4x-performance.html


粤公网安备 44011302004783号 













