目录

一、第三方库导入

二、数据集准备

三、使用转置卷积的生成器

四、使用卷积的判别器

五、生成器生成图像

六、主程序

七、运行结果

7.1 生成器和判别器的损失函数图像

7.2 训练过程中生成器生成的图像

八、完整的pytorch代码


由于之前写gans的代码时,我的生成器和判别器不是使用的全连接网络就是卷积,但是无论这两种方法怎么组合,最后生成器生成的图像效果都很不好。因此最后我选择了生成器使用转置卷积,而判别器使用卷积,最后得到的生成图像确实效果比之前好很多了。

一、第三方库导入

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False  # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader

二、数据集准备

# 手写数字数据集
class MINISTDataset(Dataset):def __init__(self, files, root_dir, transform=None):self.files = filesself.root_dir = root_dirself.transform = transformself.labels = []for f in files:parts = f.split("_")p = parts[2].split(".")[0]self.labels.append(int(p))def __len__(self):return len(self.files)def __getitem__(self, idx):img_path = os.path.join(self.root_dir, self.files[idx])img = Image.open(img_path).convert("L")if self.transform:img = self.transform(img)label = self.labels[idx]return img, label

三、使用转置卷积的生成器

class Generator(nn.Module):def __init__(self, latent_dim=100):super().__init__()self.main = nn.Sequential(# 输入: latent_dim维噪声 -> 输出: 7x7x256nn.ConvTranspose2d(latent_dim, 256, kernel_size=7, stride=1, padding=0, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),# 上采样: 7x7 -> 14x14nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),# 上采样: 14x14 -> 28x28nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),# 输出层: 28x28x1nn.ConvTranspose2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False),nn.Tanh())def forward(self, x):# 将噪声重塑为 (batch_size, latent_dim, 1, 1)x = x.view(x.size(0), -1, 1, 1)return self.main(x)

四、使用卷积的判别器

class Discriminator(nn.Module):def __init__(self):super().__init__()self.main = nn.Sequential(# 输入: 1x28x28nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),  # 输出: 32x14x14nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.3),nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 输出: 64x7x7nn.BatchNorm2d(64),nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.3),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # 输出: 128x7x7nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.3),nn.Flatten(),nn.Linear(128 * 7 * 7, 1),nn.Sigmoid())def forward(self, x):return self.main(x)

五、生成器生成图像

# 展示生成器生成的图像
def gen_img_plot(test_input, save_path):gen_imgs = gen(test_input).detach().cpu()gen_imgs = gen_imgs.view(-1, 28, 28)plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow(gen_imgs[i], cmap="gray")plt.axis("off")plt.savefig(save_path, dpi=300)plt.close()

六、主程序

if __name__ == "__main__":# 对数据做归一化处理transforms = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 路径base_dir = 'C:\\Users\\Administrator\\PycharmProjects\\CNN'train_dir = os.path.join(base_dir, "minist_train")# 获取文件夹里图像的名称train_files = [f for f in os.listdir(train_dir) if f.endswith('.jpg')]# 创建数据集和数据加载器train_dataset = MINISTDataset(train_files, train_dir, transform=transforms)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 参数epochs = 50lr = 0.0002# 初始化模型的优化器和损失函数gen = Generator()dis = Discriminator()d_optim = torch.optim.Adam(dis.parameters(), lr=lr, betas=(0.5, 0.999))  # 判别器的优化器g_optim = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))  # 生成器的优化器loss_fn = torch.nn.BCELoss()  # 二分类交叉熵损失函数# 记录lossD_loss = []G_loss = []# 训练for epoch in range(epochs):d_epoch_loss = 0g_epoch_loss = 0count = len(train_loader)  # 返回批次数for step, (img, _) in enumerate(train_loader):# 每个批次的大小size = img.size(0)random_noise = torch.randn(size, 100)# 判别器训练d_optim.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output))# d_real_loss.backward()gen_img = gen(random_noise)gen_img = gen_img.view(size, 1, 28, 28)fake_output = dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))# d_fake_loss.backward()d_loss = (d_real_loss + d_fake_loss) / 2d_loss.backward()d_optim.step()# 生成器的训练g_optim.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()g_optim.step()# 计算在一个epoch里面所有的g_loss和d_losswith torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_loss# 计算平均损失值with torch.no_grad():d_epoch_loss = d_epoch_loss / countg_epoch_loss = g_epoch_loss / countD_loss.append(d_epoch_loss.item())G_loss.append(g_epoch_loss.item())print("Epoch:", epoch, "  D loss:", d_epoch_loss.item(), "  G Loss:", g_epoch_loss.item())# 每隔2个epoch绘制生成器生成的图像if (epoch + 1) % 2 == 0:test_input = torch.randn(16, 100)name = f"gen_img_{epoch}.jpg"save_path = os.path.join('C:\\Users\\Administrator\\PycharmProjects\\CNN\\gen_img_11', name)gen_img_plot(test_input, save_path)# 绘制损失曲线图plt.figure(figsize=(12, 6))plt.plot(D_loss, label="判别器", color="tomato")plt.plot(G_loss, label="生成器", color="orange")plt.xlabel("epoch")plt.ylabel("loss")plt.title("生成器和判别器的损失曲线图")plt.legend()plt.grid()plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\gen_dis_loss_11.jpg", dpi=300, bbox_inches="tight")plt.close()

七、运行结果

7.1 生成器和判别器的损失函数图像

7.2 训练过程中生成器生成的图像

这里只展示一部分

gen_img_1.jpg

gen_img_25.jpg

gen_img_49.jpg

八、完整的pytorch代码

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False  # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader# 手写数字数据集
class MINISTDataset(Dataset):def __init__(self, files, root_dir, transform=None):self.files = filesself.root_dir = root_dirself.transform = transformself.labels = []for f in files:parts = f.split("_")p = parts[2].split(".")[0]self.labels.append(int(p))def __len__(self):return len(self.files)def __getitem__(self, idx):img_path = os.path.join(self.root_dir, self.files[idx])img = Image.open(img_path).convert("L")if self.transform:img = self.transform(img)label = self.labels[idx]return img, label# 改进的生成器(使用转置卷积)
class Generator(nn.Module):def __init__(self, latent_dim=100):super().__init__()self.main = nn.Sequential(# 输入: latent_dim维噪声 -> 输出: 7x7x256nn.ConvTranspose2d(latent_dim, 256, kernel_size=7, stride=1, padding=0, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),# 上采样: 7x7 -> 14x14nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),# 上采样: 14x14 -> 28x28nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),# 输出层: 28x28x1nn.ConvTranspose2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False),nn.Tanh())def forward(self, x):# 将噪声重塑为 (batch_size, latent_dim, 1, 1)x = x.view(x.size(0), -1, 1, 1)return self.main(x)# 改进的判别器(使用深度卷积网络)
class Discriminator(nn.Module):def __init__(self):super().__init__()self.main = nn.Sequential(# 输入: 1x28x28nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),  # 输出: 32x14x14nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.3),nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 输出: 64x7x7nn.BatchNorm2d(64),nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.3),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # 输出: 128x7x7nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.3),nn.Flatten(),nn.Linear(128 * 7 * 7, 1),nn.Sigmoid())def forward(self, x):return self.main(x)# 展示生成器生成的图像
def gen_img_plot(test_input, save_path):gen_imgs = gen(test_input).detach().cpu()gen_imgs = gen_imgs.view(-1, 28, 28)plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow(gen_imgs[i], cmap="gray")plt.axis("off")plt.savefig(save_path, dpi=300)plt.close()if __name__ == "__main__":# 对数据做归一化处理transforms = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 路径base_dir = 'C:\\Users\\Administrator\\PycharmProjects\\CNN'train_dir = os.path.join(base_dir, "minist_train")# 获取文件夹里图像的名称train_files = [f for f in os.listdir(train_dir) if f.endswith('.jpg')]# 创建数据集和数据加载器train_dataset = MINISTDataset(train_files, train_dir, transform=transforms)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 参数epochs = 50lr = 0.0002# 初始化模型的优化器和损失函数gen = Generator()dis = Discriminator()d_optim = torch.optim.Adam(dis.parameters(), lr=lr, betas=(0.5, 0.999))  # 判别器的优化器g_optim = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))  # 生成器的优化器loss_fn = torch.nn.BCELoss()  # 二分类交叉熵损失函数# 记录lossD_loss = []G_loss = []# 训练for epoch in range(epochs):d_epoch_loss = 0g_epoch_loss = 0count = len(train_loader)  # 返回批次数for step, (img, _) in enumerate(train_loader):# 每个批次的大小size = img.size(0)random_noise = torch.randn(size, 100)# 判别器训练d_optim.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output))# d_real_loss.backward()gen_img = gen(random_noise)gen_img = gen_img.view(size, 1, 28, 28)fake_output = dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))# d_fake_loss.backward()d_loss = (d_real_loss + d_fake_loss) / 2d_loss.backward()d_optim.step()# 生成器的训练g_optim.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()g_optim.step()# 计算在一个epoch里面所有的g_loss和d_losswith torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_loss# 计算平均损失值with torch.no_grad():d_epoch_loss = d_epoch_loss / countg_epoch_loss = g_epoch_loss / countD_loss.append(d_epoch_loss.item())G_loss.append(g_epoch_loss.item())print("Epoch:", epoch, "  D loss:", d_epoch_loss.item(), "  G Loss:", g_epoch_loss.item())# 每隔2个epoch绘制生成器生成的图像if (epoch + 1) % 2 == 0:test_input = torch.randn(16, 100)name = f"gen_img_{epoch}.jpg"save_path = os.path.join('C:\\Users\\Administrator\\PycharmProjects\\CNN\\gen_img_11', name)gen_img_plot(test_input, save_path)# 绘制损失曲线图plt.figure(figsize=(12, 6))plt.plot(D_loss, label="判别器", color="tomato")plt.plot(G_loss, label="生成器", color="orange")plt.xlabel("epoch")plt.ylabel("loss")plt.title("生成器和判别器的损失曲线图")plt.legend()plt.grid()plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\gen_dis_loss_11.jpg", dpi=300, bbox_inches="tight")plt.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/93412.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/93412.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/93412.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ubuntu 通过NAT模式上网

这里必须使用VMnet8 设置为NAT模式 下面设置Ip地址区域ubuntu ip地址设置来自于上面

盲盒抽谷机小程序系统开发:从0到1的完整方法论

开发一款成功的盲盒抽谷机小程序系统,需兼顾技术实现、用户体验与商业逻辑。本文将从需求分析、UI/UX设计、技术架构、测试上线到运营增长,系统梳理从0到1的完整方法论。需求分析:明确“为谁而做”盲盒抽谷机的核心用户是18-35岁的二次元爱好…

web开发,在线%射击比赛管理%系统开发demo,基于html,css,jquery,python,django,三层mysql数据库

经验心得 两业务单,业务crud开发很简单了,自行学习,我说一下学习流程。什么是前端,用到那些技术html,css,javascript分别是什么?进阶jquery,bootstrap,各种常见前端组件又是什么,前端框架react,angular以及…

Centos9傻瓜式linux部署CRMEB 开源商城系统(PHP)

服务器环境推荐要求* Nignx(必须) * PHP 7.1 ~ 7.4(必须此版本内,版本过大会警告不兼容) * MySQL 5.7 ~ 8.0(必须) * Redis(非必须)后台页面展示:…

AI 云电竞游戏盒子:从“盒子”到“云-端-芯”一体化竞技平台的架构实践

摘要 AI 云电竞游戏盒子(以下简称“电竞盒”)不再是一台简单的客厅游戏主机,而是一套以 AI 调度为核心、以云原生架构为骨架、以边缘渲染为肌肉、以端侧感知为神经的“云-端-芯”协同竞技系统。本文基于 2024 年 Q2 落地的量产方案&#xff0…

基于kuboard实现kubernetes的集群管理

1、前提条件安装docker-compose2、步骤在本地目录创建kuboard-v4\在该目录下创建文件docker-compose.yaml,内容如下:configs:create_db_sql:content: |CREATE DATABASE kuboard DEFAULT CHARACTER SET utf8mb4 DEFAULT COLLATE utf8mb4_unicode_ci;cre…

Linux操作系统软件编程——多线程

什么是线程线程的定义是轻量级的进程,可以实现多任务的并发。线程是操作系统任务调度的最小单位线程的创建由某个进程创建,且进程创建线程时,会为其分配独立的栈区空间(默认8M)。线程和所在的进程,以及进程…

linux下找到指定目录下最新日期log文件

以下是一个完整的C函数&#xff0c;用于在指定目录下自动查找最近更新的日志文件&#xff08;根据文件名中的时间戳选择最新的文件&#xff09;&#xff1a;#include <stdio.h> #include <stdlib.h> #include <string.h> #include <dirent.h> #include…

《数学模型》经典案例——钢管的订购与运输

一、问题描述 要铺设一条 A1→A2→⋯→A15A_1 \rightarrow A_2 \rightarrow \cdots \rightarrow A_{15}A1​→A2​→⋯→A15​ 的输送天然气的主管道&#xff0c;如图 6.22 所示。经筛选后可以生产这种主管道钢管的钢厂有 S1,S2,⋯,S7S_1, S_2, \cdots, S_7S1​,S2​,⋯,S7​ 。…

Java Web部署

今天小编来分享下如何将本地写的Java Web程序部署到Linux上。 小编介绍两种方式&#xff1a; 部署基于Linux Systemd服务、基于Docker容器化部署 首先部署基于Linux Systemd服务 那么部署之前&#xff0c;要对下载所需的环境 软件下载 Linux&#xff08;以ubuntu&#xf…

告别AI“炼丹术”:“策略悬崖”理论如何为大模型对齐指明科学路径

摘要&#xff1a;当前&#xff0c;我们训练大模型的方式&#xff0c;尤其是RLHF&#xff0c;充满了不确定性&#xff0c;时常产生“谄媚”、“欺骗”等怪异行为&#xff0c;被戏称为“炼丹”。一篇来自上海AI Lab的重磅论文提出的“策略悬崖”理论&#xff0c;首次为这个混沌的…

深入理解C#特性:从应用到自定义

——解锁元数据标记的高级玩法&#x1f4a1; 核心认知&#xff1a;特性本质揭秘 public sealed class ReviewCommentAttribute : System.Attribute { ... }特性即特殊类&#xff1a;所有自定义特性必须继承 System.Attribute&#xff08;基础规则&#xff09;命名规范&#xff…

机器学习-集成学习(EnsembleLearning)

0 结果展示 0.1 鸢尾花分类 import pandas as pd import numpy as npfrom sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, recall_score, f1_score, classification_repo…

Golang database/sql 包深度解析(一)

database/sql 是 Go 语言标准库中用于与 SQL&#xff08;或类 SQL&#xff09;数据库交互的核心包&#xff0c;提供了一套轻量级、通用的接口&#xff0c;使得开发者可以用统一的方式操作各种不同的数据库&#xff0c;而无需关心底层数据库驱动的具体实现。 核心设计理念 datab…

文章自然润色 API 数据接口

文章自然润色 API 数据接口 ai / 文本处理 基于 AI 的文章润色 专有模型 / 智能纠错。 1. 产品功能 基于自有专业模型进行 AI 智能润色对原始内容进行智能纠错高效的文本润色性能全接口支持 HTTPS&#xff08;TLS v1.0 / v1.1 / v1.2 / v1.3&#xff09;&#xff1b;全面兼容…

【状压DP】3276. 选择矩阵中单元格的最大得分|2403

本文涉及知识点 C动态规划 3276. 选择矩阵中单元格的最大得分 给你一个由正整数构成的二维矩阵 grid。 你需要从矩阵中选择 一个或多个 单元格&#xff0c;选中的单元格应满足以下条件&#xff1a; 所选单元格中的任意两个单元格都不会处于矩阵的 同一行。 所选单元格的值 互…

IDEA 清除 ctrl+shift+r 全局搜索记录

定位文件&#xff1a;在Windows系统中&#xff0c;文件通常位于C:Users/用户名/AppData/Roaming/JetBrains/IntelliJIdea(idea版本)/workspace目录下&#xff0c;文件名为一小串随机字符&#xff1b;在Mac系统中&#xff0c;文件位于/Users/用户名/Library/Application /Suppor…

解锁AI大模型:Prompt工程全面解析

解锁AI大模型&#xff1a;Prompt工程全面解析 本文较长&#xff0c;建议点赞收藏&#xff0c;以免遗失。更多AI大模型开发 学习视频/籽料/面试题 都在这>>Github<< 从新手到高手&#xff0c;Prompt 工程究竟是什么&#xff1f; 在当今数字化时代&#xff0c;AI …

HTTP0.9/1.0/1.1/2.0

在HTTP0.9中&#xff0c;只有GET方法&#xff0c;没有请求头headers&#xff0c;没有状态码&#xff0c;只能用于传输HTML文件。到了HTTP1.0(1996)&#xff0c;HTTP1.0传输请求头&#xff0c;有状态码&#xff0c;并且新增了POST和HEAD方法。HTTP1.0中&#xff0c;使用短连接&a…

gitee 流水线+docker-compose部署 nodejs服务+mysql+redis

文章中的方法是自己琢磨出来的&#xff0c;或许有更优解&#xff0c;共同学习&#xff0c;共同进步&#xff01; docker-compose.yml 文件配置&#xff1a; 说明&#xff1a;【配置中有个别字段冗余&#xff0c;但不影响使用】该文件推荐放在nodejs项目的根目录中&#xff0c…