生成对抗网络(Generative Adversarial Network, GAN)是一种通过对抗训练生成数据的深度学习模型,由生成器(Generator)和判别器(Discriminator)两部分组成,其核心思想源于博弈论中的零和博弈。

一、核心组成

生成器(G)

  目标:生成逼真的假数据(如图像、文本),试图欺骗判别器。
输入:随机噪声(通常服从高斯分布或均匀分布)。
输出:合成数据(如假图像)。

判别器(D)

  目标:区分真实数据(来自训练集)和生成器合成的假数据。
输出:概率值(0到1),表示输入数据是真实的概率。

二、关于对抗训练

1. 动态博弈

  1)生成器尝试生成越来越逼真的数据,使得判别器无法区分真假。
2)判别器则不断优化自身,以更准确地区分真假数据。
3)两者交替训练,最终达到纳什均衡(生成器生成的数据与真实数据分布一致,判别器无法区分,输出概率恒为0.5)。

2. 优化目标(极小极大博弈)

min⁡Gmax⁡DV(D,G)=Ex∼pdata[logD(x)]+Ez∼pz[log(1−D(G(z)))]\min_{G}{\max_D}V(D,G)=E_{x\sim p_{data}}[logD(x)]+E_{z\sim p_z}[log(1-D(G(z)))]GminDmaxV(D,G)=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]

  其中,
D(x)D(x)D(x):判别器对真实数据的判别结果;
G(z)G(z)G(z):生成器生成的假数据;
判别器希望最大化V(D,G)V(D,G)V(D,G)(正确分类真假数据);
生成器希望最小化V(D,G)V(D,G)V(D,G)(让判别器无法区分)。

3.交替更新

1) 固定生成器,训练判别器:

  用真实数据(标签1)和生成数据(标签0)训练判别器,提高其鉴别能力。

2) 固定判别器,训练生成器:

  通过反向传播调整生成器参数,使得判别器对生成数据的输出概率接近1(即欺骗判别器)。

三、典型应用

  图像生成:生成逼真的人脸、风景、艺术画(如 DCGAN、StyleGAN);
图像编辑:图像修复(填补缺失区域)、风格迁移(如将照片转为油画风格);
数据增强:为小样本任务生成额外的训练数据;
超分辨率重建:将低分辨率图像恢复为高分辨率图像。

四、优势与挑战

优势

  无监督学习:无需对数据进行标注,仅通过真实数据即可训练(适用于标注成本高的场景)。

  生成高质量数据:相比其他生成模型(如变分自编码器 VAE),GAN 在图像生成等任务中往往能生成更逼真、细节更丰富的数据。

  灵活性:生成器和判别器可以采用不同的网络结构(如卷积神经网络 CNN、循环神经网络 RNN 等),适用于多种数据类型(图像、文本、音频等)。

挑战

  训练不稳定:容易出现 “模式崩溃”(生成器只生成少数几种相似数据,缺乏多样性)或难以收敛;

  平衡难题:生成器和判别器的能力需要匹配,否则可能一方过强导致另一方无法学习(如判别器太弱,生成器无需优化即可欺骗它);

  可解释性差:生成器的内部工作机制难以解释,生成结果的可控性较弱(近年通过改进模型如 StyleGAN 缓解了这一问题)。

五、Python示例

  使用 PyTorch 实现简单 的GAN 模型,生成手写数字图像。

import matplotlib
matplotlib.use('TkAgg')import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as npplt.rcParams['font.sans-serif']=['SimHei']  # 中文支持
plt.rcParams['axes.unicode_minus']=False  # 负号显示# 设置随机种子,确保结果可复现
torch.manual_seed(42)
np.random.seed(42)# 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 数据加载和预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 将图像归一化到 [-1, 1]
])train_dataset = datasets.MNIST(root='./data', train=True,download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 定义生成器网络
class Generator(nn.Module):def __init__(self, latent_dim=100, img_dim=784):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(latent_dim, 256),nn.LeakyReLU(0.2),nn.BatchNorm1d(256),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.BatchNorm1d(512),nn.Linear(512, img_dim),nn.Tanh()  # 输出范围 [-1, 1])def forward(self, z):return self.model(z).view(z.size(0), 1, 28, 28)# 定义判别器网络
class Discriminator(nn.Module):def __init__(self, img_dim=784):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_dim, 512),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(256, 1),nn.Sigmoid()  # 输出概率值)def forward(self, img):img_flat = img.view(img.size(0), -1)return self.model(img_flat)# 初始化模型
latent_dim = 100
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)# 定义损失函数和优化器
criterion = nn.BCELoss()
lr = 0.0002
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))# 训练函数
def train_gan(epochs):for epoch in range(epochs):for i, (real_imgs, _) in enumerate(train_loader):batch_size = real_imgs.size(0)real_imgs = real_imgs.to(device)# 创建标签real_labels = torch.ones(batch_size, 1).to(device)fake_labels = torch.zeros(batch_size, 1).to(device)# ---------------------#  训练判别器# ---------------------d_optimizer.zero_grad()# 计算判别器对真实图像的损失real_pred = discriminator(real_imgs)d_real_loss = criterion(real_pred, real_labels)# 生成假图像z = torch.randn(batch_size, latent_dim).to(device)fake_imgs = generator(z)# 计算判别器对假图像的损失fake_pred = discriminator(fake_imgs.detach())d_fake_loss = criterion(fake_pred, fake_labels)# 总判别器损失d_loss = d_real_loss + d_fake_lossd_loss.backward()d_optimizer.step()# ---------------------#  训练生成器# ---------------------g_optimizer.zero_grad()# 生成假图像fake_imgs = generator(z)# 计算判别器对假图像的预测fake_pred = discriminator(fake_imgs)# 生成器希望判别器将假图像判断为真g_loss = criterion(fake_pred, real_labels)g_loss.backward()g_optimizer.step()# 打印训练进度if i % 100 == 0:print(f"Epoch [{epoch}/{epochs}] Batch {i}/{len(train_loader)} "f"Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}")# 每个epoch结束后,生成一些样本图像if (epoch + 1) % 10 == 0:generate_samples(generator, epoch + 1, latent_dim, device)# 生成样本图像
def generate_samples(generator, epoch, latent_dim, device, n_samples=16):generator.eval()z = torch.randn(n_samples, latent_dim).to(device)with torch.no_grad():samples = generator(z).cpu()# 可视化生成的样本fig, axes = plt.subplots(4, 4, figsize=(8, 8))for i, ax in enumerate(axes.flatten()):ax.imshow(samples[i][0].numpy(), cmap='gray')ax.axis('off')plt.tight_layout()plt.savefig(f"gan_samples/gan_samples_epoch_{epoch}.png")plt.close()generator.train()# 训练模型
train_gan(epochs=50)# 生成最终样本
generate_samples(generator, "final", latent_dim, device)

最终生成的样本:
在这里插入图片描述

六、小结

  GAN通过对抗机制实现了强大的生成能力,成为生成模型领域的里程碑技术。衍生变体(如CGAN、CycleGAN等)进一步扩展了其应用场景。



End.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/90825.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/90825.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/90825.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue和Element的使用

文章目录1.vue 脚手架创建步骤2.vue项目开发流程3.vue路由4.Element1.vue 脚手架创建步骤 创建一个文件夹 vue双击进入文件夹,在路径上输入cmd输入vue ui, 目的:调出图形化用户界面点击创建 9. 10.在vscode中打开 主要目录介绍 src目录介绍 vue项目启动 图形化界面中没有npm…

如何设置直播间的观看门槛,让直播间安全有效地运行?

文章目录前言一、直播间观看门槛有哪几种形式?二、设置直播间的观看门槛,对直播的好处是什么三、如何一站式实现上述功能?总结前言 打造一个安全、高效、互动良好的直播间并非易事。面对海量涌入的观众,如何有效识别并阻挡潜在的…

【SkyWalking】配置告警规则并通过 Webhook 推送钉钉通知

🧭 本文为 【SkyWalking 系列】第 3 篇 👉 系列导航:点击跳转 【SkyWalking】配置告警规则并通过 Webhook 推送钉钉通知 简介 介绍 SkyWalking 告警机制、告警规则格式以及如何通过 webhook 方式将告警信息发送到钉钉。 引入 服务响应超时…

关于 验证码系统 详解

验证码系统的目的是:阻止自动化脚本访问网页资源,验证访问者是否为真实人类用户。它通过各种测试(图像、行为、计算等)判断请求是否来自机器人。一、验证码系统的整体架构验证码系统通常由 客户端 服务端 风控模型 数据采集 四…

微服务集成snail-job分布式定时任务系统实践

前言 从事开发工作的同学,应该对定时任务的概念并不陌生,就是我们的系统在运行过程中能够自动执行的一些任务、工作流程,无需人工干预。常见的使用场景包括:数据库的定时备份、文件系统的定时上传云端服务、每天早上的业务报表数…

依赖注入的逻辑基于Java语言

对于一个厨师,要做一道菜。传统的做法是:你需要什么食材,就自己去菜市场买什么。这意味着你必须知道去哪个菜市场、怎么挑选食材、怎么讨价还价等等。你不仅要会做菜,还要会买菜,职责变得复杂了。 而依赖注入就像是有一…

skywalking镜像应用springboot的例子

目录 1、skywalking-ui连接skywalking-oap服务失败问题 2、k8s环境 检查skywalking-oap服务状态 3、本地iidea启动服务连接skywalking oap服务 4、基于apache-skywalking-java-agent-9.4.0.tgz构建skywalking-agent镜像 4.1、Dockerfile内容如下 4.2、AbstractBuilder.M…

3. java 堆和 JVM 内存结构

1. JVM介绍和运行流程-CSDN博客 2. 什么是程序计数器-CSDN博客 3. java 堆和 JVM 内存结构-CSDN博客 4. 虚拟机栈-CSDN博客 5. JVM 的方法区-CSDN博客 6. JVM直接内存-CSDN博客 7. JVM类加载器与双亲委派模型-CSDN博客 8. JVM类装载的执行过程-CSDN博客 9. JVM垃圾回收…

UnityShader——SSAO

目录 1.是什么 2.原理 3.各部分解释 2.1.从屏幕空间到视图空间 2.2.以法线半球为基,获取随机向量 2.3.应用偏移,并将其转换为uv坐标 2.4.获取深度 2.5.比较并计算贡献 2.6.最后计算 4.改进 4.1.平滑过渡 4.2.模糊 5.变量和语句解释 5.1._D…

【设计模式】外观模式(门面模式)

外观模式(Facade Pattern)详解一、外观模式简介 外观模式(Facade Pattern) 是一种 结构型设计模式,它为一个复杂的子系统提供一个统一的高层接口,使得子系统更容易使用。 外观模式又称为门面模式&#xff0…

【6.1.1 漫画分库分表】

漫画分库分表 “数据量大了不可怕,可怕的是不知道如何优雅地拆分。” 🎭 人物介绍 架构师老王:资深数据库架构专家,精通各种分库分表方案Java小明:对分库分表充满疑问的开发者ShardingSphere师傅:Apache S…

Tomcat问题:启动脚本startup.bat中文乱码问题解决

一、问题描述 我们第一次下载或者打开Tomcat时可能在控制台会出现中文乱码问题二、解决办法 我的是8.x版本的tomcat用notepad打开:logging.properties 找到:java.util.logging.ConsoleHandler.encoding设置成GBK,重启tomcat即可

Linux中Gitee的使用

一、Gitee简介:Gitee(码云)是中国的一个代码托管和协作开发平台,类似于GitHub或GitLab,主要面向开发者提供代码管理、项目协作及开源生态服务。适用场景个人开发者:托管私有代码或参与开源项目。中小企业&a…

Oracle大表数据清理优化与注意事项详解

一、性能优化策略 1. 批量处理优化批量大小选择: 小批量(1,000-10,000行):减少UNDO生成,但需要更多提交次数中批量(10,000-100,000行):平衡性能与资源消耗大批量(100,000行):适合高配置环境,但需监控资源使…

Anaconda及Conda介绍及使用

文章目录Anaconda简介为什么选择 Anaconda?Anaconda 安装Win 平台macOS 平台Linux 平台Anaconda 界面使用Conda简介Conda下载安装conda 命令环境管理包管理其他常用命令Jupyter Notebook(可选)Anaconda简介 Anaconda 是一个数据科学和机器学…

外包干了一周,技术明显退步

我是一名本科生,自2019年起,我便在南京某软件公司担任功能测试的工作。这份工作虽然稳定,但日复一日的重复性工作让我逐渐陷入了舒适区,失去了前进的动力。两年的时光匆匆流逝,我却在原地踏步,技术没有丝毫…

【QT】多线程相关教程

一、核心概念与 Qt 线程模型 1.线程与进程的区别: 线程是程序执行的最小单元,进程是资源分配的最小单元,线程共享进程的内存空间(堆,全局变量等),而进程拥有独立的内存空间。Qt线程只要关注同一进程内的并发。 2.为什么使用多线程…

VS 版本更新git安全保护问题的解决

问题:我可能移动了一个VS C# 项目,然后,发现里面的git版本检测不能用了 正在打开存储库: X:\Prj_C#\3D fatal: detected dubious ownership in repository at X:/Prj_C#/3DSnapCatch X:/Prj_C#/3D is owned by:S-1-5-32-544 but the current …

Git常用命令一览

Git 是基于 Linux内核开发的版本控制工具。与常用的版本控制工具 CVS, Subversion 等不同,它采用了分布式版本库的方式,不必服务器端软件支持(ps:这得分是用什么样的服务端,使用http协议或者git协议等不太一样。并且在…

基于 JSON 文件定位图片缺陷点并保存

基于JSON的图片缺陷处理流程 ├── 1. 输入检查 │ ├── 验证图片文件是否存在 │ └── 验证JSON文件是否存在 │ ├── 2. 数据加载 │ ├── 打开并加载图片 │ └── 读取并解析JSON文件 │ ├── 3. 缺陷信息提取 │ ├── 检查JSON中是否存在shapes字…