目录

1.使用全连接网络训练和验证MNIST数据集

2.使用全连接网络训练和验证CIFAR10数据集


1.使用全连接网络训练和验证MNIST数据集

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim
from PIL import Image
import os# 数据预处理
transform = transforms.Compose([transforms.ToTensor()])# 数据准备
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
eval_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=512, shuffle=True)# 定义网络结构
class MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.fc1 = nn.Linear(784, 256)self.bn1 = nn.BatchNorm1d(256)self.relu = nn.ReLU()self.fc2 = nn.Linear(256, 128)self.bn2 = nn.BatchNorm1d(128)self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = self.bn1(self.fc1(x))x = self.relu(x)x = self.bn2(self.fc2(x))x = self.relu(x)x = self.fc3(x)return xmodel = MyNet()
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练
def train(model, train_loader, epochs):model.train()for epoch in range(epochs):correct = 0for data, target in train_loader:output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()correct /= len(train_loader.dataset)print(f'Train Epoch:  {epoch} , loss: {loss.item():.4f},acc:{correct:.4f}')# 验证
def eval(model, eval_loader):model.eval()eval_loss = 0correct = 0with torch.no_grad():for data, target in eval_loader:output = model(data)eval_loss += criterion(output, target).item()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()eval_loss /= len(eval_loader.dataset)acc = 100.0 * correct / len(eval_loader.dataset)print(f'loss: {eval_loss:.4f}, acc: {acc:.4f}')# 保存模型
def save_model():torch.save(model.state_dict(), 'mnist_fc_model.pt')# 预测
def predict(img_path):model = MyNet()model.load_state_dict(torch.load('mnist_fc_model.pt'))model.eval()img = Image.open(img_path).convert('L')transform = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor()])t_img = transform(img).unsqueeze(0)print(t_img.shape)with torch.no_grad():output = model(t_img)_, predicted = torch.max(output.data, 1)print(predicted.item())epochs = 5train(model, train_loader, epochs)
eval(model, eval_loader)save_model()img_path = './img/7.png'
predict(img_path)

2.使用全连接网络训练和验证CIFAR10数据集

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])# 数据准备
train_dataset = datasets.CIFAR10(root='./cifar10', train=True, transform=transform, download=True)
eval_dataset = datasets.CIFAR10(root='./cifar10', train=False, transform=transform, download=True)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=512, shuffle=True)# 定义网络结构
class MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.fc1 = nn.Linear(32 * 32 * 3, 1024)self.bn1 = nn.BatchNorm1d(1024)self.dropout1 = nn.Dropout(0.3)self.fc2 = nn.Linear(1024, 512)self.bn2 = nn.BatchNorm1d(512)self.dropout2 = nn.Dropout(0.3)self.fc3 = nn.Linear(512, 256)  # 增加第三层self.bn3 = nn.BatchNorm1d(256)self.fc4 = nn.Linear(256, 10)self.relu = nn.ReLU()def forward(self, x):x = x.view(-1, 32 * 32 * 3)x = self.dropout1(self.bn1(self.fc1(x)))x = self.relu(x)x = self.dropout2(self.bn2(self.fc2(x)))x = self.relu(x)x = self.bn3(self.fc3(x))x = self.relu(x)x = self.fc4(x)return xmodel = MyNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)def train(model, train_loader, epochs):model.train()for epoch in range(epochs):correct = 0for data, target in train_loader:data, target = data.to(device), target.to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()correct /= len(train_loader.dataset)print(f'Train Epoch:  {epoch} , loss: {loss.item():.4f},acc:{correct:.4f}')def eval(model, eval_loader):model.eval()eval_loss = 0correct = 0with torch.no_grad():for data, target in eval_loader:data, target = data.to(device), target.to(device)output = model(data)eval_loss += criterion(output, target).item()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()eval_loss /= len(eval_loader.dataset)acc = 100.0 * correct / len(eval_loader.dataset)print(f'loss: {eval_loss:.4f}, acc: {acc:.4f}')epochs = 25train(model, train_loader, epochs)
eval(model, eval_loader)

思考:为什么CIFAR10数据集的准确率很低?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/916043.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/916043.shtml
英文地址,请注明出处:http://en.pswp.cn/news/916043.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

嵌入式学习的第三十四天-进程间通信-TCP

一、TCPTCP : 传输控制协议 传输层1. TCP特点(1).面向连接,避免部分数据丢失 (2).安全、可靠 (3).面向字节流 (4).占用资源开销大2.TCP安全可靠机制三次握手:指建立tcp连接时,需要客户端和服务端总共发送三次报文确认连接。确保双方均已做好 收发…

【爬虫】06 - 自动化爬虫selenium

自动化爬虫selenium 文章目录自动化爬虫selenium一:Selenium简介1:什么是selenium2:安装准备二:元素定位1:id 定位2:name 定位3:class 定位4:tag 定位5:xpath 定位(最常用…

2025年中国移动鸿鹄大数据实训营(大数据方向)kafka讲解及实践-第2次作业指导

书接上回,第二次作业比较容易解决,我问了ai,让他对我进行指导,按照它提供的步骤,我完成了本次实验,接下来我会标注出需要注意的细节,指导大家完成此次任务。 🎯 一、作业目标 ✔️…

三十七、【高级特性篇】定时任务:基于 APScheduler 实现测试计划的灵活调度

三十七、【高级特性篇】定时任务:基于 APScheduler 实现测试计划的灵活调度 前言 准备工作 第一部分:后端实现 - `APScheduler` 集成与任务调度 1. 安装 `django-apscheduler` 2. 配置 `django-apscheduler` 3. 数据库迁移 4. 创建调度触发函数 5. 启动 APScheduler 调度器 6…

RabbitMQ--消息顺序性

看本章之前强烈建议先去看博主的这篇博客 RabbitMQ--消费端单线程与多线程-CSDN博客 一、消息顺序性概念 消息顺序性是指消息在生产者发送的顺序和消费者接收处理的顺序保持一致。 二、RabbitMQ 顺序性保证机制 情况顺序保证情况备注单队列,单消费者消息严格按发送顺…

.net core接收对方传递的body体里的json并反序列化

1、首先我在通用程序里有一个可以接收对象型和数组型json串的反序列化方法public static async Task<Dictionary<string, string>> AllParameters(this HttpRequest request){Dictionary<string, string> parameters QueryParameters(request);request.Enab…

(10)机器学习小白入门 YOLOv:YOLOv8-cls 模型评估实操

YOLOv8-cls 模型评估实操 (1)机器学习小白入门YOLOv &#xff1a;从概念到实践 (2)机器学习小白入门 YOLOv&#xff1a;从模块优化到工程部署 (3)机器学习小白入门 YOLOv&#xff1a; 解锁图片分类新技能 (4)机器学习小白入门YOLOv &#xff1a;图片标注实操手册 (5)机器学习小…

Vue 脚手架基础特性

一、ref属性1.被用来给元素或子组件注册引用信息&#xff08;id的替代者&#xff09;2.应用在html标签上获取的是真实DOM元素&#xff0c;用在组件标签上是组件实例对象3.使用方式&#xff1a;(1).打标识&#xff1a;<h1 ref"xxx">...</h1> 或 <Schoo…

Ubuntu安装k8s集群入门实践-v1.31

准备3台虚拟机 在自己电脑上使用virtualbox 开了3台1核2G的Ubuntu虚拟机&#xff0c;你可以先安装好一台&#xff0c;安装第一台的时候配置临时调高到2核4G&#xff0c;安装速度会快很多&#xff0c;安装完通过如下命令关闭桌面&#xff0c;能够省内存占用&#xff0c;后面我们…

Word Press富文本控件的保存

新建富文本编辑器&#xff0c;并编写save方法如下&#xff1a; edit方法&#xff1a; export default function Edit({ attributes, setAttributes }) {return (<><div { ...useBlockProps() }><RichTexttagNameponChange{ (value) > setAttributes({ noteCo…

【编程趣味游戏】:基于分支循环语句的猜数字、关机程序

&#x1f31f;菜鸟主页&#xff1a;晨非辰的主页 &#x1f440;学习专栏&#xff1a;《C语言学习》 &#x1f4aa;学习阶段&#xff1a;C语言方向初学者 ⏳名言欣赏&#xff1a;"编程的核心是实践&#xff0c;而非空谈" 目录 1. 游戏1--猜数字 1.1 rand函数 1.2 sr…

UE5 UI 控件切换器

文章目录分类作用属性分类 面板 作用 可以根据索引切换要显示哪个子UI&#xff0c;可以拥有多个子物体&#xff0c;但是任何时间只能显示一个 属性 在这里指定要显示的UI的索引

scikit-learn 包

文章目录scikit-learn 包核心功能模块案例其他用法**常用功能详解****(1) 分类任务示例&#xff08;SVM&#xff09;****(2) 回归任务示例&#xff08;线性回归&#xff09;****(3) 聚类任务示例&#xff08;K-Means&#xff09;****(4) 特征工程&#xff08;PCA降维&#xff0…

Excel 将数据导入到SQLServer数据库

一般系统上线前期都会导入期初数据&#xff0c;业务人员一般要求你提供一个Excel模板&#xff0c;业务人员根据要求整理数据。SQLServer管理工具是支持批量导入数据的&#xff0c;所以我们可以使用该工具导入期初。Excel格式 第一行为字段1、连接登入的数据库并且选中你需要导入…

剪枝和N皇后在后端项目中的应用

剪枝算法&#xff08;Pruning Algorithm&#xff09; 生活比喻&#xff1a;就像修剪树枝一样&#xff0c;把那些明显不会结果的枝条提前剪掉&#xff0c;节省养分。 在后端项目中的应用场景&#xff1a; 搜索优化&#xff1a;在商品搜索中&#xff0c;如果某个分类下没有符合条…

cocos 2d游戏中多边形碰撞器会触发多次,怎么解决

子弹打到敌机 一发子弹击中&#xff0c;碰撞回调多次执行 我碰撞组件原本是多边形碰撞组件 PolygonCollider2D&#xff0c;我改成盒碰撞组件BoxCollider2D 就好了 用前端的节流方式。或者loading处理逻辑。我测试过了&#xff0c;是可以 本来就是多次啊,设计上貌似就是这样的…

Kubernetes环境中GPU分配异常问题深度分析与解决方案

Kubernetes环境中GPU分配异常问题深度分析与解决方案 一、问题背景与核心矛盾 在基于Kubernetes的DeepStream应用部署中&#xff0c;GPU资源的独占性分配是保障应用性能的关键。本文将围绕一个典型的GPU分配异常问题展开分析&#xff1a;多个请求GPU的容器本应独占各自的GPU&…

Django与模板

我叫补三补四&#xff0c;很高兴见到大家&#xff0c;欢迎一起学习交流和进步今天来讲一讲视图Django与模板文件工作流程模板引擎&#xff1a;主要参与模板渲染的系统。内容源&#xff1a;输入的数据流。比较常见的有数据库、XML文件和用户请求这样的网络数据。模板&#xff1a…

日本上市IT企业|8月25日将在大连举办赴日it招聘会

株式会社GSD的核心战略伙伴贝斯株式会社&#xff0c;将于2025年8月25日在大连香格里拉大酒店商务会议室隆重举办赴日技术人才专场招聘会。本次招聘会面向全国范围内的优秀IT人才&#xff0c;旨在为贝斯株式会社东京本社长期发展招募优质的系统开发与管理人才。招聘计划&#xf…

低功耗设计双目协同画面实现光学变焦内带AI模型

低功耗设计延长续航&#xff0c;集成储能模块保障阴雨天气下的铁塔路线的安全一、智能感知与识别技术 多光谱融合监控结合可见光、红外热成像、激光补光等技术&#xff0c;实现全天候监测。例如&#xff0c;红外热成像可穿透雨雾监测山火隐患&#xff0c;激光补光技术则解决夜间…