终于该学习神经网络的搭建了,开心,嘻嘻

学习神经网络离不开torch.nn,先把他印在脑子里,什么是torch.nn?他是Pytorch的一个模块,包含了大量构建神经网络需要的类和方法,就像前面学习的torch.utils,什么?忘了torch.utils是啥了?

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

想起来了吗?

1.nn.Module

先来认识一下神经网络最基本的基类:nn.Module

nn.Module是 PyTorch 中所有神经网络模块的基类。它提供了许多重要的功能和方法,简化了神经网络的构建、训练和管理。

想要搭建自己的神经网络,首先要继承他,并实现他的两个方法:__init__和forward

__init__:主要用于定义神经网络中的各种层

forward:调用上面的层,数据经过哪些层,怎么处理都是在这里定义

什么是层?

神经网络中包含了各种层,他们都有各自处理数据的方式,每一层就像一把菜刀一样,__init__就是定义这些菜刀的规格,他们有多大,多长等等。forward呢,就是定义送进来的食材需要用哪些菜刀?需不需要加佐料,撒点孜然、葱花啥的

看一下基本的代码结构:

  • from torch import nnclass Model(nn.Module):def __init__(self):# 空的 一把菜刀也没有super(Model, self).__init__()def forward(self, x):# 食材怎么进来就怎么出去,没有经过任何处理return x

2.Conv2d

接下来要开始学习第一把刀的使用:Conv2d,PyTorch 中用于处理二维数据(如图像)的卷积层

背景知识:

1.名称解析

Conv2d:

  • Conv:Convolution的简写,表示卷积操作
  • 2d:表示卷积操作是对二维数据进行的,如:图像

2.维度(Pytorch)

在Pytorch中,图像通常是3D张量,维度是3,形状:(channel,height,width)

  • channel:代表通道数,彩色图像是3,灰色图像是1
  • height:图像的高度,每一行的像素数量
  • width:图像的宽度,每一列的像素数量

彩色图像:

  • 彩色图像:形状(3,68,68),代表这是一张68*68的彩色图像
  • 灰色图像:形状(1,68,68),代表这是一张68*68的灰色图像

在Pytorch的学习中,通常将多个图像放入一个batch(批次)中,这样可以进行批量处理,因此图像张量会有一个额外的维度,就是batch_size,代表批次的大小,此时图像就是一个4D张量,形状是(batch_size,channel,height,width)

3.空间维度

此时,我们再从空间维度上理解一下图像,因为他的基本结构只有(height,width),所以我们称为2D图像,视频是由一系列连续的图像帧组成的,每个时间点(帧)是一个图像,所以他是3D视频,形状(time,height,width)

Conv2d就是专门处理2D图像的,当然还有Conv3d,这个我们以后再说,有兴趣可以去官网了解一下

4.卷积:

卷积简单来说就是将卷积核(一个矩阵)在输入数据上滑动,并与输入内容的局部区域进行点乘操作,最后输出一个新的矩阵

蓝色矩阵就叫做卷积核绿色矩阵就是我们的输入数据(按照目前学习阶段,那应该是一个图像数据),黄色部分是卷积核移动到输入内容上的第一步,用卷积核上的每一个区域的数字和输入数据上的对应区域的数字进行相乘,再相加,比如:1*1+2*2+3*3......最后9个区域的和加起来,作为输出内容的第一行第一列的数据,(1+4+9)*3=42

然后卷积核向右平移一步,继续计算,最后得到一个3*3的矩阵,也就是棕色的那个

动图:

有条件的可以去官网看一下:https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md

5.常用参数:

  • in_channels :输入数据的通道数
  • out_channels :输出数据的通道数,也是卷积核的个数
  • kernel_size :卷积核大小
  • stride:卷积核滑动的步数
  • padding:填充,对输入数据进行填充
  • padding_mode :对填充部分的数据进行更改,默认0

填充(绿色区域):

填充之后进行滑动,步数为2

实操部分

1.准备数据集

import torchvision.datasets
from torch.utils.data import DataLoader
from torchvision import transforms# 定义内置数据集
dataset = torchvision.datasets.CIFAR10(root='dataset', train=False, download=True, transform=transforms.ToTensor())# 定义数据加载器
dataloader = DataLoader(dataset, batch_size=64)

2.准备模型

from torch import nn# 准备模型
class Model(nn.Module):def __init__(self):super(Model, self).__init__()# 定义一个网络层conv1 处理2D图像self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)def forward(self, x):# 调用第一个网络层conv1,用来处理输入数据xx = self.conv1(x)return xmodel = Model()

3.将数据循环写入tensorboard

torch.reshape:接受一个张量,不改变原始张量,返回一个新的张量

参数:

  • input:要改变的张量
  • shape:你希望的新张量的形状(batch_size,channel,height,width)

当某一维度是-1,表示由PyTorch 自动推算,但最多只能有一个维度为-1

writer = SummaryWriter("logs")
step = 0
# 遍历数据加载器
for data in dataloader:imgs, labels = data# 调用模型对图片进行处理output=model(imgs)# 把输入内容的图片写入tensorboardwriter.add_images("input", imgs, step)#改变输出图的形状output = torch.reshape(output,(-1,3,30,30))# 把输出内容的图片写入tensorboardwriter.add_images("output", output, step)step += 1
writer.close()

4.完整代码:

import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms# 定义内置数据集
dataset = torchvision.datasets.CIFAR10(root='dataset', train=False, download=True, transform=transforms.ToTensor())# 定义数据加载器
dataloader = DataLoader(dataset, batch_size=64)# 准备模型
class Model(nn.Module):def __init__(self):super(Model, self).__init__()# 定义一个网络层conv1 处理2D图像self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)def forward(self, x):# 调用第一个网络层conv1,用来处理输入数据xx = self.conv1(x)return xmodel = Model()writer = SummaryWriter("logs")
step = 0
# 遍历数据加载器
for data in dataloader:imgs, labels = data# 调用模型对图片进行处理output=model(imgs)# 把输入内容的图片写入tensorboardwriter.add_images("input", imgs, step)#改变输出图的形状output = torch.reshape(output,(-1,3,30,30))# 把输出内容的图片写入tensorboardwriter.add_images("output", output, step)step += 1
writer.close()

运行tensorboard,网页展示出来是这样的,因为模型处理后的图片通道是6,我们后面又手动改变了图片的通道,所以多余的数据就到了批次那里,input是64,output是128

好了,拜拜

    本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
    如若转载,请注明出处:http://www.pswp.cn/pingmian/86134.shtml
    繁体地址,请注明出处:http://hk.pswp.cn/pingmian/86134.shtml
    英文地址,请注明出处:http://en.pswp.cn/pingmian/86134.shtml

    如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

    相关文章

    学习C++、QT---07(C++的权限、C++的引用)

    每日一言 你解决的每一个难题,都是在为未来的自己解锁新技能。 权限的讲解 这边呢我们利用银行的一个案例来讲解权限的奥秘 权限指的是public、private 、protected 就是这三种权限,因此有这一张表进行分清他们之间的区别和联系 但是我们在平时的话会因…

    全球化短剧平台全栈技术架构白皮书:多区域部署、智能分发与沉浸式体验的完整解决方案

    一、全球化基础架构深度设计 全球网络基础设施构建 采用多活数据中心部署模式,在北美(弗吉尼亚)、欧洲(法兰克福)、亚太(新加坡)建立三大核心枢纽节点 构建混合CDN网络,整合AWS Clo…

    深入剖析 LGM—— 开启高分辨率 3D 内容创作新时代

    一、引言 在当今数字化时代,3D 内容创作的需求如井喷般增长,从游戏开发中绚丽多彩的虚拟世界,到影视制作里震撼人心的特效场景,再到工业设计中精准无误的产品原型,3D 技术无处不在。然而,传统 3D 内容创作…

    从用户到社区Committer:小米工程师隋亮亮的Apache Fory成长之路

    Apache Fory 是一个基于JIT和零拷贝的高性能多语言序列化框架,实现了高效紧凑的序列化协议,提供极致的性能、压缩率和易用性。在多语言序列化框架技术领域取得了重大突破,推动序列化技术步入高性能易用新篇章!这一切,都…

    【Koa系列】10min快速入门Koa

    简介 koa是基于node开发的一个服务端框架,功能同express,但更小巧简单。 官方仓库地址:https://github.com/koajs/koa 创建项目 创建文件夹nodeKoa,执行以下脚本 npm init -y npm i koa npm i nodemon 基础示例 创建一个服…

    IDEA与通义联合:智能编程效率革命

    IDEA与通义联合:智能编程效率革命 当最强Java IDE遇上顶尖AI助手,会碰撞出怎样的生产力火花? 思维导图解读:智能编程工作流 #mermaid-svg-uTAcSs1kBBmDwGfM {font-family:"trebuchet ms",verdana,arial,sans-serif;font…

    Docker 数据持久化完全指南:Volume、Bind Mount 与匿名卷

    Docker 数据持久化完全指南:Volume、Bind Mount 与匿名卷 引言 在 Docker 中,容器的文件系统默认是临时的,容器删除后数据也会丢失。为了实现数据持久化,Docker 提供了多种存储方式,主要包括: docker vo…

    OSS跨区域复制灾备方案:华东1到华南1的数据同步与故障切换演练

    1. 引言 对象存储服务(OSS)已成为现代数据架构的核心组件。随着业务全球化,跨区域数据灾备从“可选”变为“必选”。本文以阿里云OSS为实验环境,实战演练华东1(杭州)到华南1(深圳)的…

    前端登录状态管理:主流方案对比与安全实践指南

    根据目前业内前端登录状态管理的主流设计方案,及其演进趋势进行汇总,生成主要包括如下内容的报告: 登录状态保持的基础原理:从HTTP无状态问题出发解析技术需求,使用表格对比核心挑战。主流技术方案对比:详…

    动手用 Web 实现一个 2048 游戏

    文章目录 为什么选择 2048?关键技术点与算法详解HTML 结构:搭建游戏界面CSS 样式:美化游戏界面JavaScript 核心逻辑:驱动游戏运行1)数据结构:二维数组表示游戏网格2)核心算法:添加随…

    frp v0.62.1内网穿透搭建和使用

    官网:https://gofrp.org/zh-cn/ Github:https://github.com/fatedier/frp 开源项目 frp frp 是一种快速反向代理,允许您将位于 NAT 或防火墙后面的本地服务器公开给 Internet。目前支持 TCP 和 UDP,以及 HTTP 和 HTTPS 协议&…

    如何使用 USB 数据线将文件从 PC 传输到 iPhone

    虽然用 USB 数据线将文件从 PC 传输到安卓设备非常容易,但对于 iPhone 用户来说,情况就不同了。不过,幸运的是,我们找到了三种可靠的方法,可以使用 USB 数据线将文件从 PC 传输到 iPhone,让您轻松完成这项任…

    【C++高阶三】AVL树深度剖析

    【C高阶三】AVL树深度剖析 1.什么是AVL树2.AVL树的实现2.1节点类和基本结构2.2插入2.3旋转处理2.3.1左单旋2.3.2右单旋2.3.3左右双旋2.3.4右左双旋 1.什么是AVL树 AVL树也叫二叉搜索平衡树 因为二叉搜索树如果插入顺序是有序的,那么这棵树的查找效率将会是O(N)&…

    LangChain 文本分割器深度解析:从原理到落地应用(上)

    食用指南 LangChain 作为大语言模型应用开发框架,文本分割器是其核心组件之一,本文以此作为切入点,详细介绍文本分割的作用、策略、以及常见的文本切割器应用。考虑到篇幅过长,故拆分为上、中、下三篇,后续会在中篇介…

    【Java高频面试问题】高并发篇

    【Java高频面试问题】高并发篇 Kafka原理核心组件高吞吐核心机制高可用设计 Kafka 如何保证消息不丢失如何解决Kafka重复消费一、生产者端:根源防重二、消费者端:精准控制三、业务层:幂等性设计(核心方案) 如何解决Kaf…

    关于结构体,排序,递推的详细讲解(从属于GESP四级)

    本章内容 排序算法基础 结构体 递推 简单双指针 一、排序算法基础三剑客 冒泡 Bubble、选择 Selection、插入 Insertion 1. 预备知识 1.1 排序算法评价指标 指标 含义 影响答题的典型问法 时间复杂度 算法在最坏、平均或最好情况下所需比较 / 交换次数 “写出此算法…

    离线部署docker中的containerd服务

    containerd 是一个行业标准的容器运行时,专注于简单、健壮的容器执行。它是从 Docker 中分离出来的项目,旨在作为一个底层的运行时接口,供更高层次的容器管理层使用。 containerd 负责镜像传输、存储、容器执行、网络配置等工作。它向上为 Do…

    web布局15

    CSS 网格布局除了提供定义网格和放置网格项目的相关属性之外,也提供了一些控制对齐方式的属性。这些控制对齐方式的属性,和 Flexbox 布局中的对齐属性 justify-* 、align-* 、*-items 、*-content 、 *-self 等是相似的: 在网格布局中可以用它…

    leetcode 291. Word Pattern II和290. Word Pattern

    目录 291. Word Pattern II 290. Word Pattern 291. Word Pattern II 回溯法哈希表 class Solution {unordered_map<char,string> hashmap;unordered_set<string> wordset; public:bool wordPatternMatch(string pattern, string s) {return backtrack(pattern,…

    大模型的开发应用(十三):基于RAG的法律助手项目(上):总体流程简易实现

    RAG法律助手项目&#xff08;上&#xff09;&#xff1a;总体流程简易实现 1 项目介绍1.1 方案选型1.2 知识文档 2 文档解析3 知识库构建3.1 构建知识节点3.2 嵌入向量初始化3.2 向量存储 4 查询4.1 初始化大模型4.2 模型响应4.2 本文程序存在的问题 完整代码 1 项目介绍 本项…