GAN(Generative Adversarial Network)由 Goodfellow 等人于 2014 年在 NIPS 上发表,通过生成器与判别器的对抗博弈来隐式学习数据分布,是深度学习领域最具影响力的创新之一。它将复杂的无监督生成问题巧妙转化为有监督的二分类问题,完全摒弃了马尔可夫链和变分推断,仅依赖反向传播即可端到端训练。
研究动机
深度学习在判别式模型(如图像分类)上取得了巨大成功,但在生成式模型上的进展却相对滞后。GAN 之前的主流生成方法存在以下痛点:
- 极大似然估计的计算棘手:涉及无向图模型(如 RBM、DBM)时,配分函数(Partition Function)及其梯度的计算由于存在难以解析的积分/求和,变得极其困难。
- 对马尔可夫链的重度依赖:为了近似上述难解的推断,通常必须求助于马尔可夫链蒙特卡洛(MCMC)方法,但 MCMC 存在严重的混合(Mixing)慢问题,特别是在高维空间中。
- 近似方法的局限:如噪声对比估计(NCE),在模型学习到数据空间的一个小子集的近似正确分布后,学习速度会急剧下降。
GAN 的核心目标是:抛弃复杂的近似推断和马尔可夫链,设计一种仅利用标准反向传播就能高效训练的深度生成模型。通过引入生成器与判别器的对抗机制,将复杂的无监督生成问题转化为深度学习最擅长的有监督二分类问题。
核心原理
GAN 的核心思想是同时训练两个相互对抗的神经网络:
- 生成器(Generator, G):输入随机噪声 (通常服从高斯分布),通过多层感知机将其映射为与真实数据分布相似的样本 。生成器从不直接接触真实数据,仅通过判别器的反馈来学习。
- 判别器(Discriminator, D):作为二分类器,区分输入样本是来自真实数据集还是生成器的输出。
两者以对抗方式训练,形成极小极大博弈(minimax game):生成器试图产生能够欺骗判别器的样本,判别器则努力提高分辨真伪的能力。最终目标是让生成器胜出——当判别器无法区分真实与生成样本(输出恒为 0.5)时,达到纳什均衡。
与传统生成模型的区别:传统方法(如变分推断)试图显式建模数据的底层分布,计算复杂。GAN 通过多层感知机直接将噪声映射为输出,利用反向传播高效训练,代价是不能获得数据的显式分布表示。
价值函数
GAN 的优化目标用价值函数 表示:
其中 为真实数据样本, 为输入生成器的随机噪声, 为生成器的输出, 为判别器认为 为真实样本的概率。
公式解析:
- 第一项 :针对真实数据的期望。若判别器能准确识别真实样本,则 ,该项趋近于 (最大值)。
- 第二项 :针对生成数据的期望。若判别器能准确识破生成样本,,该项趋近于 。而生成器希望愚弄判别器,使 ,令该项趋近于 。
极大极小的含义:
- 判别器目标():最大化 ,即让 、。
- 生成器目标():最小化 ,即让 ,使判别器无法区分真假。
价值函数 vs 损失函数
| MLP 损失函数 | GAN 价值函数 | |
|---|---|---|
| 目标 | 单一目标,最小化 | 双方博弈,极大极小 |
| 参与者 | 一个网络 | 两个网络(G 和 D) |
| 形式 | ||
| 优化方式 | 优化一个网络 | 交替优化两个网络 |
训练算法
训练过程在判别器和生成器之间交替进行。
训练判别器
每轮迭代中,先执行 步判别器更新:
- 从真实数据分布采样 个样本
- 从噪声分布采样 个样本 ,生成
- 将真实样本和生成样本组成 mini-batch,沿梯度上升方向更新判别器:
是超参数,需要合理选择:判别器过弱则对生成器无指导意义,过强则生成器的损失梯度趋近于 0,无法更新。论文中使用 。
训练生成器
随后执行 1 步生成器更新:
- 从噪声分布采样 个样本
- 沿梯度下降方向更新生成器:
非饱和技巧
实践中,直接最小化 在训练初期存在梯度消失问题:此时生成器很差,判别器轻易以高置信度拒绝生成样本,导致 ,梯度趋近于 0。
论文提出的工程技巧是:训练生成器时,不最小化 ,而转为最大化 。这能在训练初期提供充足的梯度信号。
值得注意的是,这一改动使得实际优化目标在数学上已不再严格等价于 JS 散度,但在工程实践中效果显著——理论赋予了论文发表的合理性,而这些工程直觉才是让网络跑通的关键。
理论结果
最优判别器
给定固定的生成器 ,对价值函数关于 求偏导并令其为 0,可得最优判别器为:
其中 是真实数据分布, 是生成器学到的分布。
全局最优与 Jensen-Shannon 散度
将最优判别器 代入价值函数,得到虚拟训练准则 :
经过代数变形,可以证明:
由于 Jensen-Shannon 散度非负,且仅在两个分布完全相等时为 0,因此全局最小值为 ,唯一解为 。此时判别器输出恒为 ,无法区分真假样本。
KL 散度
KL 散度(Kullback-Leibler Divergence)是 GAN 理论证明中的重要工具,衡量两个概率分布之间的差异:
- 直观含义:用分布 近似分布 时损失的信息量
- 非对称性:
- KL 散度越小,两个分布越接近;为 0 时两个分布完全一致
实验结果
论文在三个图像数据集上验证了 GAN 的效果:
- MNIST: 手写数字灰度图像
- TFD(Toronto Face Database):人脸灰度图像
- CIFAR-10: 彩色自然图像
由于 GAN 不显式提供概率密度分布,作者利用生成样本拟合高斯 Parzen 窗(Gaussian Parzen Window),在测试集上报告对数似然估计值:
| 模型 | MNIST | TFD |
|---|---|---|
| DBN | ||
| Stacked CAE | ||
| Deep GSN | ||
| Adversarial Nets |
在 MNIST 上,GAN 以 大幅超越了先前模型;在 TFD 上达到 ,与最佳模型表现可比。这证明了在完全摒弃马尔可夫链的前提下,仅利用反向传播,GAN 即可达到领域前沿水平。
此外,在隐空间中对两个噪声向量进行线性插值,生成图像呈现平滑过渡,证明模型真正学到了数据流形分布,而非简单记忆训练集。
总结
GAN 提出了一种全新的生成模型训练范式——通过两个网络的对抗博弈来隐式学习数据分布。其核心贡献在于:
- 范式转移:将分布匹配问题转化为分类任务,理论证明了在非参数极限下生成分布能收敛到真实分布
- 架构简洁:无需马尔可夫链或复杂的近似推断,训练和生成仅需前向/反向传播
- 生成质量高:能表示尖锐甚至退化的分布,生成图像比基于 MCMC 的方法更加锐利
GAN 的提出催生了 DCGAN、WGAN、StyleGAN、CycleGAN 等众多变体,在图像生成、风格迁移、数据增强等方向产生深远影响。尽管如今图像生成领域正被 Diffusion 模型主导,但 GAN 在实时渲染和高频细节纹理生成中仍保持独特的计算优势。
代码实战
以下代码基于 PyTorch 实现 GAN 的核心组件,使用 MNIST 手写数字生成任务演示对抗训练过程:
生成器将随机噪声 映射为 的图像:
class Generator(nn.Module):
def __init__(self, latent_dim, hidden_dim, img_dim):
super().__init__()
self.fc1 = nn.Linear(latent_dim, hidden_dim) # 64 -> 256
self.fc2 = nn.Linear(hidden_dim, hidden_dim * 2) # 256 -> 512
self.fc3 = nn.Linear(hidden_dim * 2, img_dim) # 512 -> 784
self.leaky_relu = nn.LeakyReLU(0.2)
def forward(self, z):
x = self.leaky_relu(self.fc1(z)) # (batch, 256)
x = self.leaky_relu(self.fc2(x)) # (batch, 512)
x = torch.tanh(self.fc3(x)) # (batch, 784)
return x判别器将输入图像映射为真实概率 ,使用 Dropout 防止判别器过强导致生成器梯度消失:
class Discriminator(nn.Module):
def __init__(self, img_dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(img_dim, hidden_dim * 2) # 784 -> 512
self.fc2 = nn.Linear(hidden_dim * 2, hidden_dim) # 512 -> 256
self.fc3 = nn.Linear(hidden_dim, 1) # 256 -> 1
self.leaky_relu = nn.LeakyReLU(0.2)
self.dropout = nn.Dropout(0.3)
def forward(self, x):
x = self.leaky_relu(self.fc1(x)) # (batch, 512)
x = self.dropout(x)
x = self.leaky_relu(self.fc2(x)) # (batch, 256)
x = self.dropout(x)
x = torch.sigmoid(self.fc3(x)) # (batch, 1)
return x训练循环按照论文 Algorithm 1 交替训练判别器和生成器,使用非饱和损失技巧:
criterion = nn.BCELoss()
opt_g = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_d = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
for epoch in range(num_epochs):
for real_imgs, _ in dataloader:
bs = real_imgs.size(0)
real_imgs = real_imgs.view(bs, -1).to(device)
ones = torch.ones(bs, 1, device=device)
zeros = torch.zeros(bs, 1, device=device)
# 训练判别器: max E[log D(x)] + E[log(1 - D(G(z)))]
z = torch.randn(bs, latent_dim, device=device)
fake_imgs = generator(z).detach()
d_loss = (criterion(discriminator(real_imgs), ones)
+ criterion(discriminator(fake_imgs), zeros)) / 2
opt_d.zero_grad()
d_loss.backward()
opt_d.step()
# 训练生成器: 非饱和损失 max log D(G(z))
z = torch.randn(bs, latent_dim, device=device)
g_loss = criterion(discriminator(generator(z)), ones)
opt_g.zero_grad()
g_loss.backward()
opt_g.step()参考文献
- Goodfellow, I., et al. (2014). Generative Adversarial Nets. NIPS 2014.
- 李沐. GAN 论文逐段精读. Bilibili.