【2026】 LLM 大模型系统学习指南 (33)
GAN 入门:生成对抗网络的核心原理与基础实操
生成对抗网络(Generative Adversarial Networks, GAN)是生成式 AI 的核心框架之一,核心逻辑是 “两个模型对抗训练、互相促进”—— 生成器(Generator)负责 “造假”(生成模拟数据),判别器(Discriminator)负责 “鉴真”(区分真实数据与伪造数据),最终让生成器造出足以以假乱真的数据,实现无监督生成。
本文从基础概念入手,用通俗类比拆解 GAN 的工作原理、训练流程,再通过完整代码实现简单 GAN,帮你快速入门这一 “对抗式学习” 框架。
一、GAN 的核心概念:“造假者” 与 “鉴宝师” 的博弈
GAN 的本质是 “零和博弈”,两个核心组件分工明确,如同现实中的造假者与鉴宝师:
1. 两大核心组件
- 生成器(Generator, G):“造假者”。输入随机噪声(Noise),通过神经网络生成 “伪造数据”(如假图像、假文本),目标是让伪造数据尽可能接近真实数据,骗过判别器。
- 判别器(Discriminator, D):“鉴宝师”。输入数据(可能是真实数据或生成器造的假数据),输出一个概率值(0-1),判断该数据是 “真实数据”(输出接近 1)还是 “伪造数据”(输出接近 0),目标是精准区分真假。
2. 核心目标
- 生成器 G:最小化判别器 D 的 “识别准确率”,即让 D 把假数据误判为真(D (G (Noise)) → 1);
- 判别器 D:最大化 “识别准确率”,即让 D 把真数据判为真(D (Real Data) → 1)、把假数据判为假(D (G (Noise)) → 0);
- 最终平衡:生成器 G 的造假水平足够高,判别器 D 无法区分真假(D (G (Noise)) ≈ 0.5),达到 “纳什均衡”。
通俗类比
就像一场持续的博弈:
- 造假者(G)初期造的假币很粗糙,鉴宝师(D)一眼就能识破;
- 造假者根据鉴宝师的判断,不断改进造假技术,造出更逼真的假币;
- 鉴宝师同时也在学习,提升识别假币的能力;
- 最终造假者造出足以以假乱真的假币,鉴宝师只能靠猜(正确率 50%),博弈达到平衡。
二、GAN 的工作原理:对抗训练的完整流程
GAN 的训练是 “交替训练生成器和判别器” 的循环过程,每一轮训练都让两个模型的能力同步提升,具体流程如下:
1. 训练前准备
- 数据集:真实数据(如 MNIST 手写数字图像、自然风景照片),无需标签(无监督学习);
- 随机噪声:生成器的输入,通常是服从正态分布的随机向量(如 100 维噪声向量),为生成器提供 “创作素材”;
- 模型初始化:生成器和判别器均为简单神经网络(如全连接网络、简单 CNN),参数随机初始化。
2. 一轮完整训练流程(核心步骤)
步骤 1:训练判别器 D(提升鉴真能力)
- 用真实数据训练 D:输入真实数据,让 D 输出接近 1 的概率,计算损失(如二分类交叉熵),反向传播更新 D 的参数;
- 用伪造数据训练 D:生成器 G 输入随机噪声,生成假数据,让 D 输出接近 0 的概率,计算损失,反向传播继续更新 D 的参数;
- 关键:训练 D 时,固定生成器 G 的参数(不更新),只优化 D 的识别能力。
步骤 2:训练生成器 G(提升造假能力)
- 生成假数据:G 输入随机噪声,生成假数据;
- 让 D “误判”:将假数据输入 D,要求 D 输出接近 1 的概率(即让 D 误以为是真数据),计算损失,反向传播更新 G 的参数;
- 关键:训练 G 时,固定判别器 D 的参数(不更新),只优化 G 的造假能力。
步骤 3:循环迭代
重复步骤 1 和步骤 2,直到达到预设训练轮次,或判别器对假数据的判断概率接近 0.5(无法区分真假)。
3. 损失函数:量化对抗效果
GAN 的损失函数基于二分类交叉熵,分别定义生成器和判别器的损失:
- 判别器损失(Loss_D):包含 “识别真实数据的损失” 和 “识别伪造数据的损失”,目标是最小化 Loss_D;真实数据噪声
- 生成器损失(Loss_G):让判别器误判假数据为真,目标是最小化 Loss_G;噪声
三、GAN 的关键要点:训练稳定的核心注意事项
基础 GAN 的训练容易出现 “不稳定” 问题,核心原因是两个模型的能力失衡(如 D 太强导致 G 无法学习,或 G 太强导致 D 失效),需注意以下关键要点:
1. 常见问题
- 模式崩溃(Mode Collapse):生成器只生成少数几种类型的假数据(如只生成 MNIST 中的数字 “5”),无法覆盖真实数据的所有分布;
- 梯度消失 / 爆炸:训练后期,判别器识别能力过强(输出接近 0 或 1),导致生成器的梯度趋近于 0,无法继续更新;
- 训练振荡:两个模型的能力交替碾压,损失函数剧烈波动,无法收敛。
2. 基础解决思路
- 交替训练节奏:每训练 1-2 轮判别器,再训练 1 轮生成器,避免某一方能力失衡;
- 数据预处理:对真实数据归一化(如像素值缩放到 [-1,1]),让生成器和判别器的输入分布更稳定;
- 激活函数选择:判别器输出层用 Sigmoid(输出 0-1 概率),中间层用 LeakyReLU(避免梯度消失);生成器输出层用 Tanh(配合数据归一化);
- 权重初始化:采用正交初始化或 Xavier 初始化,让模型参数分布更合理。
四、实操:用基础 GAN 生成 MNIST 手写数字
以 MNIST 手写数字数据集为例,实现基础 GAN,让生成器学会生成逼真的手写数字,代码简洁易懂,适合入门:
1. 完整代码(PyTorch 实现)
python
运行
# 安装依赖
# pip install torch torchvision numpy matplotlib
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# ---------------------- 1. 数据准备(MNIST无监督数据) ----------------------
transform = transforms.Compose([
transforms.ToTensor(), # 转为张量(0-1范围)
transforms.Normalize((0.5,), (0.5,)) # 归一化到[-1,1],配合生成器Tanh输出
])
# 加载MNIST数据集(仅用训练集,无监督学习)
train_dataset = datasets.MNIST(
root="./data", train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# ---------------------- 2. 定义GAN的两个核心模型 ----------------------
# (1)生成器G:输入100维噪声→输出28×28手写数字图像
class Generator(nn.Module):
def __init__(self, noise_dim=100, img_dim=784):
super().__init__()
self.model = nn.Sequential(
nn.Linear(noise_dim, 256), # 100维噪声→256维
nn.LeakyReLU(0.2), # 避免梯度消失
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, img_dim),
nn.Tanh() # 输出[-1,1],匹配数据归一化
)
def forward(self, x):
return self.model(x)
# (2)判别器D:输入28×28图像→输出0-1概率(真/假)
class Discriminator(nn.Module):
def __init__(self, img_dim=784):
super().__init__()
self.model = nn.Sequential(
nn.Linear(img_dim, 1024), # 784维图像→1024维
nn.LeakyReLU(0.2),
nn.Dropout(0.3), # 防止过拟合
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid() # 输出0-1概率
)
def forward(self, x):
return self.model(x)
# ---------------------- 3. 模型初始化与训练配置 ----------------------
# 超参数
noise_dim = 100 # 噪声维度
img_dim = 28 * 28 # 图像扁平化维度(28×28)
epochs = 50 # 训练轮次
lr = 2e-4 # 学习率
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化模型、损失函数、优化器
generator = Generator(noise_dim, img_dim).to(device)
discriminator = Discriminator(img_dim).to(device)
criterion = nn.BCELoss() # 二分类交叉熵损失
# 分别定义两个模型的优化器
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
# ---------------------- 4. 开始对抗训练 ----------------------
# 固定噪声(用于每轮可视化生成效果)
fixed_noise = torch.randn(10, noise_dim).to(device)
for epoch in range(epochs):
# 记录每轮损失
loss_d_total = 0.0
loss_g_total = 0.0
for batch_idx, (real_imgs, _) in enumerate(train_loader): # 无监督:忽略标签
batch_size = real_imgs.size(0)
real_imgs = real_imgs.view(-1, img_dim).to(device) # 扁平化图像(28×28→784)
# ---------------------- 步骤1:训练判别器D ----------------------
# 1.1 用真实数据训练D:目标输出1
real_labels = torch.ones(batch_size, 1).to(device) # 真实数据标签:1
output_real = discriminator(real_imgs)
loss_d_real = criterion(output_real, real_labels)
# 1.2 用伪造数据训练D:目标输出0
noise = torch.randn(batch_size, noise_dim).to(device) # 随机噪声
fake_imgs = generator(noise) # G生成假图像
fake_labels = torch.zeros(batch_size, 1).to(device) # 伪造数据标签:0
output_fake = discriminator(fake_imgs.detach()) # 固定G,不更新其参数
loss_d_fake = criterion(output_fake, fake_labels)
# 1.3 计算D的总损失,反向传播更新D
loss_d = loss_d_real + loss_d_fake
optimizer_d.zero_grad()
loss_d.backward()
optimizer_d.step()
# ---------------------- 步骤2:训练生成器G ----------------------
# 2.1 让D误判假数据为真:目标输出1
output_fake_g = discriminator(fake_imgs) # 不固定D,更新G时需D的梯度
loss_g = criterion(output_fake_g, real_labels) # 假数据标签设为1
# 2.2 计算G的损失,反向传播更新G
optimizer_g.zero_grad()
loss_g.backward()
optimizer_g.step()
# 累积损失
loss_d_total += loss_d.item()
loss_g_total += loss_g.item()
# 每轮平均损失
avg_loss_d = loss_d_total / len(train_loader)
avg_loss_g = loss_g_total / len(train_loader)
print(f"Epoch {epoch+1}/{epochs} | Loss_D: {avg_loss_d:.4f} | Loss_G: {avg_loss_g:.4f}")
# 每10轮可视化生成效果
if (epoch + 1) % 10 == 0:
generator.eval()
with torch.no_grad():
generated_imgs = generator(fixed_noise).cpu().numpy() # 生成假图像
generated_imgs = (generated_imgs + 1) / 2 # 从[-1,1]转回[0,1],方便可视化
# 绘制10个生成的数字
plt.figure(figsize=(10, 2))
for i in range(10):
plt.subplot(1, 10, i+1)
plt.imshow(generated_imgs[i].reshape(28, 28), cmap="gray")
plt.axis("off")
plt.title(f"Generated Images (Epoch {epoch+1})")
plt.savefig(f"gan_generated_epoch_{epoch+1}.png")
plt.show()
generator.train()
print("GAN训练完成!生成结果已保存。")
2. 代码核心逻辑拆解
- 数据预处理:MNIST 图像归一化到 [-1,1],配合生成器的 Tanh 输出,让数据分布更稳定;
- 生成器:4 层全连接网络,用 LeakyReLU 避免梯度消失,输出层 Tanh 匹配数据范围;
- 判别器:4 层全连接网络,用 Dropout 防止过拟合,输出层 Sigmoid 输出 0-1 概率;
- 训练节奏:每批数据先训练判别器(更新 D 参数),再训练生成器(更新 G 参数),平衡两者能力;
- 可视化:用固定噪声生成图像,观察每 10 轮的生成效果变化,直观判断模型训练进度。
3. 预期效果
- 训练初期(前 10 轮):生成的数字模糊、轮廓不清晰,判别器损失(Loss_D)远低于生成器损失(Loss_G);
- 训练中期(20-30 轮):数字轮廓逐渐清晰,能分辨出大致形状,Loss_D 和 Loss_G 趋于平衡;
- 训练后期(40-50 轮):生成的数字与真实 MNIST 数字高度相似,轮廓完整、无明显畸变。
五、GAN 的应用场景与延伸方向
1. 典型应用场景
- 图像生成:生成人脸、风景、艺术画(如 StyleGAN 生成高清人脸);
- 图像编辑:图像修复(填补破损区域)、风格迁移(如把照片转为油画风格);
- 数据增强:生成模拟训练数据,解决小样本学习问题(如医疗影像数据增强);
- 超分辨率重建:将低分辨率图像提升为高分辨率。
2. 进阶延伸方向
基础 GAN 是后续复杂 GAN 架构的基石,常见进阶方向包括:
- DCGAN:用卷积层替代全连接层,提升图像生成质量;
- WGAN-GP:解决基础 GAN 的训练不稳定和模式崩溃问题;
- StyleGAN:实现风格化生成,可精准控制生成内容的细节(如人脸的表情、发型);
- CycleGAN:实现跨域生成(如把马变成斑马、把照片变成素描)。
总结:GAN 入门的核心逻辑与学习建议
- 核心逻辑:GAN 的本质是 “对抗中学习”—— 生成器和判别器互相博弈,最终达到平衡,实现以假乱真的生成;
- 学习关键:先理解 “造假者 - 鉴宝师” 的类比,再通过基础代码跑通训练流程,直观感受模型迭代过程;
- 避坑重点:关注训练稳定性,控制交替训练节奏、合理设置激活函数和归一化方式,避免模式崩溃和梯度消失;
- 进阶路径:先掌握基础 GAN,再学习 DCGAN、WGAN-GP 等进阶架构,逐步过渡到复杂场景。
GAN 的思想简洁而强大,是生成式 AI 的重要分支,掌握其基础原理后,能轻松理解后续各类复杂生成模型。








