研究动机与问题背景
从 AlexNet 到 ResNet,CNN 凭借天然自带的归纳偏置(Inductive Bias)——平移等变性与局部性——牢牢统治着计算机视觉。一个自然的问题浮现:Transformer 在 NLP 中展现的惊人扩展能力,能否直接迁移到视觉任务?
存在的痛点与局限
- CNN 主导统治:卷积操作天然具备的局部性和平移不变性,使其在中小规模数据上极具优势。但这些人工先验也限制了模型容量的天花板。
- Transformer 跨界遇阻:随着 BERT/GPT 展现出超凡的 Scalability,大量研究尝试将 Attention 引入 CV——Non-local Block、轴向注意力等。但这些方法要么仍依赖 CNN 骨干,要么在现代硬件(TPU/GPU)上效率极低,难以大规模扩展。
- 归纳偏置的两面性:CNN 的局部性先验在小数据上是优势,但也意味着模型被预设了如何看图,无法自主学习最优的注意力模式。
核心问题
我们能否以最少量的架构修改,将标准 NLP Transformer 直接应用到图像上?换言之——剥离掉所有预设的"图像先验",仅依赖大规模数据驱动,纯 Transformer 能否看懂图片?
学术价值:ViT 是计算机视觉走向"大统一"架构的开山之作。它证明了在大规模数据面前,模型容量和扩展性比人类手工设计的局部性先验更加重要,直接催生了 Swin Transformer、MAE 等一系列革命性工作。
数学表示与建模
ViT 的设计哲学是:尽可能贴近原始 NLP Transformer(Vaswani et al., 2017),直接开箱即用经过高度优化的现有实现。
核心变量定义
| 符号 | 含义 | 维度说明 |
|---|---|---|
| $x$ | 原始输入图像 | $\mathbb{R}^{H \times W \times C}$ |
| $x_p$ | 展平后的图像切块序列 | $\mathbb{R}^{N \times (P^2 \cdot C)}$ |
| $(P, P)$ | 每个切块的分辨率 | 通常为 $16\times16$ 或 $32\times32$ |
| $N$ | 切块数量(序列长度) | $N = HW / P^2$ |
| $D$ | Transformer 的隐层维度 | 贯穿所有层的常量(如 768) |
| $E$ | 可学习线性投影矩阵 | $\mathbb{R}^{(P^2 \cdot C) \times D}$ |
| $E_{pos}$ | 可学习 1D 位置编码 | $\mathbb{R}^{(N+1) \times D}$ |
Step 1:Patch Embedding — 图像切块与线性投影
标准 Transformer 接收 1D 词向量序列。ViT 将 2D 图像重塑为展平的 2D 切块序列,通过可学习的线性投影矩阵 $E$ 映射到 $D$ 维。借鉴 BERT,在序列开头拼接一个可学习的分类向量 $x_{\text{class}}$,并加上可学习 1D 位置编码:
其中 $z_0 \in \mathbb{R}^{(N+1) \times D}$。位置编码使用标准的可学习 1D 编码,实验表明 2D 位置编码并无显著优势——模型能从 1D 编码中自主学到 2D 网格结构。
Step 2:Transformer Encoder — MSA + MLP
编码器包含 $L$ 层交替的多头自注意力块(MSA)和多层感知机块(MLP)。每个块之前应用 LayerNorm (LN),之后使用残差连接。MLP 包含两层及 GELU 激活函数:
Step 3:分类头输出
经过 $L$ 层处理后,提取序列第一个位置(即 [class] token)的特征 $z_L^0$,通过 LayerNorm 得到图像的最终表征:
预训练时分类头为含一个隐层的 MLP,微调时替换为单层线性层。
Figure 1:Vision Transformer 架构概览。图像被切分为固定大小的 patches,经线性投影后拼接 [CLS] token 并加上位置编码,送入标准 Transformer Encoder,最终通过 [CLS] token 输出分类结果。
核心逻辑伪代码 (PyTorch 风格)
# Input x: [B, C, H, W]
patches = extract_patches(x, patch_size=16) # [B, N, P*P*C]
patch_embeds = linear_projection(patches) # [B, N, D]
# Prepend class token
cls_tokens = repeat(cls_token_param, B) # [B, 1, D]
z = concat([cls_tokens, patch_embeds], dim=1) # [B, N+1, D]
# Add positional embeddings
z = z + pos_embedding # [B, N+1, D]
# Transformer Encoder (L layers)
for layer in range(L):
z = MSA(LayerNorm(z)) + z # Multi-Head Self-Attention + Residual
z = MLP(LayerNorm(z)) + z # Feed-Forward + Residual
# Classification
img_repr = LayerNorm(z[:, 0]) # [B, D] (class token)
logits = classification_head(img_repr) # [B, num_classes]
实验设置与复现细节
论文严格沿用 BERT 的配置范式,主要测试了三种尺寸的模型。后缀数字代表 patch size(如 ViT-L/16 表示 Large 模型 + 16x16 切块)。
模型变体
| 模型 | 层数 L | 隐维 D | MLP 维度 | 注意力头数 | 参数量 |
|---|---|---|---|---|---|
| ViT-Base | 12 | 768 | 3072 | 12 | 86M |
| ViT-Large | 24 | 1024 | 4096 | 16 | 307M |
| ViT-Huge | 32 | 1280 | 5120 | 16 | 632M |
数据集
- 预训练:ImageNet-1k (1.3M)、ImageNet-21k (14M)、JFT-300M (303M, Google 内部非公开数据集)。
- 下游评估:ImageNet、ImageNet-ReaL、CIFAR-10/100、Oxford-IIIT Pets、Oxford Flowers-102、VTAB (19 个任务基准)。
- 分辨率调整:微调时使用更高分辨率(如 ViT-H/14 在 518 分辨率微调)。序列长度 $N$ 增加时,使用2D 插值扩展预训练位置编码——这是 ViT 唯一手动注入 2D 拓扑先验的地方。
训练超参数细节 (基于 Appendix Table 3 & 4)
| 阶段 | 超参数 | 细节描述 |
|---|---|---|
| 预训练 | 优化器 | Adam ($\beta_1\!=\!0.9,\;\beta_2\!=\!0.999$) |
| Batch Size | 4096 | |
| Weight Decay | 0.1 (很高,对迁移学习有益) | |
| 学习率 | $\sim 8\times10^{-4}$,Linear Warmup (10k steps) + Linear Decay | |
| Dropout | 直接训练时 0.1;大容量数据集上设为 0.0 | |
| 微调 | 优化器 | SGD with momentum (0.9) |
| Batch Size | 512 | |
| 学习率 | 网格搜索 $\{0.001, 0.003, 0.01, 0.03\}$,Cosine 衰减 | |
| 分辨率 | 高于预训练(ViT-L/16: 512, ViT-H/14: 518),结合 Polyak EMA |
实验结果与核心结论
核心数据
主流基准 SOTA 对比
将 JFT-300M 预训练的 ViT 与 SOTA 卷积网络(BiT-L、Noisy Student)对比。指标为分类 Top-1 Accuracy (%):
| 模型 | ImageNet | ImageNet ReaL | CIFAR-100 | VTAB (19 tasks) | TPUv3 Core-days |
|---|---|---|---|---|---|
| ViT-H/14 (JFT) | 88.55 | 90.72 | 94.55 | 77.63 | 2.5k |
| ViT-L/16 (JFT) | 87.76 | 90.54 | 93.90 | 76.28 | 0.68k |
| BiT-L (ResNet152x4) | 87.54 | 90.54 | 93.51 | 76.29 | 9.9k |
| Noisy Student (Eff-L2) | 88.4 | 90.55 | - | - | 12.3k |
核心发现:ViT-H/14 全面击败极强的 CNN 基线,其预训练算力仅为 BiT-L 的 1/4、Noisy Student 的 1/5。这是惊人的效率革命。
预训练数据规模的影响
论文的金句:"Large scale training trumps inductive bias"——大规模训练胜过归纳偏置。
Figure 2:预训练数据规模对 ViT 与 CNN 性能的影响。在 ImageNet (1.3M) 上 CNN 凭借归纳偏置碾压 ViT;但随着数据量增至 JFT-300M,ViT 从海量数据中学到的规律完美填补了先验缺陷,并爆发出更强的上限。
深层洞见:当数据量达到千万级(ImageNet-21k)乃至数亿级(JFT-300M),直接从数据中"学习"到的规律就能完美替代人工先验(平移不变性、局部性)。模型没有被预设感受野所限制,因而释放出更强上限。
内部表征可视化:ViT 如何"看"图?
- 位置编码提取了空间几何:计算学到的位置编码的余弦相似度,发现模型自发学到了 2D 行列网格结构(同一行/列的 patch 相似度更高),说明强加 2D 位置编码是多余的。
- 注意力距离 (Attention Distance):ViT 中的"感受野"。底层网络中,部分注意力头只看近处像素(类似小卷积核),另一些头直接看全图;随着网络加深,注意力距离整体上升,高层几乎完全在做全局信息整合。
犀利短评
优点
- 极简之美,大道至简:直接把图像切成 16x16 的"词汇"送进最原始的 NLP Transformer,没有任何花哨的魔改架构(局部注意力、轴向注意力等)。在硬件加速器上执行效率极高,证明了只要数据管够,"通用架构"可以抹平模态差异。
- 惊人的 Scalability:论文不仅跑赢了 SOTA,更展示了一条清晰的"性能与算力成正比且尚未饱和"的 Scaling Law 曲线。模型越大、数据越多,ViT 的回报越显著。
不足与疑问
- 对算力和数据的饥渴度极高:普通实验室根本没有 JFT-300M 这种规模的数据。在标准的百万级 ImageNet 数据下,原版 ViT 打不过 ResNet。这也是后来 DeiT 通过知识蒸馏抢救 Data Efficiency 的原因。
- $O(N^2)$ 二次方计算灾难:Self-Attention 复杂度与序列长度 $N$ 成二次方关系。面对医疗影像或遥感等超高分辨率图像,16x16 patch 导致序列长度爆炸。这催生了 Swin Transformer 的局部窗口注意力。
- 密集预测任务的短板:论文主要做分类任务,输出仅一个
[class]token。如果做目标检测或像素级分割,单一尺度的特征图显得笨拙,缺乏 CNN 的多尺度金字塔结构。
One More Thing
被忽略的实验细节:CNN 的优化器选择玄学
在附录 D.1 中,作者提到了一个反直觉的现象。传统观念里,从 PyTorch 官方教程到何恺明原论文,训练 ResNet 的标配永远是 SGD with momentum。
但在本文的 JFT 预训练阶段,作者给所有 ResNet 基线(即 BiT)使用的是 Adam 优化器。对照实验表明:在极高训练规模下,Adam 预训练的 ResNet152x2 在 ImageNet 微调达到 84.97%,而传统 SGD 只有 84.37%。
启示:当你把训练规模(Batch size、数据量)推向极端时,过去深信不疑的经验法则可能需要重新审视。所谓的"最佳实践"往往是特定 scale 下的局部最优。
Masked Patch Prediction:MAE 的前身
作者还尝试了类似 BERT 的 Masked Patch Prediction 进行自监督预训练——随机遮盖部分 patch 并让模型预测。虽然取得了 79.9% 的准确率(较从头训练有 2% 提升),但离监督学习仍有 4% 的差距。
这个"半成品"的坑,在一年后被何恺明用极简的 MAE (Masked Autoencoders) 架构彻底填平——75% 的遮盖率 + 非对称 encoder-decoder 设计,完成了视觉自监督大一统拼图的最后一块。