ICLR 2021 Oral Vision Transformer Google Brain Foundation Model

An Image is Worth 16x16 Words
Transformers for Image Recognition at Scale

Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, et al.

尽管 Transformer 已成为 NLP 领域的事实标准,在计算机视觉中其应用仍受限于 CNN 的辅助角色。本文证明这种依赖并非必要:一个纯粹的 Transformer 直接应用于图像切块序列,在大规模数据预训练后即可以更少的算力达到惊人的 SOTA 表现,彻底改变了 CV 领域的发展轨迹。

研究动机与问题背景

从 AlexNet 到 ResNet,CNN 凭借天然自带的归纳偏置(Inductive Bias)——平移等变性与局部性——牢牢统治着计算机视觉。一个自然的问题浮现:Transformer 在 NLP 中展现的惊人扩展能力,能否直接迁移到视觉任务?

存在的痛点与局限

  • CNN 主导统治:卷积操作天然具备的局部性和平移不变性,使其在中小规模数据上极具优势。但这些人工先验也限制了模型容量的天花板。
  • Transformer 跨界遇阻:随着 BERT/GPT 展现出超凡的 Scalability,大量研究尝试将 Attention 引入 CV——Non-local Block、轴向注意力等。但这些方法要么仍依赖 CNN 骨干,要么在现代硬件(TPU/GPU)上效率极低,难以大规模扩展
  • 归纳偏置的两面性:CNN 的局部性先验在小数据上是优势,但也意味着模型被预设了如何看图,无法自主学习最优的注意力模式。

核心问题

我们能否以最少量的架构修改,将标准 NLP Transformer 直接应用到图像上?换言之——剥离掉所有预设的"图像先验",仅依赖大规模数据驱动,纯 Transformer 能否看懂图片?

💡

学术价值:ViT 是计算机视觉走向"大统一"架构的开山之作。它证明了在大规模数据面前,模型容量和扩展性比人类手工设计的局部性先验更加重要,直接催生了 Swin Transformer、MAE 等一系列革命性工作。

数学表示与建模

ViT 的设计哲学是:尽可能贴近原始 NLP Transformer(Vaswani et al., 2017),直接开箱即用经过高度优化的现有实现。

核心变量定义

符号含义维度说明
$x$原始输入图像$\mathbb{R}^{H \times W \times C}$
$x_p$展平后的图像切块序列$\mathbb{R}^{N \times (P^2 \cdot C)}$
$(P, P)$每个切块的分辨率通常为 $16\times16$ 或 $32\times32$
$N$切块数量(序列长度)$N = HW / P^2$
$D$Transformer 的隐层维度贯穿所有层的常量(如 768)
$E$可学习线性投影矩阵$\mathbb{R}^{(P^2 \cdot C) \times D}$
$E_{pos}$可学习 1D 位置编码$\mathbb{R}^{(N+1) \times D}$

Step 1:Patch Embedding — 图像切块与线性投影

标准 Transformer 接收 1D 词向量序列。ViT 将 2D 图像重塑为展平的 2D 切块序列,通过可学习的线性投影矩阵 $E$ 映射到 $D$ 维。借鉴 BERT,在序列开头拼接一个可学习的分类向量 $x_{\text{class}}$,并加上可学习 1D 位置编码:

$$ z_0 = [\,x_{\text{class}};\; x_p^1 E;\; x_p^2 E;\; \cdots;\; x_p^N E\,] + E_{pos} $$

其中 $z_0 \in \mathbb{R}^{(N+1) \times D}$。位置编码使用标准的可学习 1D 编码,实验表明 2D 位置编码并无显著优势——模型能从 1D 编码中自主学到 2D 网格结构。

Step 2:Transformer Encoder — MSA + MLP

编码器包含 $L$ 层交替的多头自注意力块(MSA)和多层感知机块(MLP)。每个块之前应用 LayerNorm (LN),之后使用残差连接。MLP 包含两层及 GELU 激活函数:

$$ z'_l = \text{MSA}(\text{LN}(z_{l-1})) + z_{l-1}, \quad l = 1 \ldots L $$ $$ z_l = \text{MLP}(\text{LN}(z'_l)) + z'_l, \quad l = 1 \ldots L $$

Step 3:分类头输出

经过 $L$ 层处理后,提取序列第一个位置(即 [class] token)的特征 $z_L^0$,通过 LayerNorm 得到图像的最终表征:

$$ y = \text{LN}(z_L^0) $$

预训练时分类头为含一个隐层的 MLP,微调时替换为单层线性层。

Vision Transformer 架构概览 输入图像 H x W x C 切块 ... N 个切块 线性 投影 E [CLS] + E_pos Transformer Encoder LayerNorm + Multi-Head SA + LayerNorm + MLP (GELU) + x L z_0 LN(z_L^0) 分类头 (MLP / Linear) Class: 猫 完整流程:图像 → 切块展平 → 线性投影 + [CLS] + 位置编码 → Transformer Encoder x L → 分类输出 核心设计:零 CNN 组件 直接复用 NLP Transformer,无任何魔改 MSA 残差 MLP 残差 [CLS] / 分类 残差连接 (Skip)

Figure 1:Vision Transformer 架构概览。图像被切分为固定大小的 patches,经线性投影后拼接 [CLS] token 并加上位置编码,送入标准 Transformer Encoder,最终通过 [CLS] token 输出分类结果。

核心逻辑伪代码 (PyTorch 风格)
# Input x: [B, C, H, W]
patches = extract_patches(x, patch_size=16)    # [B, N, P*P*C]
patch_embeds = linear_projection(patches)      # [B, N, D]

# Prepend class token
cls_tokens = repeat(cls_token_param, B)        # [B, 1, D]
z = concat([cls_tokens, patch_embeds], dim=1)  # [B, N+1, D]

# Add positional embeddings
z = z + pos_embedding                          # [B, N+1, D]

# Transformer Encoder (L layers)
for layer in range(L):
    z = MSA(LayerNorm(z)) + z                  # Multi-Head Self-Attention + Residual
    z = MLP(LayerNorm(z)) + z                  # Feed-Forward + Residual

# Classification
img_repr = LayerNorm(z[:, 0])                  # [B, D]  (class token)
logits = classification_head(img_repr)         # [B, num_classes]

实验设置与复现细节

论文严格沿用 BERT 的配置范式,主要测试了三种尺寸的模型。后缀数字代表 patch size(如 ViT-L/16 表示 Large 模型 + 16x16 切块)。

模型变体

ViT 模型家族配置
模型层数 L隐维 DMLP 维度注意力头数参数量
ViT-Base1276830721286M
ViT-Large241024409616307M
ViT-Huge321280512016632M

数据集

  • 预训练:ImageNet-1k (1.3M)、ImageNet-21k (14M)、JFT-300M (303M, Google 内部非公开数据集)
  • 下游评估:ImageNet、ImageNet-ReaL、CIFAR-10/100、Oxford-IIIT Pets、Oxford Flowers-102、VTAB (19 个任务基准)。
  • 分辨率调整:微调时使用更高分辨率(如 ViT-H/14 在 518 分辨率微调)。序列长度 $N$ 增加时,使用2D 插值扩展预训练位置编码——这是 ViT 唯一手动注入 2D 拓扑先验的地方。
训练超参数细节 (基于 Appendix Table 3 & 4)
阶段超参数细节描述
预训练优化器Adam ($\beta_1\!=\!0.9,\;\beta_2\!=\!0.999$)
Batch Size4096
Weight Decay0.1 (很高,对迁移学习有益)
学习率$\sim 8\times10^{-4}$,Linear Warmup (10k steps) + Linear Decay
Dropout直接训练时 0.1;大容量数据集上设为 0.0
微调优化器SGD with momentum (0.9)
Batch Size512
学习率网格搜索 $\{0.001, 0.003, 0.01, 0.03\}$,Cosine 衰减
分辨率高于预训练(ViT-L/16: 512, ViT-H/14: 518),结合 Polyak EMA

实验结果与核心结论

核心数据

88.55%
ImageNet Top-1
632M
ViT-H 参数量
1/4
训练成本 vs BiT
16x16
Patch 大小

主流基准 SOTA 对比

将 JFT-300M 预训练的 ViT 与 SOTA 卷积网络(BiT-L、Noisy Student)对比。指标为分类 Top-1 Accuracy (%)

SOTA 对比 (JFT-300M 预训练 → 下游微调)
模型ImageNetImageNet ReaLCIFAR-100VTAB (19 tasks)TPUv3 Core-days
ViT-H/14 (JFT) 88.55 90.72 94.55 77.63 2.5k
ViT-L/16 (JFT) 87.76 90.54 93.90 76.28 0.68k
BiT-L (ResNet152x4) 87.54 90.54 93.51 76.29 9.9k
Noisy Student (Eff-L2) 88.4 90.55 - - 12.3k

核心发现:ViT-H/14 全面击败极强的 CNN 基线,其预训练算力仅为 BiT-L 的 1/4、Noisy Student 的 1/5。这是惊人的效率革命。

预训练数据规模的影响

论文的金句:"Large scale training trumps inductive bias"——大规模训练胜过归纳偏置。

数据规模对 ViT vs CNN 性能的影响 90% 85% 80% 75% 70% ImageNet (1.3M) ImageNet-21k (14M) JFT-300M (303M) 预训练数据规模 ImageNet Top-1 Acc 交叉点:~10M 数据 ViT 开始反超 CNN ViT 胜出 大数据释放全部潜力 CNN 胜出 归纳偏置占优 ViT (Transformer) BiT (ResNet/CNN)

Figure 2:预训练数据规模对 ViT 与 CNN 性能的影响。在 ImageNet (1.3M) 上 CNN 凭借归纳偏置碾压 ViT;但随着数据量增至 JFT-300M,ViT 从海量数据中学到的规律完美填补了先验缺陷,并爆发出更强的上限。

🔍

深层洞见:当数据量达到千万级(ImageNet-21k)乃至数亿级(JFT-300M),直接从数据中"学习"到的规律就能完美替代人工先验(平移不变性、局部性)。模型没有被预设感受野所限制,因而释放出更强上限。

内部表征可视化:ViT 如何"看"图?

  • 位置编码提取了空间几何:计算学到的位置编码的余弦相似度,发现模型自发学到了 2D 行列网格结构(同一行/列的 patch 相似度更高),说明强加 2D 位置编码是多余的。
  • 注意力距离 (Attention Distance):ViT 中的"感受野"。底层网络中,部分注意力头只看近处像素(类似小卷积核),另一些头直接看全图;随着网络加深,注意力距离整体上升,高层几乎完全在做全局信息整合。

犀利短评

优点

  • 极简之美,大道至简:直接把图像切成 16x16 的"词汇"送进最原始的 NLP Transformer,没有任何花哨的魔改架构(局部注意力、轴向注意力等)。在硬件加速器上执行效率极高,证明了只要数据管够,"通用架构"可以抹平模态差异。
  • 惊人的 Scalability:论文不仅跑赢了 SOTA,更展示了一条清晰的"性能与算力成正比且尚未饱和"的 Scaling Law 曲线。模型越大、数据越多,ViT 的回报越显著。

不足与疑问

  • 对算力和数据的饥渴度极高:普通实验室根本没有 JFT-300M 这种规模的数据。在标准的百万级 ImageNet 数据下,原版 ViT 打不过 ResNet。这也是后来 DeiT 通过知识蒸馏抢救 Data Efficiency 的原因。
  • $O(N^2)$ 二次方计算灾难:Self-Attention 复杂度与序列长度 $N$ 成二次方关系。面对医疗影像或遥感等超高分辨率图像,16x16 patch 导致序列长度爆炸。这催生了 Swin Transformer 的局部窗口注意力。
  • 密集预测任务的短板:论文主要做分类任务,输出仅一个 [class] token。如果做目标检测或像素级分割,单一尺度的特征图显得笨拙,缺乏 CNN 的多尺度金字塔结构。

One More Thing

被忽略的实验细节:CNN 的优化器选择玄学

在附录 D.1 中,作者提到了一个反直觉的现象。传统观念里,从 PyTorch 官方教程到何恺明原论文,训练 ResNet 的标配永远是 SGD with momentum

但在本文的 JFT 预训练阶段,作者给所有 ResNet 基线(即 BiT)使用的是 Adam 优化器。对照实验表明:在极高训练规模下,Adam 预训练的 ResNet152x2 在 ImageNet 微调达到 84.97%,而传统 SGD 只有 84.37%。

🔍

启示:当你把训练规模(Batch size、数据量)推向极端时,过去深信不疑的经验法则可能需要重新审视。所谓的"最佳实践"往往是特定 scale 下的局部最优。

Masked Patch Prediction:MAE 的前身

作者还尝试了类似 BERT 的 Masked Patch Prediction 进行自监督预训练——随机遮盖部分 patch 并让模型预测。虽然取得了 79.9% 的准确率(较从头训练有 2% 提升),但离监督学习仍有 4% 的差距。

这个"半成品"的坑,在一年后被何恺明用极简的 MAE (Masked Autoencoders) 架构彻底填平——75% 的遮盖率 + 非对称 encoder-decoder 设计,完成了视觉自监督大一统拼图的最后一块。