把一块现代 GPU 比作一座工厂,它的"机器"(计算单元)多到几乎用不完,真正的瓶颈是"传送带"(访存带宽)——原料喂不进来,再多机器也只能空转。这句话听起来像个比喻,却几乎是理解所有 GPU 优化的钥匙:GPU 算得太快,喂不饱数据。于是绝大多数优化技巧,无论名字多花哨,最终都收敛到同一个目标——减少对慢速显存的访问,让数据尽量在快速的片上存储里被反复复用。
本文对应斯坦福 CS336 第五讲,也对应 Datawhale diy-llm 第六章。原文内容极其翔实,但也夹杂了大量产品参数罗列,读起来略显冗长。这里把它重新组织成一条主线 + 两个案例:先用"算力 vs 访存"的失衡建立直觉,再依次讲清 GPU 的执行模型、内存模型与屋顶线模型,把五种优化技术统一到"减少访存"这一总纲下,最后落到 FlashAttention 与 PagedAttention 两个把这些原则用到极致的真实系统。核心概念配以重绘的示意图,力求让每个结论都"看得见"。
一条主线:GPU 的瓶颈不在算力,而在访存
要理解 GPU 为什么这样设计,先得看清它和 CPU 在哲学上的根本分歧。
CPU 优化的是延迟(latency):它假设任务之间有强依赖,目标是让单个任务尽可能快地完成。所以 CPU 把大量晶体管投给了控制逻辑和多级缓存,只配几个到几十个强大的核心,每个核心都能乱序执行、分支预测、深度流水。
GPU 优化的是吞吐(throughput):它假设有海量彼此独立的任务,目标是让所有任务整体尽快完成,不在乎单个任务的快慢。所以 GPU 把绝大部分晶体管都投给了计算单元(ALU),用成千上万个简单核心去"人海战术"。
这个选择带来一个副作用:当算力被堆到极致,把数据从显存搬到计算单元的速度就成了短板。后文会看到,过去十年里 GPU 的算力增长了约 10 万倍,而显存带宽只增长了约 100 倍——这道越拉越大的鸿沟,正是所有优化技术真正要对抗的敌人。
GPU 的执行模型:Grid / Block / Warp / Thread
GPU 的算力来自一颗芯片上密密麻麻的流式多处理器(Streaming Multiprocessor, SM)。以 A100 为例,一颗 GA100 芯片用 7nm 工艺集成了 542 亿晶体管,包含上百个 SM、几千个 CUDA 核心和数百个专门做矩阵乘法的 Tensor Core。要驱动这么多核心,CUDA 用一套四层的执行模型把任务切碎、铺满硬件:
- Thread(线程):最细粒度的执行单元,拥有自己的私有寄存器。
- Warp(线程束):32 个线程绑成一束,是 SM 调度的最小单位。同一个 warp 里的 32 个线程在同一时刻执行同一条指令,只是各自处理不同的数据——这就是 GPU 的 SIMT(单指令多线程)模型。
- Block(线程块):若干 warp 组成一个 block,整体被分配到一个 SM 上执行,块内线程共享这个 SM 的共享内存与寄存器。
- Grid(网格):一次 kernel 启动的所有 block 构成一个 grid,由硬件调度器铺满所有可用的 SM。
SIMT 模型藏着一个性能陷阱:warp 分支发散(Warp Divergence)。如果一个 warp 里的线程因为 if/else 走了不同分支,硬件没法让它们真正并行,只能用掩码让两条路径依次串行执行——走 A 路径时屏蔽走 B 的线程,反之亦然。结果是同一段时间里只有部分线程在干活,算力利用率被打了折扣。这也是为什么 GPU 代码要尽量避免 warp 内的数据相关分支。
GPU 的内存模型:一座延迟金字塔
执行模型决定了"谁来算",内存模型决定了"数据从哪来"。GPU 的存储不是铁板一块,而是一座分层的金字塔:越靠近计算单元的越快、越小、越贵;越远离的越慢、越大、越便宜。
| 存储类型 | 位置 | 容量(A100) | 延迟 | 可见范围 |
|---|---|---|---|---|
| 寄存器 Register | SM 内 | ~256 KB / SM | ~1 周期 | 线程私有 |
| L1 / 共享内存 | SM 内 | ~192 KB / SM | ~20–40 周期 | Block 内线程共享 |
| L2 缓存 | 芯片内 | ~40 MB | ~200 周期 | 所有 SM 共享 |
| 全局显存 HBM | 芯片外 | 40–80 GB | ~500 周期 | 所有线程可见 |
最关键的一组数字是延迟那一列:访问寄存器只要 1 个周期,访问片外的全局显存(HBM)却要约 500 个周期——差了整整两个数量级。这意味着,一旦计算单元不得不停下来等一次 HBM 读取,它就白白浪费了能做几百次浮点运算的时间。
成本同样悬殊:最快的寄存器每 GB 成本高达约 100 万美元,最慢的 HBM 显存每 GB 仅约 100 美元。这就是为什么快速存储总是小得可怜——你不可能把所有数据都放进寄存器。
把这两点连起来,GPU 优化的全部哲学就浮现了:HBM 又慢又便宜(所以容量大),片上存储又快又昂贵(所以容量小)。优化,就是想方设法让数据在小而快的片上存储里多停留、被多次复用,少去碰那个又慢又远的 HBM。
失衡的趋势与屋顶线模型
为什么"访存"会成为时代性的瓶颈?因为硬件的两条能力曲线,正在以截然不同的速度增长。
从 K20 到 H100 这十年间,GPU 的计算性能增长了约 10 万倍(接近超指数增长),而显存带宽只增长了约 100 倍,卡间互联带宽的增长更慢。算力一路狂奔,访存原地踏步,二者的差距越拉越大——今天我们买到的算力,相当一部分根本喂不饱,只能眼睁睁看着它闲置。
要判断一段计算到底卡在算力还是访存上,业界用一个经典工具——屋顶线模型(Roofline Model)。它的横轴是算术强度(Arithmetic Intensity),即每从显存读取 1 字节数据能支撑多少次浮点运算(FLOPs / Byte);纵轴是实际能达到的算力。
这条曲线由两段拼成:
- 左侧斜线(访存受限,Memory-bound):算术强度低,意味着每读一点数据只算几下就得再去取数据。此时性能被显存带宽死死卡住,算力再强也用不上。
- 右侧平台(计算受限,Compute-bound):算术强度足够高,数据一旦读进来就被反复使用,瓶颈才回到计算单元本身,这才是我们想要的状态。
优化的目标,就是把你的程序从左侧的斜坡推到右侧的平台——通过提高数据复用率来抬升算术强度,让 GPU 从"等数据"变成"埋头算"。接下来的所有技术,本质上都是在做这一件事。
五种优化技术,一个共同目标
理解了屋顶线,下面五种听起来各不相干的优化技术,其实指向同一个圆心:减少 HBM 访问、提高数据复用。
低精度计算:用更少的比特换更快的速度
最直接的提速,是让每个数字占更少的位。这一招同时打中了"算力"和"访存"两个层面:
| 精度格式 | 相对 FP32 提速 | 典型用途 |
|---|---|---|
| FP16 / BF16 | 2–4× | 混合精度训练 |
| TF32 | 5–10× | 训练默认(A100 起) |
| INT8 | 8–16× | 推理量化 |
| FP8 | 10–20× | 训练加速 / 推理(H100 起) |
提速来自三处叠加:其一,硬件更省——一个 FP16 乘法器的晶体管数大约只有 FP32 的四分之一,同样的芯片面积能塞下约 4 倍的计算单元;其二,访存减半——以 GPT-3 175B 为例,参数从 FP32 的 700 GB 降到 BF16 的 350 GB,搬运的数据量直接砍半;其三,Tensor Core 专用加速——这些低精度格式正是 Tensor Core 用脉动阵列(Systolic Array)高效吞吐的对象。代价则是数值精度下降,需要靠混合精度(敏感的累加仍用高精度)来兜底。
算子融合:别让中间结果反复往返显存
朴素地执行 y = sin(x)² + cos(x)² 这样的复合运算,框架会一步步来:读 x、算 sin、把结果写回 HBM,再读回来算平方……每个算子都和 HBM 之间往返一趟。中间结果在"又慢又远"的显存里搬来搬去,纯属浪费。
算子融合(Kernel Fusion)把这一连串操作合并成一个 kernel:数据从 HBM 读进片上寄存器后,所有中间运算都在片上一气呵成,只在最后把最终结果写回 HBM。访存次数从"每个算子一趟"压缩到"首尾各一趟",算术强度大幅抬升。这正是后面 FlashAttention 的核心手法之一。
重计算:用过剩的算力换紧缺的带宽
训练时,前向传播产生的激活值(activations)要缓存下来供反向传播使用,这是显存的一大消耗。重计算(Recomputation / Activation Checkpointing)反其道而行:前向时不存这些中间激活,只留必要的输入;等反向真正需要时,再临时重新算一遍。
这看似多做了计算,却是一笔精明的交易——因为算力是过剩的、访存才是稀缺的。以三层 sigmoid 堆叠为例,存全部激活需要约 8 次内存访问,而重计算方案只需约 5 次。用闲置的算力,去赎回宝贵的显存与带宽,正合屋顶线模型的胃口。
内存合并:让一个 warp 的访问凑成一次突发传输
DRAM 并不是按单个地址取数的,而是以突发模式(Burst Mode)按块返回——你读一个地址,硬件顺手把它附近的一整段都端出来。利用好这一点,性能可以差出数倍。
内存合并(Memory Coalescing)的诀窍是:让同一个 warp 里 32 个线程访问的地址恰好落在同一个突发段内、且连续。这样硬件就能把 32 次零散请求合并成一次突发传输,吞吐量可提升约 4 倍。反之,如果线程访问的地址东一个西一个,每次突发取回的数据大部分被浪费,带宽利用率惨不忍睹。矩阵乘法里"按行连续遍历"之所以比"按列跳跃"快得多,根源就在这里。
分块 Tiling:让搬进来的数据被反复使用
朴素的矩阵乘法中,计算结果的每个元素都要去全局显存里取一整行和一整列。对 的矩阵,同一个数据会被反复从 HBM 读取 次——算术强度低得可怜,妥妥的访存受限。
分块(Tiling)的思路是:把大矩阵切成 的小块,先把一对小块整体搬进共享内存,然后让块内所有计算都在共享内存里取数、反复复用,算完再换下一对块。这样一来,每个元素从全局显存读取的次数从 降到 ,全局访存量减少约 倍,算术强度被成倍抬高。
分块也有它的工程复杂性:块的尺寸必须同时适配 SM 的共享内存容量和突发传输的段大小;当矩阵维度不是块大小的整数倍时,还需要填充(padding)把维度补齐,否则边界处的访存会翻倍。把块切得"刚刚好",本身就是一门调参艺术。
案例一:FlashAttention —— 把注意力搬进 SRAM
注意力机制是 Transformer 的算力黑洞。标准实现要先算出完整的 注意力分数矩阵 ,做 softmax,再乘以 。问题在于:这个中间矩阵 的大小是 ,会被完整地写进 HBM 再读回来——序列一长,光是搬运这个巨大的矩阵就把带宽榨干了,时间和显存双双爆炸。
FlashAttention 的洞见是:何必把整个 落盘?它把上面讲过的分块和算子融合两招用到了极致——
把 、、 都切成小块,每次只把一小块搬进片上 SRAM;在 SRAM 里完成 、softmax 和乘 的全部计算,用一种叫 Online Softmax 的技巧增量地维护每行的最大值与归一化分母,边算边更新输出,完整的 矩阵从头到尾都不写进 HBM。
# FlashAttention 内层逻辑:Online Softmax 增量更新
m_ij = rowmax(S_ij) # 当前块每行最大值
m_new = max(m_i, m_ij) # 更新全局最大值
l_i = exp(m_i - m_new) * l_i + sum(exp(S_ij - m_new)) # 重新校准归一化分母
O_i = exp(m_i - m_new) * O_i + exp(S_ij - m_new) @ V_j # 增量累加输出
m_i = m_new
# 所有 K,V 块处理完后再统一归一化
O_i = O_i / l_i注意:这是数学上精确等价的重排,而非近似——除了浮点误差,结果和标准注意力完全一致。它把注意力从"访存受限"硬生生拉回了"计算受限",显存占用也从 降到 。
FlashAttention 的三代演进,恰好对应着对 GPU 三个层面的逐步压榨:
| 版本 | 核心改进 | 解决的瓶颈 |
|---|---|---|
| V1 | 分块 + Online Softmax,注意力矩阵不落盘 | 减少 HBM 访存带宽瓶颈 |
| V2 | 在 K/V 维度切分并行(split-KV),多线程块并行 + 末端归约 | 提升并行度与 GPU 利用率 |
| V3 | 利用 H100 的异步 WGMMA 指令与原生 FP8,计算与数据搬运重叠 | 让访存与计算彻底重叠 |
性能上,V2 在 A100 上相比 V1 加速约 1.7–2.0×,相比 PyTorch 标准实现加速约 8–10×;V3 在 H100 上用 FP8 可把算力利用率推到 75–80%,单卡 80GB 即可稳定训练 256k 长度的序列。
案例二:PagedAttention —— 给 KV Cache 装上虚拟内存
如果说 FlashAttention 优化的是单次前向算子级的访存效率,那 PagedAttention 解决的是推理阶段一个完全不同的、系统级的问题:KV Cache 的内存管理。
自回归生成时,每个已生成 token 的 Key/Value 都要缓存起来供后续步骤复用。传统做法是按最大长度静态预分配一段连续显存,这带来两种浪费:
- 内部碎片:按最大长度 2048 预留,实际只生成了 300 个 token,剩下 1748 个 token 的空间全程闲置。
- 外部碎片:不同请求结束后留下大小不一的空洞,散落各处无法拼成连续大块,新请求挤不进去。
PagedAttention 借鉴了操作系统的虚拟内存分页思想:把 KV Cache 切成固定大小的页(如每页 16 个 token),逻辑上连续的 token 序列,物理上可以散落在任意不连续的页里,靠一张块表(Block Table)维护逻辑页到物理页的映射。
效果立竿见影。同样是 300 个 token、页大小 16 的请求,只需 19 页(实际容纳 304 token),浪费仅 4 个 token——而传统方案要浪费 1748 个。换成复杂度的语言:静态预分配的浪费是 ,PagedAttention 的浪费只有 。把省下来的显存用于容纳更多并发请求,单卡的吞吐能提升一个数量级。
这两个系统并不互斥,反而常常协同:FlashAttention 负责算得快,PagedAttention 负责装得多。
| 维度 | FlashAttention | PagedAttention |
|---|---|---|
| 优化目标 | 单次前向的访存(IO)复杂度 | 整个生成生命周期的内存管理 |
| 层级 | 算子级(micro) | 系统级(macro) |
| 作用对象 | Attention kernel | KV Cache 生命周期 |
| 适用场景 | 训练 + 推理 | 基本仅推理(decoder-only) |
尾声:硬件在演进,原理却恒定
GPU 之外,专为矩阵乘法而生的 TPU 走了一条更纯粹的路线——它放弃通用计算,用标量单元、矢量单元和矩阵乘法单元(MXU)的简洁组合,把"做矩阵乘法"这一件事做到极致,内部高速存储 + 外部 HBM 的二层结构则与 GPU 一脉相承。近两年,沐曦、昆仑芯、海光、摩尔线程、华为昇腾等国产芯片也在快速补齐算力与软件生态的拼图。
但无论硬件如何更迭、招牌换成哪家,本文这条主线始终不变:只要"算力远快于访存"的格局不改变,减少数据搬运、提高片上复用,就永远是 GPU 优化的第一性原理。
总结:GPU 优化速查表
把全文的优化技术浓缩成一张表——每一项的本质,都是在和"访存"这个敌人作战:
| 技术 | 它减少了什么 | 一句话核心 |
|---|---|---|
| 低精度计算 | 计算量 + 访存量 | 更少的比特,更快的吞吐与更小的占用 |
| 算子融合 | HBM 往返次数 | 中间结果留在片上,只首尾各访存一次 |
| 重计算 | 激活值显存占用 | 用过剩算力赎回紧缺带宽 |
| 内存合并 | 突发传输浪费 | 让一个 warp 的访问凑成一次连续读取 |
| 分块 Tiling | 全局访存次数(降 倍) | 搬进共享内存反复复用 |
| FlashAttention | 注意力矩阵的 HBM 读写 | 分块 + Online Softmax, 矩阵不落盘 |
| PagedAttention | KV Cache 的内存碎片 | 借虚拟内存分页,浪费从 降到 |
记住这张表背后的那条主线,你再看任何一项新的 GPU 优化技术,都能一眼看穿它到底在和"算力"还是"访存"较劲——而答案,几乎总是后者。
参考资料
- CS336: Language Modeling from Scratch(Stanford)
- Datawhale diy-llm · 第六章 GPU 和 GPU 相关的优化
- FlashAttention(arXiv:2205.14135)
- FlashAttention-2(arXiv:2307.08691)
- FlashAttention-3(arXiv:2407.08608)
- PagedAttention / vLLM(arXiv:2309.06180)
- Roofline: An Insightful Visual Performance Model
- NVIDIA A100 Tensor Core GPU 架构白皮书