PyTorch生成式人工智能——ACGAN详解与实现
- 0. 前言
- 1. ACGAN 简介
- 1.1 ACGAN 技术原理
- 1.2 ACGAN 核心思想
- 1.3 损失函数
- 2. 模型训练流程
- 3. 使用 PyTorch 构建 ACGAN
- 3.1 数据处理
- 3.2 模型构建
- 3.3 模型训练
- 3.4 模型测试
- 相关链接
0. 前言
在生成对抗网络 (Generative Adversarial Network, GAN) 的众多变体中,ACGAN
(Auxiliary Classifier GAN
) 是一个非常经典且实用的条件生成模型。它的核心思想是:在判别器中除了保留“真假判别”这一任务外,额外加入一个辅助分类器,让判别器同时预测输入样本的类别。这样,生成器在训练时不仅需要“欺骗判别器”,还必须生成能够被正确分类的样本,从而在图像语义和类别可控性上得到显著提升。
这一改进让 ACGAN
能够在条件图像生成中表现出色,在复杂数据集上实现按类别生成的能力。相比于传统条件生成对抗网络 (Conditional GAN, cGAN) 简单地把标签拼接到输入,ACGAN
通过 “辅助分类监督” 提供了更细粒度的学习信号,使得生成器得到的梯度更加稳定和有意义。在本节中,将详细介绍 ACGAN
原理,并使用 PyTorch
构建 ACGAN
模型。
1. ACGAN 简介
1.1 ACGAN 技术原理
生成对抗网络 (Generative Adversarial Network, GAN) 的众多变体中,ACGAN
(Auxiliary Classifier GAN
) 能够从随机噪声中生成逼真的图像、文本甚至音乐。然而,传统的 GAN
有一个显著的局限性:缺乏对生成过程的精确控制。我们无法指定要生成“数字7”的图片还是一只“戴墨镜的猫”。
为了解决这个问题,条件生成对抗网络 (Conditional GAN, cGAN) 应运而生。它通过将类别标签信息同时注入生成器 (Generator
) 和判别器 (Discriminator
),实现了条件生成。但这仍然不够完美,cGAN
的判别器最终只输出一个“真/假”的概率,它并没有显式地告诉生成器它生成的图片在类别上是否正确。
ACGAN
(Auxiliary Classifier GAN
) 正是在 CGAN
的基础上,对判别器的任务进行了至关重要的扩展。它不仅判断真伪,还同时担任一个“分类器”的角色。这个简单的改变,极大地提升了生成图像的质量和多样性,尤其是在生成特定类别的图像时。
1.2 ACGAN 核心思想
ACGAN
的核心思想非常直观:为判别器增加一个辅助任务——对输入图像进行分类。其中,生成器的输入包括随机噪声向量 zzz 和目标类别标签 ccc;判别器的输出包括:
- 一个源 (
Source
) 输出:一个标量概率,表示图像是来自真实数据分布的概率 - 一个辅助类别 (
Class
) 输出:类别概率分布
通过引入这个辅助的分类任务,ACGAN
迫使判别器不仅要学习“什么样的图像是真实的”,还要学习“真实图像属于什么类别”。反过来,生成器为了欺骗这个更强大的判别器,也必须生成既逼真又类别分明的图像。
1.3 损失函数
损失函数包含两部分:
- 源判别损失 (
source loss
),用来训练真假判别,通常使用二元交叉熵 - 类别判别损失 (
auxiliary classification loss
),使用多元交叉熵(真实图像的类别为真实标签,生成图像的类别为生成器的条件标签
训练目标:
- 判别器
D
,最小化源判别损失(正确区分真实/虚假图像)并最小化类别判别损失(正确预测类别) - 生成器
G
:生成图像以最大化判别器认为是“真实”的概率,并最小化判别器给出的类别预测与条件类的一致性
2. 模型训练流程
模型训练流程如下:
- 从真实数据中取一批数据 x,c{x,c}x,c
- 判别器更新:
- 计算真实样本的源判别损失与类别判别损失
- 用噪声和随机标签生成虚假样本 x~=G(z,c)\tilde x=G(z,c)x~=G(z,c),计算虚假样本的源判别损失与类别判别损失(可选)
- 把这些损失加权后更新
D
- 生成器更新:
- 用一批噪声与条件标签生成样本 x~\tilde xx~
- 通过
D
计算源输出与辅助类别输出 - 生成器的损失是希望源输出为“真实”,并希望辅助类别输出为生成时的条件标签
- 更新
G
3. 使用 PyTorch 构建 ACGAN
接下来,使用 PyTorch
实现 ACGAN
,并在 MNIST
数据集上进行训练生成手写数字。
3.1 数据处理
(1) 首先,导入所需库并设置超参数:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
import os# 设置超参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 100
num_classes = 10
batch_size = 64
lr = 0.0002
num_epochs = 100
sample_interval = 400# 创建输出目录
os.makedirs("images", exist_ok=True)
os.makedirs("models", exist_ok=True)
(2) 加载 MNIST
数据集,将图像转换为张量,并将像素值从 [0,1]
归一化到 [-1,1]
范围:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])
])
dataset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=transform
)
(3) 构建数据加载器:
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True
)
3.2 模型构建
(1) 首先,定义权重初始化函数:
def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)
(2) 定义生成器。生成器接收随机噪声和类别标签作为输入,通过嵌入层将标签转换为与噪声相同维度的向量,然后将二者相乘融合,之后通过全连接层和转置卷积层逐步上采样,最终生成 28 x 28
的图像:
class Generator(nn.Module):def __init__(self, latent_dim, num_classes):super(Generator, self).__init__()# 将类别标签转换为嵌入向量self.label_emb = nn.Embedding(num_classes, latent_dim)self.init_size = 7 # 初始特征图大小self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 1, 3, stride=1, padding=1),nn.Tanh())def forward(self, noise, labels):# 将噪声和标签嵌入相乘gen_input = torch.mul(self.label_emb(labels), noise)out = self.l1(gen_input)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return img
除了将类别标签转换为嵌入向量进行融合外,也可以直接使用标签的独热编码与噪声向量进行拼接。
(3) 定义判别器。判别器使用卷积层逐步提取特征,最后通过两个全连接层分别输出样本真伪的概率(源判别输出)和类别概率(类别判别输出):
class Discriminator(nn.Module):def __init__(self, num_classes):super(Discriminator, self).__init__()# 卷积层提取特征self.features = nn.Sequential(# 输入: 1x28x28nn.Conv2d(1, 16, 3, stride=2, padding=1), # 16x14x14nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.Conv2d(16, 32, 3, stride=2, padding=1), # 32x7x7nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.BatchNorm2d(32, 0.8),nn.Conv2d(32, 64, 3, stride=2, padding=1), # 64x4x4nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.BatchNorm2d(64, 0.8),nn.Conv2d(64, 128, 3, stride=2, padding=1), # 128x2x2nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.BatchNorm2d(128, 0.8),)# 计算特征图大小: 128 * 2 * 2 = 512self.feature_size = 128 * 2 * 2# 输出真实/虚假的概率self.adv_layer = nn.Sequential(nn.Linear(self.feature_size, 1), nn.Sigmoid())# 输出类别概率self.aux_layer = nn.Sequential(nn.Linear(self.feature_size, num_classes), nn.Softmax(dim=1))def forward(self, img):# 提取特征features = self.features(img)features = features.view(features.size(0), -1) # 展平# 预测真伪和类别validity = self.adv_layer(features)label = self.aux_layer(features)return validity, label
(4) 初始化生成器和判别器,并打印模型结构:
generator = Generator(latent_dim, num_classes).to(device)
discriminator = Discriminator(num_classes).to(device)# 初始化权重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# 打印模型结构
print("Generator structure:")
print(generator)
print("\nDiscriminator structure:")
print(discriminator)
输出模型结构如下所示:
3.3 模型训练
(1) 初始化损失函数和优化器:
# 定义损失函数
adversarial_loss = nn.BCELoss()
auxiliary_loss = nn.CrossEntropyLoss()# 定义优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
(2) 定义变量记录训练过程的损失变化:
G_losses = []
D_losses = []
(3) 实现训练循环。训练过程分为两个部分,先训练判别器,使其能正确区分真实和生成样本,并正确分类;然后训练生成器,使其能生成被判别器判定为真实且分类正确的样本:
# 训练循环
for epoch in range(num_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# 准备真实/虚假标签valid = torch.ones(batch_size, 1).to(device)fake = torch.zeros(batch_size, 1).to(device)# 真实图像和标签real_imgs = imgs.to(device)real_labels = labels.to(device)# 训练判别器optimizer_D.zero_grad()# 真实样本的损失real_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, real_labels)) / 2# 生成虚假样本z = torch.randn(batch_size, latent_dim).to(device)gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)gen_imgs = generator(z, gen_labels)# 虚假样本的损失fake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2# 总判别器损失d_loss = (d_real_loss + d_fake_loss) / 2# 计算判别器准确率pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([real_labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()# 训练生成器optimizer_G.zero_grad()# 生成器希望判别器将虚假样本判断为真实validity, pred_label = discriminator(gen_imgs)g_loss = (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels)) / 2g_loss.backward()optimizer_G.step()# 记录损失G_losses.append(g_loss.item())D_losses.append(d_loss.item())# 打印训练状态if i % 100 == 0:print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}, acc: {100*d_acc:.2f}%] "f"[G loss: {g_loss.item():.4f}]")# 定期保存生成的图像样本batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:# 保存生成器生成的图像save_image(gen_imgs.data[:25], f"images/{batches_done}.png", nrow=5, normalize=True)
(4) 训练完成后,保存模型权重:
torch.save(generator.state_dict(), "models/generator_final.pth")
torch.save(discriminator.state_dict(), "models/discriminator_final.pth")
(5) 绘制训练过程中的损失曲线
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("loss_curve.png")
plt.show()
3.4 模型测试
在 images
文件夹中可以看到训练过程中生成的样本,随着训练进行,生成的数字越来越清晰:
使用训练完成的模型,生成制定类别的数字:
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i in range(10):img = generate_digit(generator, i)ax = axes[i//5, i%5]ax.imshow(img.cpu().squeeze(), cmap='gray')ax.set_title(f"Digit: {i}")ax.axis('off')
plt.tight_layout()
plt.savefig("generated_digits.png")
plt.show()
生成数字 1
:
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i in range(10):img = generate_digit(generator, 1)ax = axes[i//5, i%5]ax.imshow(img.cpu().squeeze(), cmap='gray')ax.set_title(f"Digit: 1")ax.axis('off')
plt.tight_layout()
plt.savefig("generated_digits.png")
plt.show()
相关链接
PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)——神经网络与模型训练过程详解
PyTorch生成式人工智能(2)——PyTorch基础
PyTorch生成式人工智能(3)——使用PyTorch构建神经网络
PyTorch生成式人工智能(4)——卷积神经网络详解
PyTorch生成式人工智能(5)——分类任务详解
PyTorch生成式人工智能(6)——生成模型(Generative Model)详解
PyTorch生成式人工智能(7)——生成对抗网络实践详解
PyTorch生成式人工智能(8)——深度卷积生成对抗网络
PyTorch生成式人工智能(9)——Pix2Pix详解与实现
PyTorch生成式人工智能(10)——CyclelGAN详解与实现
PyTorch生成式人工智能(11)——神经风格迁移
PyTorch生成式人工智能(12)——StyleGAN详解与实现
PyTorch生成式人工智能(13)——WGAN详解与实现
PyTorch生成式人工智能(14)——条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)——变分自编码器详解与实现
PyTorch生成式人工智能(18)——循环神经网络详解与实现
PyTorch生成式人工智能(19)——自回归模型详解与实现
PyTorch生成式人工智能(20)——像素卷积神经网络(PixelCNN)
PyTorch生成式人工智能(21)——归一化流模型(Normalizing Flow Model)
PyTorch生成式人工智能(27)——从零开始训练GPT模型
PyTorch生成式人工智能(28)——MuseGAN详解与实现