GAN

Generative Adversarial Network 生成对抗网络

GAN 通过两个神经网络的对抗博弈来隐式学习数据分布,从而生成逼真的新样本。 由 Ian Goodfellow 等人于 2014 年提出,是生成模型领域最具影响力的框架之一。

核心思想:对抗博弈

GAN 的架构包含两个相互对抗的网络:

两者构成一个极小极大博弈 (Minimax Game):生成器试图最大化判别器的错误率,判别器试图最大化自身的判别准确率。

数学公式

GAN 的目标函数(价值函数)为:

minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中:

纳什均衡

理论上,当训练达到纳什均衡时,生成器学会了真实数据分布 pG=pdata,此时判别器对任意输入输出 D(x)=12,即无法区分真假。

训练过程

GAN 的训练交替进行以下两步:

第一步:固定 G,训练 D(最大化 V

maxDExpdata[logD(x)]+Ezpz[log(1D(G(z)))]

第二步:固定 D,训练 G(最小化 V

实践中通常最大化 logD(G(z)) 而非最小化 log(1D(G(z))),以避免早期梯度消失:

maxGEzpz[logD(G(z))]
python
import torch
import torch.nn as nn

# 判别器训练
real_output = D(real_data)
fake_data = G(torch.randn(batch_size, latent_dim))
fake_output = D(fake_data.detach())
d_loss = -torch.mean(torch.log(real_output) + torch.log(1 - fake_output))

# 生成器训练
fake_data = G(torch.randn(batch_size, latent_dim))
fake_output = D(fake_data)
g_loss = -torch.mean(torch.log(fake_output))

网络架构

组件 典型结构 输入 输出
生成器 G 全连接 / 转置卷积(反卷积)网络 噪声向量 zRd 生成样本 G(z)
判别器 D 全连接 / CNN 样本 xG(z) 概率值 D()[0,1]

生成器通常使用 BatchNorm + ReLU(最后一层用 Tanh),判别器使用 LeakyReLU + Dropout。

训练挑战

模式崩溃 (Mode Collapse)

生成器可能只学会生成少数几种"安全"的样本来欺骗判别器,而忽略真实数据分布中的其他模式。这是 GAN 训练中最常见的问题。

主要变种

变种 改进点 关键特性
DCGAN 使用卷积架构替代全连接 稳定训练,图像生成质量提升
WGAN 用 Wasserstein 距离替代 JS 散度 缓解模式崩溃,梯度更平滑
WGAN-GP 在 WGAN 基础上用梯度惩罚替代权重裁剪 训练更稳定,无需权重裁剪
CGAN 引入条件信息 y(类别标签等) 可控生成
CycleGAN 无配对数据的图像风格转换 循环一致性损失
StyleGAN 风格映射网络 + 渐进式生成 高分辨率人脸生成,风格可控
BigGAN 大规模训练 + 类条件生成 ImageNet 级别图像生成

与其他生成模型的对比

特性 GAN VAE DDPM
训练方式 对抗训练 变分推断(最大化 ELBO) 去噪分数匹配
生成质量 高(尤其是锐利度) 中等(易模糊)
训练稳定性 低(需精心调参)
模式覆盖 易模式崩溃
采样速度 快(单次前向传播) 慢(需迭代去噪)
似然估计 无显式似然 有(ELBO 下界)

应用场景


机器学习 | 深度学习 | 神经网络 | VAE | DDPM