进入尾声,一个完整的模型训练 ,点亮的第一个led

#自己注释版
import torch
import torchvision.datasets
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import time
# from model import *
from torch.utils.data import DataLoader#定义训练的设备
device= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#准备数据集
train_data = torchvision.datasets.CIFAR10(root='./data_CIF',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root='./data_CIF',train=False,transform=torchvision.transforms.ToTensor(),download=True)#获得数据集长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为 : {train_data_size}")
print(f"测试数据集的长度为 : {test_data_size}")#利用 Dataloader 来加载数据集
train_loader =DataLoader(dataset=train_data,batch_size=64)
test_loader =DataLoader(dataset=test_data,batch_size=64)#搭建神经网络
class Tudui(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Conv2d(in_channels=3,out_channels=32,kernel_size=5,stride=1,padding=2),nn.MaxPool2d(2),nn.Conv2d(in_channels=32,out_channels=32,kernel_size=5,stride=1,padding=2),nn.MaxPool2d(2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size=5,stride=1,padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(in_features=64*4*4,out_features=64),nn.Linear(in_features=64,out_features=10),)def forward(self,x):x = self.model(x)return x#创建网络模型
tudui = Tudui()
#GPU
tudui.to(device)#损失函数
loss_fn = nn.CrossEntropyLoss()
#GPU
loss_fn.to(device)#优化器
# learning_rate = 0.001
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)#设置训练网络的一些参数
#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
#训练的轮数
epoch = 10#添加tensorboard
writer = SummaryWriter("./logs_train")start_time = time.time()
for i in range(epoch):print(f"---------第{i+1}轮训练开始---------")#训练步骤开始tudui.train()       #当网络中有特定层的时候有用for data in train_loader:imgs, targets = data#GPUimgs.to(device)targets.to(device)output = tudui(imgs)loss = loss_fn(output,targets)      #算出误差# 优化器优化模型#梯度置零optimizer.zero_grad()#反向传播loss.backward()#更新参数optimizer.step()#展示输出total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print(f"训练次数:{total_train_step} 花费时间:{end_time - start_time}")print(f"训练次数:{total_train_step},Loss:{loss.item()}")writer.add_scalar("train_loss",loss.item(),total_train_step)#测试步骤开始tudui.eval()        #当网络中有特定层的时候有用total_test_loss = 0#整体的正确个数total_accuracy = 0#取消梯度with torch.no_grad():for data in test_loader:imgs, targets = data#GPUimgs.to(device)targets.to(device)output = tudui(imgs)loss = loss_fn(output,targets)  #计算出误差#构建一些指标total_accuracy += (output.argmax(1) == targets).sum().item() #统计每一批预测的结果和真实结果相等的个数total_test_loss += loss.item()print(f"整体测试集上的loss:{total_test_loss}")print(f"整体测试集上的正确率:{total_accuracy/test_data_size}")print(f"整体测试集上的正确数:{total_accuracy}")#写入tensorboardwriter.add_scalar("test_loss",total_test_loss,total_test_step)writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)total_test_step += 1#保存模型torch.save(tudui.state_dict(),f"./tudui{i}.pth")print("模型已保存")writer.close()

#ai注释版
import torch
import torchvision.datasets
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import time
from torch.utils.data import DataLoader# ----------------- 一、设备 -----------------
# 判断是否有可用的GPU,没有则用CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# ----------------- 二、准备数据 -----------------
# 下载CIFAR10训练集
train_data = torchvision.datasets.CIFAR10(root='./data_CIF', train=True, transform=torchvision.transforms.ToTensor(), download=True)
# 下载CIFAR10测试集
test_data = torchvision.datasets.CIFAR10(root='./data_CIF', train=False, transform=torchvision.transforms.ToTensor(), download=True)# 打印训练集和测试集的长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为 : {train_data_size}")
print(f"测试数据集的长度为 : {test_data_size}")# 使用Dataloader封装数据,方便批量加载
train_loader = DataLoader(dataset=train_data, batch_size=64)
test_loader = DataLoader(dataset=test_data, batch_size=64)# ----------------- 三、搭建神经网络 -----------------
class Tudui(nn.Module):def __init__(self):super().__init__()# 搭建一个简单的卷积神经网络self.model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2), # [b,3,32,32] -> [b,32,32,32]nn.MaxPool2d(2),  # [b,32,32,32] -> [b,32,16,16]nn.Conv2d(32, 32, 5, 1, 2), # -> [b,32,16,16]nn.MaxPool2d(2), # -> [b,32,8,8]nn.Conv2d(32, 64, 5, 1, 2), # -> [b,64,8,8]nn.MaxPool2d(2), # -> [b,64,4,4]nn.Flatten(),  # 拉平成一维 [b,64*4*4]nn.Linear(64*4*4, 64),nn.Linear(64, 10)  # CIFAR10 一共10类)def forward(self, x):return self.model(x)# 创建模型对象
tudui = Tudui()
tudui.to(device)  # 移动到GPU/CPU# ----------------- 四、定义损失函数和优化器 -----------------
# 交叉熵损失函数(多分类标准选择)
loss_fn = nn.CrossEntropyLoss().to(device)# SGD随机梯度下降优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)# ----------------- 五、训练准备 -----------------
total_train_step = 0   # 总训练次数
total_test_step = 0    # 总测试次数
epoch = 10             # 训练轮数# TensorBoard日志工具
writer = SummaryWriter("./logs_train")start_time = time.time()  # 记录起始时间# ----------------- 六、开始训练 -----------------
for i in range(epoch):print(f"---------第{i+1}轮训练开始---------")# 训练模式(启用BN、Dropout等)tudui.train()for data in train_loader:imgs, targets = dataimgs, targets = imgs.to(device), targets.to(device)# 前向传播output = tudui(imgs)# 计算损失loss = loss_fn(output, targets)# 优化器梯度清零optimizer.zero_grad()# 反向传播,自动求导loss.backward()# 更新参数optimizer.step()total_train_step += 1# 每100次打印一次训练lossif total_train_step % 100 == 0:end_time = time.time()print(f"训练次数:{total_train_step} 花费时间:{end_time - start_time}")print(f"训练次数:{total_train_step}, Loss:{loss.item()}")# 写入TensorBoardwriter.add_scalar("train_loss", loss.item(), total_train_step)# ----------------- 七、测试步骤 -----------------tudui.eval()  # 切换到测试模式(停用BN、Dropout)total_test_loss = 0total_accuracy = 0# 不计算梯度,节省显存,加快推理with torch.no_grad():for data in test_loader:imgs, targets = dataimgs, targets = imgs.to(device), targets.to(device)output = tudui(imgs)loss = loss_fn(output, targets)total_test_loss += loss.item()# 预测正确个数统计total_accuracy += (output.argmax(1) == targets).sum().item()print(f"整体测试集上的Loss: {total_test_loss}")print(f"整体测试集上的正确率: {total_accuracy / test_data_size}")print(f"整体测试集上的正确数: {total_accuracy}")# 写入TensorBoard(测试loss和准确率)writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)total_test_step += 1# ----------------- 八、保存模型 -----------------torch.save(tudui.state_dict(), f"./tudui{i}.pth")print("模型已保存")# ----------------- 九、关闭TensorBoard -----------------
writer.close()

 结果图

 忘记清除历史数据了

 

 完整的模型验证套路

import torch
import torchvision.transforms
from PIL import Image
from torch import nnimage_path = "./images/微信截图_20250719220956.png"
image = Image.open(image_path).convert('RGB')
print(type(image))transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
image = transform(image)
print(type(image))#搭建神经网络
class Tudui(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Conv2d(in_channels=3,out_channels=32,kernel_size=5,stride=1,padding=2),nn.MaxPool2d(2),nn.Conv2d(in_channels=32,out_channels=32,kernel_size=5,stride=1,padding=2),nn.MaxPool2d(2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size=5,stride=1,padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(in_features=64*4*4,out_features=64),nn.Linear(in_features=64,out_features=10),)def forward(self,x):x = self.model(x)return xmodel = Tudui()
model.load_state_dict(torch.load("tudui9.pth"))
image = torch.reshape(image, (1,3,32,32))
model.eval()
with torch.no_grad():output = model(image)
print(output)
print(output.argmax(1))

5确实是狗,验证成功 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/92419.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/92419.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/92419.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java变量详解:局部变量、成员变量、类变量区别及使用场景

作为Java开发者,深入理解不同变量的特性是写出高质量代码的基础。本文将为你全面解析三种核心变量类型,并通过实战案例展示它们的正确使用方式。一、变量类型概览 1. 局部变量(Local Variable) 定义:在方法、构造方法或…

【收集电脑信息】collect_info.sh

收集电脑信息 collect_info.sh #!/bin/bashoutput"info.txt" > "$output"# 1. OS Version echo " 操作系统名称及版本 " >> "$output" lsb_release -d | cut -f2- >> "$output" echo -e "\n" >…

服务器清理空间--主要是conda环境清理和删除

1.查看空间情况 (base) zhouy24RL-DSlab:~/zhouy24Files$ df -h Filesystem Size Used Avail Use% Mounted on udev 252G 0 252G 0% /dev tmpfs 51G 4.9M 51G 1% /run /dev/nvme0n1p3 1.9T 1.7T 42G 98% / tmpfs 252G …

UE5多人MOBA+GAS 26、为角色添加每秒回血回蓝(番外:添加到UI上)

文章目录添加生命值和蓝量的状态标签创建无限GE并应用监听添加和去除标签每秒回复配上UI添加生命值和蓝量的状态标签 添加新的标签 CRUNCH_API UE_DECLARE_GAMEPLAY_TAG_EXTERN(Stats_Health_Full)CRUNCH_API UE_DECLARE_GAMEPLAY_TAG_EXTERN(Stats_Health_Empty)CRUNCH_API U…

MetaGPT源码剖析(三):多智能体系统的 “智能角色“ 核心实现——Role类

每一篇文章都短小精悍,不啰嗦。今天我们来深入剖析Role类的代码实现。在多智能体协作系统中,Role(角色)就像现实世界中的 "员工",是执行具体任务、参与协作的基本单位。这段代码是 MetaGPT 框架的核心&#…

【项目经验】小智ai MCP学习笔记

理论 1、什么是MCP MCP(Model Context Protocol,模型上下文协议)是一种开放式协议,它实现了LLM与各种工具的调用。使LLM从对话、生成式AI变成了拥有调用三方工具的AI。用官方的比喻,MCP就是USB-C接口,只要实现了这个接口&#x…

Matlab学习笔记:矩阵基础

MATLAB学习笔记:矩阵基础 作为MATLAB的核心,矩阵是处理数据的基础工具。矩阵本质上是一个二维数组,由行和列组成,用于存储和操作数值数据。在本节中,我将详细讲解矩阵的所有知识点,包括创建、索引、运算、函数等,确保内容通俗易懂。我会在关键地方添加MATLAB代码示例,…

技术演进中的开发沉思-38 MFC系列:关于打印

打印程序也是MFC开发中不能忽视的一个环节,现在做打印开发so easy。但当年做打印开发还是挺麻烦。在当年的桌面程序里就像拼图的最后一块,看着简单,实则要把屏幕上的像素世界,准确映射到打印机的物理纸张上。而MFC 的打印机制就像…

Apache Ignite 长事务终止机制

这段内容讲的是 Apache Ignite 中长事务终止机制(Long Running Transactions Termination),特别是关于分区映射交换(Partition Map Exchange)与事务超时设置(Transaction Timeout)之间的关系。下…

网络编程---TCP协议

TCP协议基础知识TCP(Transmission Control Protocol,传输控制协议)是互联网核心协议之一,位于传输层(OSI第4层),为应用层提供可靠的、面向连接的、基于字节流的数据传输服务。它与IP协议共同构成…

K 近邻算法(K-Nearest Neighbors, KNN)详解及案例

K近邻算法(K-Nearest Neighbors, KNN)详解及案例 一、基本原理 K近邻算法是一种监督学习算法,核心思想是“物以类聚,人以群分”:对于一个新样本,通过计算它与训练集中所有样本的“距离”,找出距…

深入理解 Redis 集群化看门狗机制:原理、实践与风险

在分布式系统中,我们常常需要执行一些关键任务,这些任务要么必须成功执行,要么失败后需要明确的状态(如回滚),并且它们的执行时间可能难以精确预测。如何确保这些任务不会被意外中断,或者在长时…

Python机器学习:从零基础到项目实战

目录第一部分:思想与基石——万法归宗,筑基问道第1章:初探智慧之境——机器学习世界观1.1 何为学习?从人类学习到机器智能1.2 机器学习的“前世今生”:一部思想与技术的演进史1.3 为何是Python?——数据科学…

数据库:库的操作

1:查看所有数据库SHOW DATABASES;2:创建数据库CREATE DATABASE [ IF NOT EXISTS ] 数据库名 [ CHARACTER SET 字符集编码 | COLLATE 字符集校验规则 | ENCRYPTION { Y | N } ];[]:可写可不写{}:必选一个|:n 选 1ENCR…

AngularJS 动画

AngularJS 动画 引言 AngularJS 是一个流行的JavaScript框架,它为开发者提供了一种构建动态Web应用的方式。在AngularJS中,动画是一个强大的功能,可以帮助我们创建出更加生动和引人注目的用户界面。本文将详细介绍AngularJS动画的原理、用法以及最佳实践。 AngularJS 动画…

SonarQube 代码分析工具

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 🧠全面掌握 SonarQube:企业代码质量保障的利器 🚀 在当今 DevOps 流水线中,代码…

vmware vsphere esxi6.5 使用工具导出镜像

注:为什么使用这个工具,我这边主要因为esxi6.5自身bug导致web导出镜像会失败一、下载VMware-ovftool到本地系统(根据你的操作系统版本到官网下载安装,此处略)以下内容默认将VMware-ovftool安装到windows 本地系统为例。…

ES 踩坑记:Set Processor 字段更新引发的 _source 污染

问题背景 社区的一个伙伴想对一个 integer 的字段类型添加一个 keyword 类型的子字段,然后进行精确匹配的查询优化,提高查询的速度。 整个索引数据量不大,并不想进行 reindex 这样的复杂操作,就想到了使用 update_by_query 的存量…

如何彻底搞定 PyCharm 中 pip install 报错 ModuleNotFoundError: No module named ‘requests’ 的问题

如何彻底搞定 PyCharm 中 pip install 报错 ModuleNotFoundError: No module named ‘requests’ 的问题 在使用 PyCharm 开发 Python 项目时,ModuleNotFoundError: No module named requests 是一个常见但令人头疼的问题。本篇博文将从环境配置、原因分析到多种解…

powerquery如何实现表的拼接主键

在做表过程中,有时候没有基表,这个时候就要构造完整的主键,这样才可以使之后匹配的数据不会因为主键不全而丢失数据 我的处理方法是吧多个表的主键拼在一起然后去重,构造一个单单之后之间的表作为基表去匹配数据 所以就哟啊用到自…