基础神经网络模型搭建

【Pytorch】数据集的加载和处理(一)

【Pytorch】数据集的加载和处理(二)

损失函数计算模型输出和目标之间的距离。通过torch.nn 包可以定义一个负对数似然损失函数,负对数似然损失对于训练具有多个类的分类问题比较有效,负对数似然损失函数的输入为对数概率,而在模型搭建的输出层部分接触过log_softmax,它能从模型中获取对数概率

目录

基础模型搭建

数据集的加载和处理

定义损失函数

定义优化器

训练并评估模型


基础模型搭建

import torch
from torch import nn
import torch.nn.functional as F
class Net(nn.Module):def __init__(self):super(Net, self).__init__()def forward(self, x):pass
def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5, 1)self.conv2 = nn.Conv2d(20, 50, 5, 1)self.fc1 = nn.Linear(4*4*50, 500)self.fc2 = nn.Linear(500, 10)
def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2) x = x.view(-1, 4*4*50)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)
Net.__init__ = __init__
Net.forward = forward
model = Net()

检查搭建情况 

print(model)

原位置为cpu 

 转移至所需CUDA设备

device = torch.device("cuda:0")
model.to(device)
print(next(model.parameters()).device)

数据集的加载和处理

导入MNIST训练数据集和验证数据集并处理

from torch import nn
from torchvision import datasets
from torch.utils.data import TensorDataset
path2data="./data"
train_data=datasets.MNIST(path2data, train=True, download=True)
x_train, y_train=train_data.data,train_data.targets
val_data=datasets.MNIST(path2data, train=False, download=True)
x_val,y_val=val_data.data, val_data.targets
if len(x_train.shape)==3:x_train=x_train.unsqueeze(1)
print(x_train.shape)
if len(x_val.shape)==3:x_val=x_val.unsqueeze(1)
print(x_val.shape)
train_ds = TensorDataset(x_train, y_train)
val_ds = TensorDataset(x_val, y_val)
for x,y in train_ds:print(x.shape,y.item())breakfrom torch.utils.data import DataLoader 
train_dl = DataLoader(train_ds, batch_size=8)
val_dl = DataLoader(val_ds, batch_size=8)

定义损失函数

损失函数计算模型输出和目标之间的距离。Pytorch 中的 optim 包提供了各种优化算法的实现,例如SGD、Adam、RMSprop 等。

通过torch.nn 包可以定义一个负对数似然损失函数,负对数似然损失对于训练具有多个类的分类问题比较有效,负对数似然损失函数的输入为对数概率,而在模型搭建的输出层部分接触过log_softmax,它能从模型中获取对数概率。

loss_func = nn.NLLLoss(reduction="sum")
for xb, yb in train_dl:# move batch to cuda devicexb=xb.type(torch.float).to(device)yb=yb.to(device)out=model(xb)loss = loss_func(out, yb)print (loss.item())break

得到一个测试值 

定义优化器

定义一个Adam优化器,优化器的输入是模型参数和学习率

from torch import optim
opt = optim.Adam(model.parameters(), lr=1e-4)

通过opt .step()自动更新模型参数,同时需要注意计算下一批的梯度之前需将梯度归0

opt.step()
opt.zero_grad()

训练并评估模型

定义一个辅助函数 loss_batch来计算每个小批量的损失值。函数的 opt 参数引用优化器,如果给定,则计算梯度并按小批量更新模型参数。

def  loss_batch(loss_func,  xb,  yb,yb_h,  opt=None): loss = loss_func(yb_h, yb) metric_b =  metrics_batch(yb,yb_h) if opt is  not None: loss.backward()opt.step()opt.zero_grad()return loss.item(),metric_b

 定义一个辅助函数metrics_batch来计算每个小批量的性能指标,这里以准确率作为分类任务的性能指标,并使用 output.argmax 来获取概率最高的预测类

def metrics_batch(target, output):pred = output.argmax(dim=1, keepdim=True)corrects=pred.eq(target.view_as(pred)).sum().item()return corrects

定义一个辅助函数loss_epoch来计算整个数据集的损失和指标值。使用数据加载器对象获取小批量,将它们提供给模型,并计算每个小批量的损失和指标,通过两个运行变量来分别添加损失值和指标值。

def loss_epoch(model,loss_func,dataset_dl,opt=None):loss=0.0metric=0.0len_data=len(dataset_dl.dataset)for xb, yb in dataset_dl:xb=xb.type(torch.float).to(device)yb=yb.to(device)yb_h=model(xb)loss_b,metric_b=loss_batch(loss_func, xb, yb,yb_h, opt)loss+=loss_bif metric_b is not None:metric+=metric_bloss/=len_datametric/=len_datareturn loss, metric

最后,定义一个辅助函数train_val来训练多个时期的模型。在每个时期使用验证数据集评估模型的性能。训练和评估需要分别使用 model.train()和 model.eval()模式。torch.no_grad()可以阻止 autograd 在评估期间计算梯度。

def train_val(epochs, model, loss_func, opt, train_dl, val_dl):for epoch in range(epochs):model.train()train_loss,train_metric=loss_epoch(model,loss_func,train_dl,opt)model.eval()with torch.no_grad():val_loss, val_metric=loss_epoch(model,loss_func,val_dl)accuracy=100*val_metricprint("epoch: %d, train loss: %.6f, val loss: %.6f,accuracy: %.2f" %(epoch, train_loss,val_loss,accuracy))

 设定时期数为5,调用函数进行训练和评估

num_epochs=5
train_val(num_epochs, model, loss_func, opt, train_dl, val_dl)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/92987.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/92987.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/92987.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电子书转PDF格式教程,实现epub转PDF步骤

EPUB 格式属于流式文档,在屏幕尺寸各异的设备上都能自动适配显示。然而,要是你使用的是特定的阅读设备,像打印机、不支持 EPUB 格式的电子阅读器(例如某些早期的 Kindle 型号),或者需要在固定尺寸的屏幕上展…

Java学习第六十九部分——RabbitMQ

目录 一、前言提要 二、基本信息 1. 关键定义 2. 核心角色 3. 交换机类型 三、消息生命周期与可靠性机制 四、生态集成——与Java 五、应用场景 六、性能与选型对比 七、生产级最佳实践——基于Java 八、应用场景 九、一句话总结 一、前言提要 Spring AMQP是…

MDAC2.6问题解决指南:解决.NET Framework数据访问烦恼

MDAC2.6问题解决指南:解决.NET Framework数据访问烦恼 【下载地址】MDAC2.6问题解决指南 MDAC 2.6 问题解决指南为您提供了针对.NET Framework数据提供程序要求使用Microsoft Data Access Components (MDAC) 2.6或更高版本的全面解决方案。本指南详细介绍了如何在开…

会话跟踪模式

一、图片讲了什么?这张图片主要讲的是“会话跟踪技术”,也就是网站怎么记住你是谁、你做了什么。1. 什么是会话?会话(Session)就像你和网站的一次聊天,从你打开网页到关闭网页,这段时间就是一次…

C语言开发工具Win-TC

如你所知,WIN-TC是一个turbo C2 WINDOWS 平台开发工具,最大特点是支持中文界面,支持鼠标操作,程序段复制,为初学 c 语言、对高等编程环境不熟悉的同志们非常有帮助。该软件使用 turbo C2 为内核,提供 WINDO…

lwIP学习记录5——裸机lwIP工程学习后的总结

1、ping包的TTL生存时间如何修改当我们把工程烧录到板子上是,我们对板子的IP进行ping包,看到信息如下图这时候我好奇TTL是什么作用,为什么有的设备是64有的设备是128有的是255?解:TTL(Time to Live&#xf…

利用Trae将原型图转换为可执行的html文件,感受AI编程的魅力

1、UI设计原型效果2、通过Tare对话生成的效果图(5分钟左右)3、查资料做的效果图(30分钟左右))通过以上对比,显然差别不多能满足要求,只需要在继续优化就能搞定; 4、Trae生成的源码&l…

Chessboard and Queens

题目描述Your task is to place eight queens on a chessboard so that no two queens are attacking each other. As an additional challenge, each square is either free or reserved, and you can only place queens on the free squares. However, the reserved squares …

菜鸟教程R语言一二章阅读笔记

菜鸟教程R语言一二章阅读笔记 一.R语言基础教程 R 语言是为数学研究工作者设计的一种数学编程语言,主要用于统计分析、绘图、数据挖掘。侧重于数学工作者 R语言特点如下: R 语言环境软件属于 GNU 开源软件,兼容性好、使用免费 语法十分有利于…

Tactile-VLA:解锁视觉-语言-动作模型的物理知识,实现触觉泛化

25年7月来自清华、中科大和上海交大的论文“Tactile-VLA: Unlocking Vision-Language- Action Model’s Physical Knowledge For Tactile Generalization ”。 视觉-语言-动作 (VLA) 模型已展现出卓越的成就,这得益于其视觉-语言组件丰富的隐性知识。然而&#xff0…

HTML初学者第五天

<1>表格标签1.1基本语法<table><tr><td>单元格内的文字</td>...</tr>... </table>1.<table></table>是用于定义表格的标签。2.<tr></tr>标签用于定义表格中的行&#xff0c;必须嵌套在<table></ta…

FastAPI入门:demo、路径参数、查询参数

demo from fastapi import FastAPIapp FastAPI()app.get("/") async def root():return {"message": "Hello World"}在终端运行 fastapi dev main.py结果如下&#xff1a;打开http://127.0.0.1:8000&#xff1a;交互式API文档&#xff1a;位于h…

pytest中的rerunfailures的插件(失败重试)

目录 1-- 安装rerunfailures插件 2-- rerunfailures的使用 3-- 重试案例 安装rerunfailures插件 pip install pytest-rerunfailures点击左下角的控制台面板 输入 pip install pytest-rerunfailures 出现上图的情况就算安装完成了 rerunfailures的使用 可以添加一下参数使用&…

SpringMVC——建立连接

建立连接 将用户&#xff08;浏览器&#xff09;和java程序连接起来&#xff0c;也就是访问一个地址能够调用到我们的Spring程序。在 Spring MVC 中使用 RequestMapping来实现URL 路由映射&#xff0c;也就是浏览器连接程序的作用。 1.RequestMapping注解介绍 RequestMapping…

蘑菇云路由器使用教程

1: 手机连接路由器的Wi-Fi&#xff0c;在浏览器输入背面IP地址&#xff1a;192.168.132.1进入路由管理界面1.1: 电脑连接路由器网线在浏览器输入背面IP地址&#xff1a;192.168.132.1进入路由管理界面账号&#xff1a;admin密码&#xff1a;123456782:选择上网模式2.1&#xff…

ubuntu的tar解压指令相关

1. 指令说明参数作用-xextract&#xff0c;解包-z通过 gzip 解压&#xff08;.tar.gz、.tgz&#xff09;-vverbose&#xff0c;显示过程-ffile&#xff0c;后面紧跟压缩包文件名2. 什么时候用z参数场景是否加 -z结果.tar.gz / .tgz✅ 必须加 -z正常解压.tar.gz / .tgz❌ 没加 -…

车载诊断刷写 --- Flash关于擦除和写入大小

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 简单,单纯,喜欢独处,独来独往,不易合同频过着接地气的生活,除了生存温饱问题之外,没有什么过多的欲望,表面看起来很高冷,内心热情,如果你身…

【Verilog HDL 入门教程】 —— 学长带你学Verilog(基础篇)

文章目录一、Verilog HDL 概述1、Verilog HDL 是什么2、Verilog HDL产生的背景3、Verilog HDL 和 VHDL的区别二、Verilog HDL 基础知识1、Verilog HDL 语言要素1.1、命名规则1.2、注释符1.3、关键字1.4、数值1.4.1、整数及其表示1.4.2、实数及其表示1.4.3、字符串及其表示2、数…

SQL Developer Data Modeler:一款免费跨平台的数据库建模工具

SQL Developer Data Modeler 是由 Oracle 公司开发的一款免费的图形化数据建模和数据库设计工具&#xff0c;用于创建、浏览和编辑逻辑模型、关系模型、物理模型、多维模型和数据类型模型。 SQL Developer Data Modeler 既是一个独立的应用程序&#xff0c;同时也被集成到了 Or…

CSS面试题及详细答案140道之(21-40)

《前后端面试题》专栏集合了前后端各个知识模块的面试题&#xff0c;包括html&#xff0c;javascript&#xff0c;css&#xff0c;vue&#xff0c;react&#xff0c;java&#xff0c;Openlayers&#xff0c;leaflet&#xff0c;cesium&#xff0c;mapboxGL&#xff0c;threejs&…