CS336 第五讲 · GPU 与 GPU 优化:从硬件模型到 FlashAttention 与 PagedAttention

June 14, 2026

CS336 第五讲 · GPU 与 GPU 优化:从硬件模型到 FlashAttention 与 PagedAttention

把一块现代 GPU 比作一座工厂,它的"机器"(计算单元)多到几乎用不完,真正的瓶颈是"传送带"(访存带宽)——原料喂不进来,再多机器也只能空转。这句话听起来像个比喻,却几乎是理解所有 GPU 优化的钥匙:GPU 算得太快,喂不饱数据。于是绝大多数优化技巧,无论名字多花哨,最终都收敛到同一个目标——减少对慢速显存的访问,让数据尽量在快速的片上存储里被反复复用

本文对应斯坦福 CS336 第五讲,也对应 Datawhale diy-llm 第六章。原文内容极其翔实,但也夹杂了大量产品参数罗列,读起来略显冗长。这里把它重新组织成一条主线 + 两个案例:先用"算力 vs 访存"的失衡建立直觉,再依次讲清 GPU 的执行模型、内存模型与屋顶线模型,把五种优化技术统一到"减少访存"这一总纲下,最后落到 FlashAttention 与 PagedAttention 两个把这些原则用到极致的真实系统。核心概念配以重绘的示意图,力求让每个结论都"看得见"。

一条主线:GPU 的瓶颈不在算力,而在访存

要理解 GPU 为什么这样设计,先得看清它和 CPU 在哲学上的根本分歧。

CPU 优化的是延迟(latency):它假设任务之间有强依赖,目标是让单个任务尽可能快地完成。所以 CPU 把大量晶体管投给了控制逻辑和多级缓存,只配几个到几十个强大的核心,每个核心都能乱序执行、分支预测、深度流水。

GPU 优化的是吞吐(throughput):它假设有海量彼此独立的任务,目标是让所有任务整体尽快完成,不在乎单个任务的快慢。所以 GPU 把绝大部分晶体管都投给了计算单元(ALU),用成千上万个简单核心去"人海战术"。

CPU 与 GPU 的设计哲学对比:延迟优化 vs 吞吐优化

这个选择带来一个副作用:当算力被堆到极致,把数据从显存搬到计算单元的速度就成了短板。后文会看到,过去十年里 GPU 的算力增长了约 10 万倍,而显存带宽只增长了约 100 倍——这道越拉越大的鸿沟,正是所有优化技术真正要对抗的敌人。

GPU 的执行模型:Grid / Block / Warp / Thread

GPU 的算力来自一颗芯片上密密麻麻的流式多处理器(Streaming Multiprocessor, SM)。以 A100 为例,一颗 GA100 芯片用 7nm 工艺集成了 542 亿晶体管,包含上百个 SM、几千个 CUDA 核心和数百个专门做矩阵乘法的 Tensor Core。要驱动这么多核心,CUDA 用一套四层的执行模型把任务切碎、铺满硬件:

GPU 的四层执行模型:Grid、Block、Warp、Thread 与 SIMT

  • Thread(线程):最细粒度的执行单元,拥有自己的私有寄存器。
  • Warp(线程束):32 个线程绑成一束,是 SM 调度的最小单位。同一个 warp 里的 32 个线程在同一时刻执行同一条指令,只是各自处理不同的数据——这就是 GPU 的 SIMT(单指令多线程)模型。
  • Block(线程块):若干 warp 组成一个 block,整体被分配到一个 SM 上执行,块内线程共享这个 SM 的共享内存与寄存器。
  • Grid(网格):一次 kernel 启动的所有 block 构成一个 grid,由硬件调度器铺满所有可用的 SM。

SIMT 模型藏着一个性能陷阱:warp 分支发散(Warp Divergence)。如果一个 warp 里的线程因为 if/else 走了不同分支,硬件没法让它们真正并行,只能用掩码让两条路径依次串行执行——走 A 路径时屏蔽走 B 的线程,反之亦然。结果是同一段时间里只有部分线程在干活,算力利用率被打了折扣。这也是为什么 GPU 代码要尽量避免 warp 内的数据相关分支。

GPU 的内存模型:一座延迟金字塔

执行模型决定了"谁来算",内存模型决定了"数据从哪来"。GPU 的存储不是铁板一块,而是一座分层的金字塔:越靠近计算单元的越快、越小、越贵;越远离的越慢、越大、越便宜。

GPU 的分层内存:从寄存器到全局显存的延迟金字塔

存储类型位置容量(A100)延迟可见范围
寄存器 RegisterSM 内~256 KB / SM~1 周期线程私有
L1 / 共享内存SM 内~192 KB / SM~20–40 周期Block 内线程共享
L2 缓存芯片内~40 MB~200 周期所有 SM 共享
全局显存 HBM芯片外40–80 GB~500 周期所有线程可见

最关键的一组数字是延迟那一列:访问寄存器只要 1 个周期,访问片外的全局显存(HBM)却要约 500 个周期——差了整整两个数量级。这意味着,一旦计算单元不得不停下来等一次 HBM 读取,它就白白浪费了能做几百次浮点运算的时间。

成本同样悬殊:最快的寄存器每 GB 成本高达约 100 万美元,最慢的 HBM 显存每 GB 仅约 100 美元。这就是为什么快速存储总是小得可怜——你不可能把所有数据都放进寄存器。

把这两点连起来,GPU 优化的全部哲学就浮现了:HBM 又慢又便宜(所以容量大),片上存储又快又昂贵(所以容量小)。优化,就是想方设法让数据在小而快的片上存储里多停留、被多次复用,少去碰那个又慢又远的 HBM。

失衡的趋势与屋顶线模型

为什么"访存"会成为时代性的瓶颈?因为硬件的两条能力曲线,正在以截然不同的速度增长。

从 K20 到 H100 这十年间,GPU 的计算性能增长了约 10 万倍(接近超指数增长),而显存带宽只增长了约 100 倍,卡间互联带宽的增长更慢。算力一路狂奔,访存原地踏步,二者的差距越拉越大——今天我们买到的算力,相当一部分根本喂不饱,只能眼睁睁看着它闲置。

要判断一段计算到底卡在算力还是访存上,业界用一个经典工具——屋顶线模型(Roofline Model)。它的横轴是算术强度(Arithmetic Intensity),即每从显存读取 1 字节数据能支撑多少次浮点运算(FLOPs / Byte);纵轴是实际能达到的算力。

屋顶线模型:访存受限区与计算受限区

这条曲线由两段拼成:

  • 左侧斜线(访存受限,Memory-bound):算术强度低,意味着每读一点数据只算几下就得再去取数据。此时性能被显存带宽死死卡住,算力再强也用不上。
  • 右侧平台(计算受限,Compute-bound):算术强度足够高,数据一旦读进来就被反复使用,瓶颈才回到计算单元本身,这才是我们想要的状态。

优化的目标,就是把你的程序从左侧的斜坡推到右侧的平台——通过提高数据复用率来抬升算术强度,让 GPU 从"等数据"变成"埋头算"。接下来的所有技术,本质上都是在做这一件事。

五种优化技术,一个共同目标

理解了屋顶线,下面五种听起来各不相干的优化技术,其实指向同一个圆心:减少 HBM 访问、提高数据复用

五种优化技术统一到"减少访存、提高复用"这一目标

低精度计算:用更少的比特换更快的速度

最直接的提速,是让每个数字占更少的位。这一招同时打中了"算力"和"访存"两个层面:

精度格式相对 FP32 提速典型用途
FP16 / BF162–4×混合精度训练
TF325–10×训练默认(A100 起)
INT88–16×推理量化
FP810–20×训练加速 / 推理(H100 起)

提速来自三处叠加:其一,硬件更省——一个 FP16 乘法器的晶体管数大约只有 FP32 的四分之一,同样的芯片面积能塞下约 4 倍的计算单元;其二,访存减半——以 GPT-3 175B 为例,参数从 FP32 的 700 GB 降到 BF16 的 350 GB,搬运的数据量直接砍半;其三,Tensor Core 专用加速——这些低精度格式正是 Tensor Core 用脉动阵列(Systolic Array)高效吞吐的对象。代价则是数值精度下降,需要靠混合精度(敏感的累加仍用高精度)来兜底。

算子融合:别让中间结果反复往返显存

朴素地执行 y = sin(x)² + cos(x)² 这样的复合运算,框架会一步步来:读 x、算 sin、把结果写回 HBM,再读回来算平方……每个算子都和 HBM 之间往返一趟。中间结果在"又慢又远"的显存里搬来搬去,纯属浪费。

算子融合:未融合在 HBM 间反复往返,融合后数据驻留片上

算子融合(Kernel Fusion)把这一连串操作合并成一个 kernel:数据从 HBM 读进片上寄存器后,所有中间运算都在片上一气呵成,只在最后把最终结果写回 HBM。访存次数从"每个算子一趟"压缩到"首尾各一趟",算术强度大幅抬升。这正是后面 FlashAttention 的核心手法之一。

重计算:用过剩的算力换紧缺的带宽

训练时,前向传播产生的激活值(activations)要缓存下来供反向传播使用,这是显存的一大消耗。重计算(Recomputation / Activation Checkpointing)反其道而行:前向时不存这些中间激活,只留必要的输入;等反向真正需要时,再临时重新算一遍

这看似多做了计算,却是一笔精明的交易——因为算力是过剩的、访存才是稀缺的。以三层 sigmoid 堆叠为例,存全部激活需要约 8 次内存访问,而重计算方案只需约 5 次。用闲置的算力,去赎回宝贵的显存与带宽,正合屋顶线模型的胃口。

内存合并:让一个 warp 的访问凑成一次突发传输

DRAM 并不是按单个地址取数的,而是以突发模式(Burst Mode)按块返回——你读一个地址,硬件顺手把它附近的一整段都端出来。利用好这一点,性能可以差出数倍。

内存合并:分散访问 vs 合并访问的突发传输

内存合并(Memory Coalescing)的诀窍是:让同一个 warp 里 32 个线程访问的地址恰好落在同一个突发段内、且连续。这样硬件就能把 32 次零散请求合并成一次突发传输,吞吐量可提升约 4 倍。反之,如果线程访问的地址东一个西一个,每次突发取回的数据大部分被浪费,带宽利用率惨不忍睹。矩阵乘法里"按行连续遍历"之所以比"按列跳跃"快得多,根源就在这里。

分块 Tiling:让搬进来的数据被反复使用

朴素的矩阵乘法中,计算结果的每个元素都要去全局显存里取一整行和一整列。对 N×NN \times N 的矩阵,同一个数据会被反复从 HBM 读取 NN 次——算术强度低得可怜,妥妥的访存受限。

分块(Tiling):把数据搬进共享内存复用,全局访存减少 T 倍

分块(Tiling)的思路是:把大矩阵切成 T×TT \times T 的小块,先把一对小块整体搬进共享内存,然后让块内所有计算都在共享内存里取数、反复复用,算完再换下一对块。这样一来,每个元素从全局显存读取的次数从 NN 降到 N/TN/T全局访存量减少约 TT,算术强度被成倍抬高。

分块也有它的工程复杂性:块的尺寸必须同时适配 SM 的共享内存容量和突发传输的段大小;当矩阵维度不是块大小的整数倍时,还需要填充(padding)把维度补齐,否则边界处的访存会翻倍。把块切得"刚刚好",本身就是一门调参艺术。

案例一:FlashAttention —— 把注意力搬进 SRAM

注意力机制是 Transformer 的算力黑洞。标准实现要先算出完整的 N×NN \times N 注意力分数矩阵 S=QKS = QK^\top,做 softmax,再乘以 VV。问题在于:这个中间矩阵 SS 的大小是 O(N2)O(N^2),会被完整地写进 HBM 再读回来——序列一长,光是搬运这个巨大的矩阵就把带宽榨干了,时间和显存双双爆炸。

FlashAttention 的洞见是:何必把整个 SS 落盘?它把上面讲过的分块算子融合两招用到了极致——

FlashAttention:分块 + Online Softmax,注意力矩阵永不落盘 HBM

QQKKVV 都切成小块,每次只把一小块搬进片上 SRAM;在 SRAM 里完成 QKQK^\top、softmax 和乘 VV全部计算,用一种叫 Online Softmax 的技巧增量地维护每行的最大值与归一化分母,边算边更新输出,完整的 N×NN \times N 矩阵从头到尾都不写进 HBM

# FlashAttention 内层逻辑:Online Softmax 增量更新
m_ij  = rowmax(S_ij)                                  # 当前块每行最大值
m_new = max(m_i, m_ij)                                # 更新全局最大值
l_i   = exp(m_i - m_new) * l_i + sum(exp(S_ij - m_new))   # 重新校准归一化分母
O_i   = exp(m_i - m_new) * O_i + exp(S_ij - m_new) @ V_j  # 增量累加输出
m_i   = m_new
# 所有 K,V 块处理完后再统一归一化
O_i   = O_i / l_i

注意:这是数学上精确等价的重排,而非近似——除了浮点误差,结果和标准注意力完全一致。它把注意力从"访存受限"硬生生拉回了"计算受限",显存占用也从 O(N2)O(N^2) 降到 O(Nd)O(N \cdot d)

FlashAttention 的三代演进,恰好对应着对 GPU 三个层面的逐步压榨:

版本核心改进解决的瓶颈
V1分块 + Online Softmax,注意力矩阵不落盘减少 HBM 访存带宽瓶颈
V2在 K/V 维度切分并行(split-KV),多线程块并行 + 末端归约提升并行度与 GPU 利用率
V3利用 H100 的异步 WGMMA 指令与原生 FP8,计算与数据搬运重叠让访存与计算彻底重叠

性能上,V2 在 A100 上相比 V1 加速约 1.7–2.0×,相比 PyTorch 标准实现加速约 8–10×;V3 在 H100 上用 FP8 可把算力利用率推到 75–80%,单卡 80GB 即可稳定训练 256k 长度的序列。

案例二:PagedAttention —— 给 KV Cache 装上虚拟内存

如果说 FlashAttention 优化的是单次前向算子级的访存效率,那 PagedAttention 解决的是推理阶段一个完全不同的、系统级的问题:KV Cache 的内存管理。

自回归生成时,每个已生成 token 的 Key/Value 都要缓存起来供后续步骤复用。传统做法是按最大长度静态预分配一段连续显存,这带来两种浪费:

PagedAttention:静态预分配的碎片浪费 vs 分页管理

  • 内部碎片:按最大长度 2048 预留,实际只生成了 300 个 token,剩下 1748 个 token 的空间全程闲置。
  • 外部碎片:不同请求结束后留下大小不一的空洞,散落各处无法拼成连续大块,新请求挤不进去。

PagedAttention 借鉴了操作系统的虚拟内存分页思想:把 KV Cache 切成固定大小的(如每页 16 个 token),逻辑上连续的 token 序列,物理上可以散落在任意不连续的页里,靠一张块表(Block Table)维护逻辑页到物理页的映射。

效果立竿见影。同样是 300 个 token、页大小 16 的请求,只需 19 页(实际容纳 304 token),浪费仅 4 个 token——而传统方案要浪费 1748 个。换成复杂度的语言:静态预分配的浪费是 O(Lmax)O(L_{\max}),PagedAttention 的浪费只有 O(页大小)O(\text{页大小})。把省下来的显存用于容纳更多并发请求,单卡的吞吐能提升一个数量级。

这两个系统并不互斥,反而常常协同:FlashAttention 负责算得快,PagedAttention 负责装得多。

维度FlashAttentionPagedAttention
优化目标单次前向的访存(IO)复杂度整个生成生命周期的内存管理
层级算子级(micro)系统级(macro)
作用对象Attention kernelKV Cache 生命周期
适用场景训练 + 推理基本仅推理(decoder-only)

尾声:硬件在演进,原理却恒定

GPU 之外,专为矩阵乘法而生的 TPU 走了一条更纯粹的路线——它放弃通用计算,用标量单元、矢量单元和矩阵乘法单元(MXU)的简洁组合,把"做矩阵乘法"这一件事做到极致,内部高速存储 + 外部 HBM 的二层结构则与 GPU 一脉相承。近两年,沐曦、昆仑芯、海光、摩尔线程、华为昇腾等国产芯片也在快速补齐算力与软件生态的拼图。

但无论硬件如何更迭、招牌换成哪家,本文这条主线始终不变:只要"算力远快于访存"的格局不改变,减少数据搬运、提高片上复用,就永远是 GPU 优化的第一性原理。

总结:GPU 优化速查表

把全文的优化技术浓缩成一张表——每一项的本质,都是在和"访存"这个敌人作战:

技术它减少了什么一句话核心
低精度计算计算量 + 访存量更少的比特,更快的吞吐与更小的占用
算子融合HBM 往返次数中间结果留在片上,只首尾各访存一次
重计算激活值显存占用用过剩算力赎回紧缺带宽
内存合并突发传输浪费让一个 warp 的访问凑成一次连续读取
分块 Tiling全局访存次数(降 TT 倍)搬进共享内存反复复用
FlashAttention注意力矩阵的 HBM 读写分块 + Online Softmax,N2N^2 矩阵不落盘
PagedAttentionKV Cache 的内存碎片借虚拟内存分页,浪费从 O(Lmax)O(L_{\max}) 降到 O()O(\text{页})

记住这张表背后的那条主线,你再看任何一项新的 GPU 优化技术,都能一眼看穿它到底在和"算力"还是"访存"较劲——而答案,几乎总是后者。

参考资料