冲冲冲😊
here😊

文章目录

  • PyTorch多层感知机模型构建与MNIST分类训练笔记
    • 🎯 1. 任务概述
    • ⚙️ 2. 环境设置
      • 2.1 导入必要库
      • 2.2 GPU配置
    • 🧠 3. 模型构建
      • 3.1 模型定义关键点
      • 3.2 损失函数选择
      • 3.3 模型初始化与设备选择
    • 🔧 4. 优化器配置
      • 4.1 随机梯度下降优化器
    • 🔄 5. 训练循环实现
      • 5.1 训练函数设计
      • 5.2 测试函数设计
    • 📦 6. 数据准备
      • 6.1 加载MNIST数据集
    • 🚀 7. 训练执行
      • 7.1 训练循环主体
      • 7.2 训练过程输出(部分)
    • 📊 8. 结果可视化
      • 8.1 损失曲线绘制
      • 8.2 准确率曲线绘制

PyTorch多层感知机模型构建与MNIST分类训练笔记

🎯 1. 任务概述

解决MNIST手写数字分类问题,创建一个简单的多层感知机(MLP)模型

  • 使用torch.nn.Linear层构建模型
  • 使用ReLU作为激活函数
  • 包含两个全连接隐藏层(120和84个神经元)和输出层(10个神经元对应10个数字类别)
  • 模型输入为展平后的28×28=784像素图像

⚙️ 2. 环境设置

2.1 导入必要库

import torch
from torch import nn
import os

2.2 GPU配置

# os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,6"  # 只使用空闲的GPU

🧠 3. 模型构建

3.1 模型定义关键点

class Model(nn.Module):def __init__(self):super().__init__()# 第一层输入展平后的特征长度28乘28,创建120个神经元self.liner_1 = nn.Linear(28*28, 120)# 第二层输入的是前一层的输出,创建84个神经元self.liner_2 = nn.Linear(120, 84)# 输出层接受第二层的输入84,输出分类个数10self.liner_3 = nn.Linear(84, 10)def forward(self, input):x = input.view(-1, 28*28)  # 将输入展平为二维(1,28,28)->(28*28)x = torch.relu(self.liner_1(x))x = torch.relu(self.liner_2(x))x = self.liner_3(x)return x

📝 模型结构说明

  1. 输入层:将28×28图像展平为784维向量
  2. 隐藏层1:120个神经元,使用ReLU激活
  3. 隐藏层2:84个神经元,使用ReLU激活
  4. 输出层:10个神经元对应10个数字类别

3.2 损失函数选择

loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失函数
'''
注意两个参数
1. weight: 各类别的权重(处理不平衡数据集)
2. ignore_index: 忽略特定类别的索引
另外,它要求实际类别为数值编码,而不是独热编码
'''

🔍 为什么选择交叉熵损失?

  • 适用于多分类问题
  • 内部集成了Softmax计算,简化实现流程
  • 对错误分类有较强的惩罚

3.3 模型初始化与设备选择

device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model().to(device)
# print(device)  # 可选:打印使用的设备

💡 GPU加速提示
使用.to(device)将模型移动到GPU可显著加快训练速度,特别是对于大模型和大数据集

🔧 4. 优化器配置

4.1 随机梯度下降优化器

optimizer = torch.optim.SGD(model.parameters(), lr=0.005)

🔧 关键参数解析

  • params: 需要优化的模型参数(通常为model.parameters()
  • lr=0.005: 学习率,控制参数更新步长的超参数
  • 其他可选参数:momentum(动量),weight_decay(L2正则化)

🔄 5. 训练循环实现

5.1 训练函数设计

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 获取当前数据集样本总数量num_batches = len(dataloader)   # 获取当前data loader总批次数# train_loss用于累计所有批次的损失之和, correct用于累计预测正确的样本总数train_loss, correct = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)# 进行预测,并计算当前批次的损失pred = model(X)loss = loss_fn(pred, y)# 利用反向传播算法,根据损失优化模型参数optimizer.zero_grad()   # 先将梯度清零loss.backward()          # 损失反向传播,计算模型参数梯度optimizer.step()         # 根据梯度优化参数with torch.no_grad():# correct用于累计预测正确的样本总数correct += (pred.argmax(1) == y).type(torch.float).sum().item()# train_loss用于累计所有批次的损失之和train_loss += loss.item()# train_loss 是所有批次的损失之和,所以计算全部样本的平均损失时需要除以总的批次数train_loss /= num_batches# correct 是预测正确的样本总数,若计算整个epoch总体正确率,需要除以样本总数量correct /= sizereturn train_loss, correct

5.2 测试函数设计

def test(dataloader, model):size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizereturn test_loss, correct

📊 数据加载器相关方法区别

方法返回内容适用场景
len(dataset)数据集总样本数(如100)数据统计、划分
len(dataloader)总批次数(如4)训练循环控制
len(dataloader.dataset)等同于 len(dataset)需要访问原始数据时

📦 6. 数据准备

6.1 加载MNIST数据集

import torchvision
from torchvision.transforms import ToTensortrain_ds = torchvision.datasets.MNIST("data/", train=True, transform=ToTensor(), download=True)
test_ds = torchvision.datasets.MNIST("data/", train=False, transform=ToTensor(), download=True)train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=64)

🚀 7. 训练执行

7.1 训练循环主体

# 对全部的数据集训练50个epoch(一个epoch表示对全部数据训练一遍)
epochs = 50 
train_loss, train_acc = [], []
test_loss, test_acc = [], []for epoch in range(epochs):# 调用train()函数训练epoch_loss, epoch_acc = train(train_dl, model, loss_fn, optimizer)# 调用test()函数测试epoch_test_loss, epoch_test_acc = test(test_dl, model)train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)# 定义一个打印模板template = ("epoch:{:2d},train_loss:{:.6f},train_acc:{:.1f}%,""test_loss:{:.5f},test_acc:{:.1f}%")print(template.format(epoch, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))print("Done")

7.2 训练过程输出(部分)

epoch: 0,train_loss:2.157364,train_acc:46.7%,test_loss:1.83506,test_acc:63.7%
epoch: 1,train_loss:1.222660,train_acc:74.3%,test_loss:0.74291,test_acc:81.8%
epoch: 2,train_loss:0.612381,train_acc:84.0%,test_loss:0.49773,test_acc:86.3%
...
epoch:48,train_loss:0.110716,train_acc:96.9%,test_loss:0.12003,test_acc:96.4%
epoch:49,train_loss:0.108877,train_acc:97.0%,test_loss:0.11783,test_acc:96.5%
Done

📈 训练趋势分析

  • 初始准确率:46.7%(训练集),63.7%(测试集)
  • 最终准确率:97.0%(训练集),96.5%(测试集)
  • 过拟合现象轻微:训练集和测试集性能差距仅0.5%

📊 8. 结果可视化

8.1 损失曲线绘制

import matplotlib.pyplot as pltplt.plot(range(1, epochs+1), train_loss, label="train_loss")
plt.plot(range(1, epochs+1), test_loss, label="test_loss", ls="--")
plt.xlabel("epoch")
plt.legend()
plt.show()

注释:损失曲线显示训练初期损失快速下降,后期趋于平稳

8.2 准确率曲线绘制

plt.plot(range(1, epochs+1), train_acc, label="train_acc")
plt.plot(range(1, epochs+1), test_acc, label="test_acc")
plt.xlabel("epoch")
plt.legend()
plt.show()

注释:准确率曲线稳步上升,最终达到96.5%的测试准确率

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/91058.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/91058.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/91058.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

android tabLayout 切换fragment fragment生命周期

1、TabLayout 与 Fragment 结合使用的常见方式 通常会使用 FragmentPagerAdapter 或 FragmentStatePagerAdapter 与 ViewPager 配合,再将 TabLayout 与 ViewPager 关联,实现通过 TabLayout 切换 Fragment。 以下是布局文件示例 activity_main.xml: <LinearLayout xmln…

马蹄集 BD202401补给

可怕的战争发生了&#xff0c;小度作为后勤保障工作人员&#xff0c;也要为了保卫国家而努力。现在有 N(1≤N≤)个堡垒需要补给&#xff0c;然而总的预算 B(1≤B≤)是有限的。现在已知第 i 个堡垒需要价值 P(i) 的补给&#xff0c;并且需要 S(i) 的运费。 鉴于小度与供应商之间…

《Llava:Visual Instruction Tuning》论文精读笔记

论文链接&#xff1a;arxiv.org/pdf/2304.08485 参考视频&#xff1a;LLAVA讲解_哔哩哔哩_bilibili [论文速览]LLaVA: Visual Instruction Tuning[2304.08485]_哔哩哔哩_bilibili 标题&#xff1a;Visual Instruction Tuning 视觉指令微调 背景引言 大模型的Instruction…

【DataWhale】快乐学习大模型 | 202507,Task01笔记

引言 我从2016年开始接触matlab看别人做语音识别&#xff0c;再接触tensorflow的神经网络&#xff0c;2017年接触语音合成&#xff0c;2020年做落地的医院手写数字识别。到2020年接触pytorch做了计算机视觉图像分类&#xff0c;到2021年做了目标检测&#xff0c;2022年做了文本…

机器学习中的朴素贝叶斯(Naive Bayes)模型

1. 用实例来理解朴素贝叶斯 下面用具体的数据来演示垃圾邮件 vs 正常邮件的概率计算假设我们有一个小型邮件数据集邮件内容类别&#xff08;垃圾/正常&#xff09;“免费 赢取 大奖”垃圾“免费 参加会议”正常“中奖 点击 链接”垃圾“明天 开会”正常“赢取 免费 礼品”垃圾 …

document.documentElement详解

核心概念定义 它始终指向当前文档的根元素&#xff0c;在 HTML 文档中对应 <html> 标签。与 document.body&#xff08;对应 <body>&#xff09;和 document.head&#xff08;对应 <head>&#xff09;形成层级关系。与 document.body 的区别 <html> &l…

c#进阶之数据结构(动态数组篇)----Queue

1、简介这个是c#封装的队列类型&#xff0c;同栈相反&#xff0c;这个是先进先出&#xff0c;一般用于事件注册&#xff0c;或者数据的按顺序处理&#xff0c;理解为需要排队处理的可以用队列来处理。注意&#xff0c;队列一定是有顺序的&#xff0c;先进确实是会先出&#xff…

使用 keytool 在服务器上导入证书操作指南(SSL 证书验证错误处理)

使用 keytool 在服务器上导入证书操作指南(SSL 证书验证错误处理) 一、概述 本文档用于指导如何在运行 Java 应用程序的服务器上,通过keytool工具将证书导入 Java 信任库,解决因证书未被信任导致的 SSL/TLS 通信问题(如PKIX path building failed错误)。 二、操作步骤…

VUE export import

目录 命名导出 导出变量 导出函数 总结 默认导出 导出变量 导出函数 总结 因为总是搞不懂export和Import什么时候需要加{}&#xff0c;什么时候不用&#xff0c;所以自己测试了一下&#xff0c;以下是总结。 需不需要加{}取决于命名导出还是默认导出&#xff0c;命名导…

端侧宠物识别+拍摄控制智能化:解决设备识别频次识别率双低问题

随着宠物成为家庭重要成员&#xff0c;宠物影像创作需求激增&#xff0c;传统相机系统 “人脸优先” 的调度逻辑已难以应对宠物拍摄的复杂场景。毛发边缘模糊、动态姿态多变、光照反差剧烈等问题&#xff0c;推动着智能拍摄技术向 “宠物优先” 范式转型。本文基于端侧 AI 部署…

Popover API 实战指南:前端弹层体验的原生重构

&#x1fa84; Popover API 实战指南&#xff1a;前端弹层体验的原生重构 还在用 position: absolute JS 定位做 tooltip&#xff1f;还在引入大型 UI 库只为做个浮层&#xff1f;现在浏览器已经支持了真正原生的「弹出层 API」&#xff0c;一行 HTMLCSS 就能构建可交互、无障…

CCS-MSPM0G3507-6-模块篇-OLED的移植

前言基础篇结束&#xff0c;接下来我们来开始进行模块驱动如果懂把江科大的OLED移植成HAL库&#xff0c;那其实也没什么难首先配置OLED的引脚这里我配置PA16和17为推挽输出&#xff0c;PA0和1不要用&#xff0c;因为只有那两个引脚能使用MPU6050 根据配置出来的引脚&#xff0c…

意识边界的算法战争—脑机接口技术重构人类认知的颠覆性挑战

一、神经解码的技术奇点当瘫痪患者通过脑电波操控机械臂饮水&#xff0c;当失语者借由皮层电极合成语音&#xff0c;脑机接口&#xff08;BCI&#xff09;正从医疗辅助工具演变为认知增强的潘多拉魔盒。这场革命的核心突破在于神经信号解析精度的指数跃迁&#xff1a;传统脑电图…

详解彩信 SMIL规范

以下内容将系统地讲解彩信 MMS&#xff08;Multimedia Messaging Service&#xff09;中使用的 SMIL&#xff08;Synchronized Multimedia Integration Language&#xff09;规范&#xff0c;涵盖历史、语法结构、在彩信中的裁剪与扩展、常见实现细节以及最佳实践。末尾附示例代…

《红蓝攻防:构建实战化网络安全防御体系》

《红蓝攻防&#xff1a;构建实战化网络安全防御体系》文章目录第一部分&#xff1a;网络安全的攻防全景 1、攻防演练的基础——红队、蓝队、紫队 1.1 红队&#xff08;攻击方&#xff09; 1.2 蓝队&#xff08;防守方&#xff09; 1.3 紫队&#xff08;协调方&#xff09; 2、5…

MFC UI大小改变与自适应

文章目录窗口最大化库EasySize控件自适应大小窗口最大化 资源视图中开放最大化按钮&#xff0c;添加窗口样式WS_MAXIMIZEBOX。发送大小改变消息ON_WM_SIZE()。响应大小改变。 void CDlg::OnSize(UINT nType, int cx, int cy) {CDialog::OnSize(nType, cx, cy);//获取改变后窗…

【Linux网络】:HTTP(应用层协议)

目录 一、HTTP 1、URL 2、协议格式 3、请求方法 4、状态码 5、Header信息 6、会话保持Cookie 7、长连接 8、简易版HTTP服务器代码 一、HTTP 我们在编写网络通信代码时&#xff0c;我们可以自己进行协议的定制&#xff0c;但实际有很多优秀的工程师早就写出了许多非常…

C++-linux 7.文件IO(三)文件元数据与 C 标准库文件操作

文件 IO 进阶&#xff1a;文件元数据与 C 标准库文件操作 在 Linux 系统中&#xff0c;文件操作不仅涉及数据的读写&#xff0c;还包括对文件元数据的管理和高层库函数的使用。本文将从文件系统的底层存储机制&#xff08;inode 与 dentry&#xff09;讲起&#xff0c;详细解析…

WordPress Ads Pro Plugin本地文件包含漏洞(CVE-2025-4380)

免责声明 本文档所述漏洞详情及复现方法仅限用于合法授权的安全研究和学术教育用途。任何个人或组织不得利用本文内容从事未经许可的渗透测试、网络攻击或其他违法行为。 前言:我们建立了一个更多,更全的知识库。每日追踪最新的安全漏洞,追中25HW情报。 更多详情: http…