知识点回顾:

  1. 彩色和灰度图片测试和训练的规范写法:封装在函数中
  2. 展平操作:除第一个维度batchsize外全部展平
  3. dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout

作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 1. 数据预处理
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差
])# 2. 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)# 3. 创建数据加载器
batch_size = 64  # 每批处理64个样本
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定义模型、损失函数和优化器
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()  # 将28x28的图像展平为784维向量self.layer1 = nn.Linear(784, 128)  # 第一层:784个输入,128个神经元self.relu = nn.ReLU()  # 激活函数self.layer2 = nn.Linear(128, 10)  # 第二层:128个输入,10个输出(对应10个数字类别)def forward(self, x):x = self.flatten(x)  # 展平图像x = self.layer1(x)   # 第一层线性变换x = self.relu(x)     # 应用ReLU激活函数x = self.layer2(x)   # 第二层线性变换,输出logitsreturn x# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 初始化模型
model = MLP()
model = model.to(device)  # 将模型移至GPU(如果可用)criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,适用于多分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器# 5. 训练模型(记录每个 iteration 的损失)
def train(model, train_loader, test_loader, criterion, optimizer, device, epochs):model.train()  # 设置为训练模式# 新增:记录每个 iteration 的损失all_iter_losses = []  # 存储所有 batch 的损失iter_indices = []     # 存储 iteration 序号(从1开始)for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)  # 移至GPU(如果可用)optimizer.zero_grad()  # 梯度清零output = model(data)  # 前向传播loss = criterion(output, target)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数# 记录当前 iteration 的损失(注意:这里直接使用单 batch 损失,而非累加平均)iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)  # iteration 序号从1开始# 统计准确率和损失(原逻辑保留,用于 epoch 级统计)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 每100个批次打印一次训练信息(可选:同时打印单 batch 损失)if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')# 原 epoch 级逻辑(测试、打印 epoch 结果)不变epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totalepoch_test_loss, epoch_test_acc = test(model, test_loader, criterion, device)print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')# 绘制所有 iteration 的损失曲线plot_iter_losses(all_iter_losses, iter_indices)# 保留原 epoch 级曲线(可选)# plot_metrics(train_losses, test_losses, train_accuracies, test_accuracies, epochs)return epoch_test_acc  # 返回最终测试准确率# 6. 测试模型
def test(model, test_loader, criterion, device):model.eval()  # 设置为评估模式test_loss = 0correct = 0total = 0with torch.no_grad():  # 不计算梯度,节省内存和计算资源for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy  # 返回损失和准确率# 7.绘制每个 iteration 的损失曲线
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('每个 Iteration 的训练损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 8. 执行训练和测试(设置 epochs=2 验证效果)
epochs = 2  
print("开始训练模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/92763.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/92763.shtml
英文地址,请注明出处:http://en.pswp.cn/web/92763.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分析代码并回答问题

代码 <template><div>Counter: {{ counter }}</div><div>Double Counter: {{ doubleCounter }}</div> </template><script setup lang"ts"> import { ref, computed } from "vue";const counter ref(0);const …

在macOS上扫描192.168.1.0/24子网的所有IP地址

在macOS上扫描192.168.1.0/24子网的所有IP地址&#xff0c;可以通过终端命令实现。以下是几种常用方法&#xff1a; 使用ping命令循环扫描 打开终端执行以下脚本&#xff0c;会逐个ping测试192.168.1.1到192.168.1.254的地址&#xff0c;并过滤出有响应的IP&#xff1a; for i …

Java基础05——类型转换(本文为个人学习笔记,内容整理自哔哩哔哩UP主【遇见狂神说】的公开课程。 > 所有知识点归属原作者,仅作非商业用途分享)

Java基础05——类型转换 类型转换 由于Java是强类型语言&#xff0c;所以要进行有些运算的时候&#xff0c;需要用到类型转换。 如&#xff1a;byte(占1个字节)&#xff0c;short(占2个字节)&#xff0c;char(占2个字节)→int(4个字节)→long(占8个字节)→float(占4个字节)→do…

mysql基础(二)五分钟掌握全量与增量备份

全量备份 Linux环境 数据备份 数据库的备份与恢复有多中方法&#xff0c;通过mysql自带的mysqldump工具可对数据库进行备份。语法&#xff1a; mysqldump -u username -p password --databases db_name > file_name .sql说明&#xff1a; -u参数指定用户名&#xff0c;usern…

使用Windbg分析多线程死锁项目实战问题分享

目录 1、问题描述 2、使用.effmach x86命令切换到32位上下文 3、切换到UI线程&#xff0c;发现UI线程死锁了 4、使用!locks命令查看临界区锁的详细信息&#xff0c;遇到了问题 5、使用dt命令查看临界区对象信息&#xff0c;找到发生死锁的多个线程 6、用户态锁与内核态锁…

防火墙组网方式总结

一、部署模式&#xff1a;灵活适配多样网络环境下一代防火墙&#xff08;NGAF&#xff09;具备极强的网络适应能力&#xff0c;支持五种核心部署模式&#xff0c;可根据不同网络需求灵活选择。路由模式&#xff1a;防火墙相当于路由器&#xff0c;位于内外网之间负责路由寻址&a…

AI大模型:(二)5.1 文生视频(Text-to-Video)模型发展史

目录 1.介绍 2.发展历史 2.1.早期探索阶段(2015-2019) 2.1.1.技术萌芽期 2.1.2.RNN/LSTM时代 2.2.技术突破期(2020-2021) 2.2.1 Transformer引入视频生成 2.2.2 扩散模型的兴起 2.3.商业化突破期(2022-2023) 2.3.1 产品化里程碑 2.3.2 竞争格局形成 2.4.革命…

14mm寻北仪能否塞进液压支架生死缝隙?

在煤矿井下世界的方寸之间&#xff0c;液压支架的每个关键节点都承载着千钧重压。顶梁铰接点、立柱顶端、掩护梁角落&#xff0c;恰恰是空间最为局促的“禁区”。ER-MNS-10A MEMS寻北仪应运而生&#xff01;它采用了先进的MEMS陀螺技术&#xff0c;以14mm至薄高度、40g极致轻盈…

python之浅拷贝深拷贝

文章目录潜拷贝(shallow copy)深拷贝(deep copy)总结一下python的浅拷贝和深拷贝.潜拷贝(shallow copy) python中潜拷贝指的是:构造一个新的复合对象&#xff0c;然后将原对象中的对象引用插入其中 平常开发过程中潜拷贝是比深拷贝更常见的场景. 比如编程中使用到的一些基本的…

普通大学本科生如何入门强化学习?

问题:你平时是如何紧跟大型语言模型和智能体技术前沿的&#xff1f;有哪些具体的学习和跟踪方式&#xff1f;回答:我会通过“输入-内化-实践”结合的方式跟踪前沿。首先&#xff0c;学术动态方面&#xff0c;每天花10分钟浏览arXiv的http://cs.CL和http://cs.AI板块&#xff0c…

新手向:Python实现数据可视化图表生成

Python数据可视化入门&#xff1a;从零开始生成图表数据可视化是数据分析过程中不可或缺的关键环节&#xff0c;它通过将抽象的数字信息转化为直观的图形展示&#xff0c;帮助分析师和决策者更快速、更准确地发现数据中隐藏的模式、规律和发展趋势。在当今大数据时代&#xff0…

VBA即用型代码手册:计算选择的单词数Count Words in Selection

我给VBA下的定义&#xff1a;VBA是个人小型自动化处理的有效工具。可以大大提高自己的劳动效率&#xff0c;而且可以提高数据的准确性。我这里专注VBA,将我多年的经验汇集在VBA系列九套教程中。作为我的学员要利用我的积木编程思想&#xff0c;积木编程最重要的是积木如何搭建及…

DNS(域名系统)

分层结构根域名&#xff08;ipv4&#xff0c;13台&#xff09;&#xff0c;二级域名&#xff0c;三级域名……相关记录A将域名解析为ipv4地址AAAA将域名解析为ipv6地址MX指名该区域为邮件服务区PTR反向查询将主机名解析为域名NS记录服务器的名字CNAME别名查询方式递归查询迭代查…

【大模型】强化学习算法总结

角色和术语定义 State&#xff1a;状态Action&#xff1a;动作Policy/actor model&#xff1a;策略模型&#xff0c;用于决策行动的主要模型Critic/value model&#xff1a;价值模型&#xff0c;用于评判某个行动的价值大小Reward model&#xff1a;奖励模型&#xff0c;用于给…

基于梅特卡夫定律的开源链动2+1模式AI智能名片S2B2C商城小程序价值重构研究

摘要&#xff1a;梅特卡夫定律揭示了网络价值与用户数量的平方关系&#xff0c;在互联网经济中&#xff0c;连接的深度与形式正因人的参与发生质变。本文以开源链动21模式、AI智能名片与S2B2C商城小程序的协同应用为研究对象&#xff0c;通过实证分析其在社群团购、下沉市场等场…

Ubuntu22.04安装CH340驱动及串口

一、CH340驱动安装 1.1 查看USB设备能否被识别 CtrlAltT打开终端&#xff1a; lsusb 插入设备前&#xff1a; 插入设备后&#xff1a; 输出中包含ID 1a86:7523 QinHeng Electronics CH340 serial converter的信息&#xff0c;这表明CH340设备已经被系统识别。 1.2 查看USB转串…

CPU缓存(CPU Cache)和TLB(Translation Lookaside Buffer)缓存现代计算机体系结构中用于提高性能的关键技术

CPU缓存&#xff08;CPU Cache&#xff09;和TLB&#xff08;Translation Lookaside Buffer&#xff09;缓存是现代计算机体系结构中用于提高性能的关键技术。它们通过减少CPU访问数据和指令的延迟来提高系统的整体效率。以下是对这两者的详细解释&#xff1a; 1. CPU 缓存 CPU…

唐扬·高并发系统设计40问

课程下载&#xff1a;https://download.csdn.net/download/m0_66047725/91644703 00开篇词 _ 为什么你要学习高并发系统设计&#xff1f;.pdf 00开篇词丨为什么你要学习高并发系统设计&#xff1f;.mp3 01 _ 高并发系统&#xff1a;它的通用设计方法是什么&#xff1f;.pdf …

基于Spring Data Elasticsearch的分布式全文检索与集群性能优化实践指南

基于Spring Data Elasticsearch的分布式全文检索与集群性能优化实践指南 技术背景与应用场景 随着大数据时代的到来&#xff0c;海量信息的存储与检索成为各类应用的核心需求。Elasticsearch 作为一款分布式搜索引擎&#xff0c;凭借其高可扩展、高可用和实时检索的优势&#x…

Linux系统编程——基础IO

一些前置知识&#xff1a;文件 属性 内容文件 分为 打开的文件、未打开的文件打开的文件&#xff1a;由进程打开&#xff0c;本质是 进程与文件 的关系&#xff1b;维护的文件对象先加载文件属性&#xff0c;文件内容一般按需加载未打开的文件&#xff1a;在永久性存储介质 —…