Vision Transformer (ViT) 论文精读

March 3, 2026

Vision Transformer (ViT) 论文精读

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

Dosovitskiy et al. — ICLR 2021 (Google Research, Brain Team)

Vision Transformer(ViT)由 Google 于 2020 年提出,首次将 NLP 领域的 Transformer 架构直接应用于图像分类任务。其核心思想是将图像划分为固定大小的 Patch 序列,类比文本中的 Token,输入标准 Transformer Encoder 进行建模,在大规模数据预训练下取得了超越 CNN 的性能。这一工作证明了纯 Transformer 架构在视觉任务中的可行性,为后续 DeiT、Swin Transformer、MAE 等一系列视觉 Transformer 工作奠定了基础。

研究动机

Transformer 在 NLP 任务中大获成功(如 BERT 处理长度为 512 的序列),但直接将其迁移到视觉领域面临一个关键挑战:序列长度过大。以常见的 224×224224 \times 224 分辨率图像为例,如果将每个像素作为一个 Token,序列长度将达到 50176,远超 Transformer 的处理能力。目标检测、实例分割等任务的分辨率更高,问题更加严峻。

ViT 的核心创新在于:用 Patch 切分将图像转化为可控长度的序列,从而用标准 Transformer 替代 CNN 作为视觉主干网络。与 NLP 中常见的无监督预训练不同,ViT 采用有监督学习进行预训练。

模型架构

ViT 整体架构:图像经 Patch 切分、线性投影和位置编码后输入 Transformer Encoder

图像切分为 Patch

输入图像被均匀切分为固定大小的非重叠 Patch。以 224×224224 \times 224 的图像、16×1616 \times 16 的 Patch 大小为例:

  • Patch 数量:(224/16)×(224/16)=196(224 / 16) \times (224 / 16) = 196
  • 每个 Patch 展平后的维度:16×16×3=76816 \times 16 \times 3 = 768

这样,一张图像被转化为长度为 196 的序列,每个元素是一个 768 维向量。

Patch Embedding

每个展平的 Patch 通过一个线性投影层(全连接层)映射为固定长度的嵌入向量。以 ViT-Base 为例,投影矩阵维度为 768×768768 \times 768,其中输入维度 768 由 Patch 展平大小决定,输出维度 768 为超参数设定。这一过程类似于 NLP 中的词嵌入(Word Embedding)。

位置编码(Position Embedding)

Transformer 的自注意力机制对序列中元素做两两交互,但本身不包含位置信息。由于图像 Patch 具有空间顺序,需要为每个 Patch Embedding 加上可学习的位置编码以保留空间结构。

ViT 论文中关于位置编码的消融实验

ViT 论文探讨了三种位置编码策略:

  • 1D 位置编码:将 Patch 按光栅顺序排列,每个位置分配一个长度为 DD 的可学习向量。这是 ViT 的默认方案。
  • 2D 位置编码:分别学习 X 轴和 Y 轴方向的编码,各自长度为 D/2D/2,拼接后仍为长度 DD。序列数量从 NN 变为 N\sqrt{N}(每个轴方向)。
  • 相对位置编码:基于 Patch 之间的相对距离编码空间信息。

实验结果表明三种方案性能接近,1D 编码已能隐式学到 2D 的空间结构信息。

CLS Token

借鉴 BERT 的设计,ViT 在 Patch 序列前拼接一个可学习的 CLS Token(维度为 1×7681 \times 768),用于聚合全局图像特征并完成最终分类。

因此,Transformer Encoder 的输入维度为 197×768197 \times 768(196 个 Patch + 1 个 CLS Token),整个输入 = Patch Embedding + CLS Token + Position Embedding(逐元素相加)。

Transformer Encoder

ViT 直接采用标准 Transformer Encoder 结构,由 LL 层堆叠而成,每层包含:

  1. Layer Norm + 多头自注意力(Multi-Head Self-Attention, MSA)+ 残差连接
  2. Layer Norm + 前馈网络(MLP,含 GELU 激活)+ 残差连接

数学表示如下:

z0=[xclass;xp1E;xp2E;;xpNE]+Epos\mathbf{z}_0 = [\mathbf{x}_{\text{class}};\, \mathbf{x}_p^1 \mathbf{E};\, \mathbf{x}_p^2 \mathbf{E};\, \cdots;\, \mathbf{x}_p^N \mathbf{E}] + \mathbf{E}_{\text{pos}} z=MSA(LN(z1))+z1,=1L\mathbf{z}'_\ell = \text{MSA}(\text{LN}(\mathbf{z}_{\ell-1})) + \mathbf{z}_{\ell-1}, \quad \ell = 1 \ldots L z=MLP(LN(z))+z,=1L\mathbf{z}_\ell = \text{MLP}(\text{LN}(\mathbf{z}'_\ell)) + \mathbf{z}'_\ell, \quad \ell = 1 \ldots L y=LN(zL0)\mathbf{y} = \text{LN}(\mathbf{z}_L^0)

其中 PP 为 Patch 大小,CC 为通道数,DD 为嵌入维度,NN 为 Patch 数量,LL 为 Encoder 层数。最终取 zL0\mathbf{z}_L^0(即 CLS Token 的输出)作为图像的全局表示。

分类头(Classification Head)

Transformer Encoder 输出的 CLS Token 向量经过一个 MLP 分类头,输出最终的类别概率。预训练时使用一个隐藏层的 MLP,微调时替换为单层线性层。

混合模型(Hybrid Architecture)

ViT 还提出了一种混合架构:

  1. 先用 CNN(如 ResNet)提取特征图
  2. 将特征图切分为 Patch 序列
  3. 对特征 Patch 进行线性投影(若 CNN 输出通道数已匹配 Transformer 维度则可省略)
  4. 添加位置编码和 CLS Token 后输入 Transformer Encoder

微调与高分辨率处理

ViT 采用典型的预训练-微调范式:先在大规模数据集上预训练,再迁移到下游任务。微调时:

  • 移除预训练的分类头
  • 添加新的全连接层(D×KD \times KKK 为下游任务类别数),使用零初始化

微调时通常使用比预训练更高的分辨率以提升性能。由于 Patch 大小保持不变,更高分辨率意味着更多的 Patch、更长的序列。Transformer 本身可处理任意长度序列(受内存限制),但预训练的位置编码需要调整。

解决方案:对预训练的位置编码进行 2D 插值,根据 Patch 在原始图像中的位置进行空间插值。

实验分析

数据规模的影响

数据集规模对 ViT 性能影响显著。在小规模数据集(如 ImageNet)上训练时,ViT 不如同等大小的 ResNet,这表明 Transformer 缺乏 CNN 内置的归纳偏置(如局部性和平移不变性)。但随着数据量增加到 JFT-300M,ViT 性能优势逐渐显现并超越 CNN,说明大规模数据能弥补归纳偏置的缺失。

迁移学习能力

在 19 个下游分类任务中,ViT 平均比基于 ResNet 的 BiT 模型高出 2.7% 的准确率,在 CIFAR-10/100、Oxford Flowers 等数据集上表现尤为突出,展现了优秀的泛化能力。

计算效率

与达到相同性能的 EfficientNet 相比,ViT 所需训练计算资源显著更少,在 TPU 上训练速度快 2.5 倍以上,得益于 Transformer 架构的高度并行性。

可视化分析

ViT 可视化分析:线性投影滤波器、位置编码余弦相似度、注意力距离

上图展示了三个关键发现:

  • 线性投影滤波器(左):类似 CNN 低层的 Gabor 滤波器,说明模型学到了基础视觉特征
  • 位置编码相似度(中):相邻位置的编码更相似,呈现明显的 2D 空间结构
  • 注意力距离(右):浅层注意力头关注局部特征(短距离),深层捕捉全局信息(长距离),类似 CNN 的层次特征提取

消融实验要点

实验项目结论
位置编码1D、2D 和相对位置编码效果接近,1D 已能隐式学到 2D 信息
Patch 大小较小 Patch(如 14×1414 \times 14)效果更好,但计算开销更大
注意力机制浅层关注局部,深层关注全局,无需显式局部性偏置

总结

ViT 证明了纯 Transformer 架构在视觉任务中的可行性:当预训练数据充足时,无需卷积操作即可取得卓越的图像分类性能,同时保持更高的计算效率和更强的可扩展性。这一工作为视觉领域的预训练-微调范式开辟了新的方向。

代码实战

完整的 ViT 代码实现(CIFAR-10 图像分类),包含源代码实现与 nn.TransformerEncoder 简洁实现两种方式的对比:

Open In Colab

本节展示 ViT 的 PyTorch 实现,包含源代码实现(逐组件手写)和简洁实现(使用 PyTorch 内置模块)两种方式,均在 CIFAR-10 子集上训练验证。

Patch Embedding

将输入图像切分为不重叠的 Patch,通过线性投影映射到 DD 维空间:

xpiRP2CERD\mathbf{x}_p^i \in \mathbb{R}^{P^2 \cdot C} \xrightarrow{\mathbf{E}} \mathbb{R}^D
class PatchEmbeddingScratch(nn.Module):
    """将图像切分为 Patch 并线性投影"""
 
    def __init__(self, img_size, patch_size, in_channels, d_model):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        patch_dim = patch_size * patch_size * in_channels
        self.proj = nn.Linear(patch_dim, d_model)
 
    def forward(self, x):
        B, C, H, W = x.shape
        p = self.patch_size
        x = x.reshape(B, C, H // p, p, W // p, p)
        x = x.permute(0, 2, 4, 3, 5, 1)
        x = x.reshape(B, self.num_patches, -1)
        x = self.proj(x)
        return x

多头自注意力

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V
class MultiHeadAttention(nn.Module):
    """多头自注意力(手写实现)"""
 
    def __init__(self, d_model, num_heads, dropout=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.scale = self.head_dim ** -0.5
 
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.proj = nn.Linear(d_model, d_model)
        self.attn_drop = nn.Dropout(dropout)
        self.proj_drop = nn.Dropout(dropout)
 
    def forward(self, x):
        B, N, D = x.shape
        h = self.num_heads
        qkv = self.qkv(x).reshape(B, N, 3, h, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
 
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
 
        out = (attn @ v).transpose(1, 2).reshape(B, N, D)
        out = self.proj_drop(self.proj(out))
        return out

Transformer Encoder Block(Pre-LN)

ViT 使用 Pre-LN 变体,LayerNorm 在子层之前:

z=MSA(LN(z1))+z1\mathbf{z}'_\ell = \text{MSA}(\text{LN}(\mathbf{z}_{\ell-1})) + \mathbf{z}_{\ell-1} z=MLP(LN(z))+z\mathbf{z}_\ell = \text{MLP}(\text{LN}(\mathbf{z}'_\ell)) + \mathbf{z}'_\ell
class TransformerBlock(nn.Module):
    """Transformer Encoder Block(Pre-LN 变体)"""
 
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
 
    def forward(self, x):
        x = self.attn(self.norm1(x)) + x
        x = self.mlp(self.norm2(x)) + x
        return x

完整 ViT 模型

组装所有组件:Patch Embedding、CLS Token、Position Embedding、Transformer Encoder、分类头。

class ViTScratch(nn.Module):
    """Vision Transformer 源代码实现"""
 
    def __init__(self, img_size, patch_size, in_channels, d_model,
                 num_heads, num_layers, d_ff, num_classes, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbeddingScratch(
            img_size, patch_size, in_channels, d_model
        )
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, d_model))
        self.pos_drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, num_classes)
 
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
 
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = self.pos_drop(x + self.pos_embed)
        for block in self.blocks:
            x = block(x)
        cls_out = self.norm(x[:, 0])
        logits = self.head(cls_out)
        return logits

简洁实现

使用 nn.Conv2d 实现 Patch Embedding,nn.TransformerEncoder 封装 Encoder 层:

class ViTConcise(nn.Module):
    """Vision Transformer 简洁实现"""
 
    def __init__(self, img_size, patch_size, in_channels, d_model,
                 num_heads, num_layers, d_ff, num_classes, dropout=0.1):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(
            in_channels, d_model,
            kernel_size=patch_size, stride=patch_size
        )
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, d_model))
        self.pos_drop = nn.Dropout(dropout)
 
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=num_heads, dim_feedforward=d_ff,
            dropout=dropout, activation='gelu', batch_first=True,
            norm_first=True
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=num_layers
        )
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, num_classes)
 
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
 
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x).flatten(2).transpose(1, 2)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = self.pos_drop(x + self.pos_embed)
        x = self.encoder(x)
        cls_out = self.norm(x[:, 0])
        logits = self.head(cls_out)
        return logits

两种实现架构等价(参数量均为 142,026),在 CIFAR-10 子集上训练 50 个 epoch 后测试准确率约 50%。准确率较低是因为使用了极小的模型(D=64D=64,4 层)和少量数据(5000 样本),仅用于演示 ViT 的工作原理。

预训练模型微调

使用 ImageNet 预训练的 ViT-B/16(86M 参数),冻结主干网络仅训练分类头(7,690 参数),在 CIFAR-10 上 3 个 epoch 即可达到 87% 测试准确率,充分体现了预训练-微调范式的价值。

参考文献