第一部分——起手式

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimuse_cuda = torch.cuda.is_available()if use_cuda:device = torch.device("cuda")
else: device = torch.device("cpu")print(f"Using device {device}")

第二部分——计算均值、方差

transform = transforms.Compose([#将数据转换成Tensor张量transforms.ToTensor()
]
)#读取数据
datasets1 = datasets.MNIST('./data',train=True,download = True, transform =transform)
datasets1_len = len(datasets1)#设置数据加载器、批次大小全部图片
train_loader = torch.utils.data.DataLoader(datasets1, batch_size=datasets1_len, shuffle = True)#循环训练集 DataLoader,0是起始索引
for batch_idx, data in enumerate(train_loader,0):inputs, targets = data #将训练集图(60000,1,28,28)像转换为(60000*1,28*28)的二维数组,-1 是占位符用于自动计算维度大小x = inputs.view(-1,28*28)#计算均值-0.3081x_mean =x.mean().item()#计算标准差-0.1307x_std =x.std().item()print(f"mean: {x_mean}, std: {x_std}")
#mean: 0.13066047430038452, std: 0.30810782313346863

第三部分——网络模型

#自定义类构建模型、继承torch.nn.module初始化网络模型
class Net(torch.nn.Module):def __init__(self):super(Net,self).__init__()self.fc1 = torch.nn.Linear(784, 128)#Liner线性加权求和,784是input,128是当前层神经元个数self.dropout = torch.nn.Dropout(p = 0.2)self.fc2 = torch.nn.Linear(128, 10)#input=上一层的神经元个数,输出是10,做一个0-9的10分类def forward(self, x):#把x的每条数据展成一维数组28*28=784x = torch.flatten(x,1)x = self.fc1(x)x = F.relu(x)x = self.dropout(x)x = self.fc2(x)output = F.log_softmax(x, dim=1)#做完softmax然后取log,便于后续计算损失函数(损失函数需要取log)return output       

第四部分——训练策略、测试策略

#创建实例
model = Net().to(device)#每个批次如何训练
def train_step(data, target, model, optimizer):optimizer.zero_grad()#梯度归零output = model(data)loss = F.nll_loss(output,target)#nll是负对数似然,output是y_head,target是y_trueloss.backward()#反向传播求梯度optimizer.step()#根据梯度更新网络return loss#每个批次如何测试
def test(data, target, model, test_loss, correct):output = model(data)#累积计算每个批次的损失test_loss += F.nll_loss(output,target,reduction='sum').item()#获取对数概率最大对应的索引,dim=1:表示选取每一行概率最大的索引,keepdim = True 表示维度保持不变pred = output.argmax(dim=1, keepdim=True)#统计预测值与正确值相同的数量,eq在做比较,返回True/Fasle,sum是求和,item是将数据取出来(原来是tensor)correct += pred.eq(target.view_as(pred)).sum().item()return test_loss, correct

第五部分——开始训练

#真正分轮次训练
EPOCHS = 5#调参优化器,lr是学习率
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)for epoch in range(EPOCHS):model.train()#设置为训练模式:BN层计算的是均值方差for batch_index, (data, target) in enumerate(train_loader):data, target = data.to(device),target.to(device)loss = train_step(data, target, model, optimizer)#每隔10个批次打印一次信息if batch_index%10 ==0:print('Train Epoch:{epoch} [{batch}/{total_batch} {percent}%] train_loss:{loss:.3f}'.format(epoch=epoch+1,#第几个批次batch = batch_index*len(data),#已跑多少数据total_batch = len(train_loader.dataset),#当前轮总数据条数percent = 100.0*batch_index/len(train_loader),#当前轮数已占训练集百分比loss = loss.item()#损失是tensor,转为数值))       #设置为测试模式:BN层计算的是滑动平均,Droput层不进行预测model.eval()test_loss = 0correct = 0with torch.no_grad():#不求梯度for data, target in test_loader:data, target = data.to(device), target.to(device)test_loss, correct = test_step(data, target, model, test_loss, correct)    test_loss = test_loss/len(test_loader.dataset)print('\n Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(test_loss,correct,len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

完整代码

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimuse_cuda = torch.cuda.is_available()if use_cuda:device = torch.device("cuda")
else: device = torch.device("cpu")print(f"Using device {device}")#数据预处理
transform = transforms.Compose([#将数据转换成Tensor张量transforms.ToTensor(),#图片数据归一化:0.1307是均值,0.3081是方差。数值和数据集有关系transforms.Normalize((0.1307),(0.3081))
]
)#读取数据
datasets1 =datasets.MNIST('./data',train=True,download = True, transform =transform)
datasets2 =datasets.MNIST('./data',train=False,download = True, transform =transform)#设置数据加载器、批次大小128、是否打乱顺序-是
train_loader = torch.utils.data.DataLoader(datasets1, batch_size=128, shuffle = True)
#测试批次可以大,测试集不需要打乱顺序-False
test_loader = torch.utils.data.DataLoader(datasets2, batch_size =1000,shuffle = False)#自定义类构建模型、继承torch.nn.module初始化网络模型
class Net(torch.nn.Module):def __init__(self):super(Net,self).__init__()self.fc1 = torch.nn.Linear(784, 128)#Liner线性加权求和,784是input,128是当前层神经元个数self.dropout = torch.nn.Dropout(p = 0.2)self.fc2 = torch.nn.Linear(128, 10)#input=上一层的神经元个数,输出是10,做一个0-9的10分类def forward(self, x):#把x的每条数据展成一维数组28*28=784x = torch.flatten(x,1)x = self.fc1(x)x = F.relu(x)x = self.dropout(x)x = self.fc2(x)output = F.log_softmax(x, dim=1)#做完softmax然后取log,便于后续计算损失函数(损失函数需要取log)return output       #创建实例
model = Net().to(device)#每个批次如何训练
def train_step(data, target, model, optimizer):optimizer.zero_grad()#梯度归零output = model(data)loss = F.nll_loss(output,target)#nll是负对数似然,output是y_head,target是y_trueloss.backward()#反向传播求梯度optimizer.step()#根据梯度更新网络return loss#每个批次如何测试
def test_step(data, target, model, test_loss, correct):output = model(data)#累积计算每个批次的损失test_loss += F.nll_loss(output,target,reduction='sum').item()#获取对数概率最大对应的索引,dim=1:表示选取每一行概率最大的索引,keepdim = True 表示维度保持不变pred = output.argmax(dim=1, keepdim=True)#统计预测值与正确值相同的数量,eq在做比较,返回True/Fasle,sum是求和,item是将数据取出来(原来是tensor)correct += pred.eq(target.view_as(pred)).sum().item()return test_loss, correct#真正分轮次训练
EPOCHS = 5#调参优化器,lr是学习率
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)for epoch in range(EPOCHS):model.train()#设置为训练模式:BN层计算的是均值方差for batch_index, (data, target) in enumerate(train_loader):data, target = data.to(device),target.to(device)loss = train_step(data, target, model, optimizer)#每隔10个批次打印一次信息if batch_index%10 ==0:print('Train Epoch:{epoch} [{batch}/{total_batch} {percent}%] train_loss:{loss:.3f}'.format(epoch=epoch+1,#第几个批次batch = batch_index*len(data),#已跑多少数据total_batch = len(train_loader.dataset),#当前轮总数据条数percent = 100.0*batch_index/len(train_loader),#当前轮数已占训练集百分比loss = loss.item()#损失是tensor,转为数值))       #设置为测试模式:BN层计算的是滑动平均,Droput层不进行预测model.eval()test_loss = 0correct = 0with torch.no_grad():#不求梯度for data, target in test_loader:data, target = data.to(device), target.to(device)test_loss, correct = test_step(data, target, model, test_loss, correct)    test_loss = test_loss/len(test_loader.dataset)print('\n Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(test_loss,correct,len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/95774.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/95774.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/95774.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JAVA高级】实现word转pdf 实现,源码概述。深坑总结

之前的需求做好后,需求,客户突发奇想。要将生成的word转为pdf! 因为不想让下载文档的人改动文档。 【JAVA】实现word添加标签实现系统自动填入字段-CSDN博客 事实上这个需求难度较高,并不是直接转换就行的 word文档当中的很多东西都需要处理 public static byte[] gener…

数据驱动测试提升自动化效率

测试工程师老王盯着满屏重复代码叹气:“改个搜索条件要重写20个脚本,这班加到啥时候是个头?” 隔壁组的小李探过头:“试试数据驱动呗,一套脚本吃遍所有数据,我们组上周测了300个组合都没加班!”…

模板引用(Template Refs)全解析2

三、v-for 中的模板引用 当在 v-for 中使用模板引用时,引用的 value 会自动变为一个数组,包含列表中所有元素/组件的引用(需 Vue 3.5+ 版本,旧版需手动处理且顺序不保证)。 1. 基本用法(Vue 3.5+) <script setup> import { ref, useTemplateRef, onMounted } f…

【Linux系统】进程间通信:System V IPC——共享内存

前文中我们介绍了管道——匿名管道和命名管道来实现进程间通信&#xff0c;在介绍怎么进行通信时&#xff0c;我们有提到过不止管道的方式进行通信&#xff0c;还有System V IPC&#xff0c;今天这篇文章我们就来学习一下System V IPC中的共享内存1. 为何引入共享内存&#xff…

[优选算法专题二滑动窗口——最大连续1的个数 III]

题目链接 最大连续1的个数 III 题目描述 题目解析 问题本质 输入&#xff1a;二进制数组nums&#xff08;只包含 0 和 1&#xff09;和整数k操作&#xff1a;最多可以将k个 0 翻转成 1目标&#xff1a;找到翻转后能得到的最长连续 1 的子数组长度 这个问题的核心是要找到一…

C#单元测试(xUnit + Moq + coverlet.collector)

C#单元测试 xUnit Moq coverlet.collector 1.添加库 MlyMathLib 2.编写库函数内容 using System;namespace MlyMathLib {public interface IUserRepo{string GetName(int id);}public class UserService{private readonly IUserRepo _repo;public UserService(IUserRepo repo…

【数据库】Oracle学习笔记整理之五:ORACLE体系结构 - 参数文件与控制文件(Parameter Files Control Files)

Oracle体系结构 - 参数文件与控制文件&#xff08;Parameter Files & Control Files&#xff09; 参数文件与控制文件是Oracle数据库的“双核基石”&#xff1a;参数文件是实例的“启动配置中心”&#xff0c;定义运行环境与规则&#xff1b;控制文件是数据库的“物理元数据…

GDB典型开发场景深度解析

GDB典型开发场景深度解析 以下是开发过程中最常见的GDB使用场景&#xff0c;结合具体实例和调试技巧&#xff0c;帮助开发者高效解决实际问题&#xff1a;一、崩溃分析&#xff08;Core Dump调试&#xff09; 场景&#xff1a;程序突然崩溃&#xff0c;生成了core文件 # 启动调…

存储、硬盘、文件系统、 IO相关常识总结

目录 &#xff08;一&#xff09;存储 &#xff08;1&#xff09;定义 &#xff08;2&#xff09;分类 &#xff08;二&#xff09;硬盘 &#xff08;1&#xff09;容量&#xff08;最主要的参数&#xff09; &#xff08;2&#xff09;转速 &#xff08;3&#xff09;访…

docker安装mongodb及java连接实战

1.docker部署mongodb docker run --name mongodb -d -p 27017:27017 -v /data/mongodbdata:/data/db -e MONGO_INITDB_ROOT_USERNAMEtestmongo -e MONGO_INITDB_ROOT_PASSWORDtest123456 mongodb:4.0.112.项目实战 <dependencies><dependency><groupId>org.m…

Java设计模式之《工厂模式》

目录 1、介绍 1.1、定义 1.2、优缺点 1.3、使用场景 2、实现 2.1、简单工厂模式 2.2、工厂方法模式 2.3、抽象工厂模式 3、小结 前言 在面向对象编程中&#xff0c;创建对象实例最常用的方式就是通过 new 操作符构造一个对象实例&#xff0c;但在某些情况下&#xff0…

【异步】js中异步的实现方式 async await /Promise / Generator

JS的异步相关知识 js里面一共有以下异步的解决方案 传统的回调 省略 。。。。 生成器 Generator 函数是 ES6 提供的一种异步编程解决方案, 语法上&#xff0c;首先可以把它理解成&#xff0c;Generator 函数是一个状态机&#xff0c;封装了多个内部状态。执行 Generator 函数…

JVM字节码文件结构

Class文件结构class文件是二进制文件&#xff0c;这里要介绍的是这个二级制文件的结构。思考&#xff1a;一个java文件编译成class文件&#xff0c;如果要描述一个java文件&#xff0c;需要哪些信息呢&#xff1f;基本信息&#xff1a;类名、父类、实现哪些接口、方法个数、每个…

11.web api 2

5. 操作元素属性 5.1操作元素常用属性 &#xff1a;通过 JS 设置/修改标签元素属性&#xff0c;比如通过 src更换 图片最常见的属性比如&#xff1a; href、title、src 等5.2 操作元素样式属性 &#xff1a;通过 JS 设置/修改标签元素的样式属性。使用 className 有什么好处&a…

java中数组和list的区别是什么?

在Java中&#xff0c;数组&#xff08;Array&#xff09;和List&#xff08;通常指java.util.List接口的实现类&#xff0c;如ArrayList、LinkedList&#xff09;是两种常用的容器&#xff0c;但它们在设计、功能和使用场景上有显著区别。以下从核心特性、使用方式等方面详细对…

Python爬取推特(X)的各种数据

&#x1f31f; Hello&#xff0c;我是蒋星熠Jaxonic&#xff01; &#x1f308; 在浩瀚无垠的技术宇宙中&#xff0c;我是一名执着的星际旅人&#xff0c;用代码绘制探索的轨迹。 &#x1f680; 每一个算法都是我点燃的推进器&#xff0c;每一行代码都是我航行的星图。 &#x…

Oracle数据库文件管理与空间问题解决指南

在Oracle数据库运维中&#xff0c;表空间、数据文件及相关日志文件的管理是保障数据库稳定运行的核心环节。本文将系统梳理表空间与数据文件的调整、关键文件的移动、自动扩展配置&#xff0c;以及常见空间不足错误的排查解决方法&#xff0c;为数据库管理员提供全面参考。 一、…

华为实验综合小练习

描述&#xff1a; 1 内网有A、B、C 三个部门。所在网段如图所示。 2 内网服务器配置静态IP,网关192.168.100.1。 3 sw1和R1之间使用vlan200 192.168.200.0/30 互联。 4 R1向运营商申请企业宽带并申请了5个公网IP&#xff1a;200.1.1.1-.5子网掩码 255.255.255.248&#xff0c;网…

Flink面试题及详细答案100道(1-20)- 基础概念与架构

《前后端面试题》专栏集合了前后端各个知识模块的面试题&#xff0c;包括html&#xff0c;javascript&#xff0c;css&#xff0c;vue&#xff0c;react&#xff0c;java&#xff0c;Openlayers&#xff0c;leaflet&#xff0c;cesium&#xff0c;mapboxGL&#xff0c;threejs&…

爬虫逆向之滑块验证码加密分析(轨迹和坐标)

本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的。否则由此产生的一切后果均与作者无关&#xff01;在爬虫开发过程中&#xff0c;滑块验证码常常成为我们获取数据的一大阻碍。而滑块验证码的加密方式多种多样&#xff0c;其中轨迹加密和坐标加密是比较常见的…