关于torch.nn:

使⽤Pytorch来构建神经⽹络, 主要的⼯具都在torch.nn包中.
nn依赖于autograd来定义模型, 并对其⾃动求导.

构建神经⽹络的典型流程:

  1. 定义⼀个拥有可学习参数的神经⽹络
  2. 遍历训练数据集
  3. 处理输⼊数据使其流经神经⽹络
  4. 计算损失值
  5. 将⽹络参数的梯度进⾏反向传播
  6. 以⼀定的规则更新⽹络的权重
我们⾸先定义⼀个Pytorch实现的神经⽹络:
# 导⼊若⼲⼯具包
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义⼀个简单的⽹络类
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 定义第⼀层卷积神经⽹络, 输⼊通道维度=1, 输出通道维度=6, 卷积核⼤⼩3*3self.conv1 = nn.Conv2d(1, 6, 3)# 定义第⼆层卷积神经⽹络, 输⼊通道维度=6, 输出通道维度=16, 卷积核⼤⼩3*3self.conv2 = nn.Conv2d(6, 16, 3)# 定义三层全连接⽹络self.fc1 = nn.Linear(16 * 6 * 6, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x): # 在(2, 2)的池化窗⼝下执⾏最⼤池化操作x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):# 计算size, 除了第0个维度上的batch_sizesize = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_features
net = Net()
print(net)

输出结果:
Net((conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))(fc1): Linear(in_features=576, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)
注意:
模型中所有的可训练参数, 可以通过net.parameters()来获得.
params = list(net.parameters())
print(len(params))
print(params[0].size())

输出结果:
10
torch.Size([6, 1, 3, 3])

输出结果:
假设图像的输⼊尺⼨为32 * 32:
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

输出结果:
tensor([[ 0.1242, 0.1194, -0.0584, -0.1140, 0.0661, 0.0191, -0.0966, 
0.0480, 0.0775, -0.0451]], grad_fn=<AddmmBackward>)

有了输出张量后, 就可以执⾏梯度归零和反向传播的操作了.
net.zero_grad()
out.backward(torch.randn(1, 10))

注意:
  • torch.nn构建的神经⽹络只⽀持mini-batches的输⼊, 不⽀持单⼀样本的输⼊.
  • ⽐如: nn.Conv2d 需要⼀个4D Tensor, 形状为(nSamples, nChannels, Height, Width). 如果你的输⼊只有单⼀样本形式, 则需要执⾏input.unsqueeze(0), 主动将3D Tensor扩充成4D Tensor.
损失函数
  • 损失函数的输⼊是⼀个输⼊的pair: (output, target), 然后计算出⼀个数值来评估output和target之间的差距⼤⼩.
  • 在torch.nn中有若⼲不同的损失函数可供使⽤, ⽐如nn.MSELoss就是通过计算均⽅差损失来评估输⼊和⽬标值之间的差距.
应⽤nn.MSELoss计算损失的⼀个例⼦:
output = net(input)
target = torch.randn(10)
# 改变target的形状为⼆维张量, 为了和output匹配
target = target.view(1, -1)
criterion = nn.MSELoss()
loss = criterion(output, target)
print(loss)

输出结果:
tensor(1.1562, grad_fn=<MseLossBackward>)

关于⽅向传播的链条: 如果我们跟踪loss反向传播的⽅向, 使⽤.grad_fn属性打印, 将可以看到⼀张完整的计算图如下:
input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d-> view -> linear -> relu -> linear -> relu -> linear-> MSELoss-> loss

当调⽤loss.backward()时, 整张计算图将对loss进⾏⾃动求导, 所有属性requires_grad=True的Tensors都将参与梯度求导的运算, 并将梯度累加到Tensors中的.grad属性中.
print(loss.grad_fn) # MSELoss
print(loss.grad_fn.next_functions[0][0]) # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU

输出结果:

反向传播(backpropagation)

在Pytorch中执⾏反向传播⾮常简便, 全部的操作就是loss.backward().
在执⾏反向传播之前, 要先将梯度清零, 否则梯度会在不同的批次数据之间被累加.
执⾏⼀个反向传播的⼩例⼦:
# Pytorch中执⾏梯度清零的代码
net.zero_grad()
print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)
# Pytorch中执⾏反向传播的代码
loss.backward()
print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

输出结果:
conv1.bias.grad before backward
tensor([0., 0., 0., 0., 0., 0.])
conv1.bias.grad after backward
tensor([-0.0002, 0.0045, 0.0017, -0.0099, 0.0092, -0.0044])

更新⽹络参数

  • 更新参数最简单的算法就是SGD(随机梯度下降).
  • 具体的算法公式表达式为: weight = weight - learning_rate * gradient
⾸先⽤传统的Python代码来实现SGD如下:
learning_rate = 0.01
for f in net.parameters():f.data.sub_(f.grad.data * learning_rate)

然后使⽤Pytorch官⽅推荐的标准代码如下:
# ⾸先导⼊优化器的包, optim中包含若⼲常⽤的优化算法, ⽐如SGD, Adam等
import torch.optim as optim
# 通过optim创建优化器对象
optimizer = optim.SGD(net.parameters(), lr=0.01)
# 将优化器执⾏梯度清零的操作
optimizer.zero_grad()
output = net(input)
loss = criterion(output, target)
# 对损失值执⾏反向传播的操作
loss.backward()
# 参数的更新通过⼀⾏标准代码来执⾏
optimizer.step()

⼩节总结

  • 学习了构建⼀个神经⽹络的典型流程:
    • 定义⼀个拥有可学习参数的神经⽹络
    • 遍历训练数据集
    • 处理输⼊数据使其流经神经⽹络
    • 计算损失值
    • 将⽹络参数的梯度进⾏反向传播
    • 以⼀定的规则更新⽹络的权重
  • 学习了损失函数的定义:
  • 采⽤torch.nn.MSELoss()计算均⽅误差.
  • 通过loss.backward()进⾏反向传播计算时, 整张计算图将对loss进⾏⾃动求导, 所有属性
  • requires_grad=True的Tensors都将参与梯度求导的运算, 并将梯度累加到Tensors中
  • 的.grad属性中.
  • 学习了反向传播的计算⽅法:
  • 在Pytorch中执⾏反向传播⾮常简便, 全部的操作就是loss.backward().
  • 在执⾏反向传播之前, 要先将梯度清零, 否则梯度会在不同的批次数据之间被累加.
    • net.zero_grad()
    • loss.backward()
  • 学习了参数的更新⽅法:
  • 定义优化器来执⾏参数的优化与更新.
    • optimizer = optim.SGD(net.parameters(), lr=0.01)
  • 通过优化器来执⾏具体的参数更新.
    • optimizer.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/89380.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/89380.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/89380.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网络爬虫的详细知识点

基本介绍 什么是网络爬虫 网络爬虫&#xff08;Web Crawler&#xff09;是一种自动化程序&#xff0c;用于从互联网上抓取、解析和存储网页数据。其核心功能是模拟人类浏览行为&#xff0c;通过HTTP/HTTPS协议访问目标网站&#xff0c;提取文本、链接、图片或其他结构化信息&…

AndroidX中ComponentActivity与原生 Activity 的区别

一、AndroidX 与原生 Activity 的区别 1. 概念与背景 原生 Activity&#xff1a;指 Android 早期&#xff08;API 1 起&#xff09;就存在于 android.app 包下的 Activity 类&#xff08;如 android.app.Activity&#xff09;&#xff0c;是 Android 最初的 Activity 实现&…

Spring AI 使用 Elasticsearch 作为向量数据库

前言 嗨&#xff0c;大家好&#xff0c;我是雪荷&#xff0c;最近在公司开发 AI 知识库&#xff0c;同时学到了一些 AI 开发相关的技术&#xff0c;这期先与大家分享一下如何用 ES 当做向量数据库。 安装ES 第一步我们先安装 Elasticsearch&#xff0c;这里建议 Elasticsear…

TypeScript 配置全解析:tsconfig.json、tsconfig.app.json 与 tsconfig.node.json 的深度指南

前言在现代前端和后端开发中&#xff0c;TypeScript 已经成为许多开发者的首选语言。然而&#xff0c;TypeScript 的配置文件&#xff08;特别是多个配置文件协同工作时&#xff09;常常让开发者感到困惑。本文将深入探讨 tsconfig.json、tsconfig.app.json 和 tsconfig.node.j…

读书笔记(学会说话)

1、一个人只有会说话&#xff0c;才会有好人缘&#xff0c;做事才会顺利。会说话的人容易成功。善于说话的人易成功&#xff0c;而不善说话的人往往寸步难行。我们要把话说得好听&#xff0c;同时更要把事做得漂亮。或许一句话&#xff0c;一件事&#xff0c;就可能使人生的旅途…

私有服务器AI智能体搭建-大模型选择优缺点、扩展性、可开发

以下是主流 AI 框架与模型的对比分析&#xff0c;涵盖其优缺点、扩展性、可开发性等方面。 文章目录一、AI 框架对比二、主流大模型对比三、扩展性对比总结四、可开发性对比总结五、选择建议&#xff08;按场景&#xff09;六、未来趋势一、AI 框架对比 框架优点缺点扩展性可开…

OpenCV直线段检测算法类cv::line_descriptor::LSDDetector

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 该类用于实现 LSD (Line Segment Detector) 直线段检测算法。LSD 是一种快速、准确的直线检测方法&#xff0c;能够在不依赖边缘检测的前提下直接从…

Go语言流程控制(if / for)

分支结构package mainimport ("fmt""strconv" )/* 1.顺序结构 2.分支结构 3.循环结构 *//* if 条件1 {// 条件1为真时执行的代码 } else if 条件2 {// 条件1为假但条件2为真时执行的代码 } else {// 所有条件均为假时执行的代码 }一种特殊的条件分支结构if…

wx小程序设置沉浸式导航文字高度问题

第一步&#xff1a;在app.json中设置"navigationStyle": "custom"第二步骤&#xff1a;文件的home.js中// pages/test/test.js Page({/*** 页面的初始数据*/data: {statusBarHeight: 0,navBarHeight: 44 // 自定义导航内容区高度(单位px)},/*** 生命周期函…

C++算法竞赛篇:DevC++ 如何进行debug调试

C算法竞赛篇&#xff1a;DevC 如何进行debug调试前言一、准备工作&#xff1a;编译生成可执行程序二、核心步骤&#xff1a;设置断点与启动调试1. 设置断点2. 启动调试模式三、调试操作&#xff1a;逐步执行与变量监控1. 逐步执行代码2. 监控变量值变化四、调试结束前言 在算法…

语音大模型速览(三)- cosyvoice2

CosyVoice 2: Scalable Streaming Speech Synthesis with Large Language Models 论文链接&#xff1a;https://arxiv.org/pdf/2412.10117代码链接&#xff1a;https://github.com/FunAudioLLM/CosyVoice 一句话总结 CosyVoice 2 是一款改进的流式语音合成模型&#xff0c;其…

-lstdc++与-static-libstdc++的用法和差异

CMakeLists.txt 里写了&#xff1a; target_link_libraries(${PROJECT_NAME} PRIVATEgccstdc ) target_link_options(${PROJECT_NAME} PRIVATE -static-libstdc)看起来像是“链接了两次 C 标准库”&#xff0c;其实它们的作用完全不同&#xff1a;1. target_link_libraries(...…

Redis学习其二(事务,SpringBoot整合,持久化RDB和AOF)

文章目录5,事务5.1Redis 事务不保证原子性的原因5.2事务操作过程5.3监控6,SpringBoot整合Redis6.1Redis客户端6.1.1Jedis简单使用6.1.2Lettuce&Jedis6.2配置相关6.3使用6.3.1使用RedisTemplate6.3.2Redis工具类7,持久化RDB7.1RDB持久化原理7.2触发机制save命令flushall命令…

springboot项目部署到K8S

java后台 创建harbor镜像拉取Secret&#xff1a;kubectl create secret docker-registry harbor-regcred \--docker-server \ #harbor仓库地址--docker-username \ #harbor 账号--docker-password \ #harbor密码-n productionDockerfile FROM *harbor地址*/library/custom-jdk…

【FPGA开发】一文轻松入门Modelsim的基本操作

Modelsim仿真的步骤 &#xff08;1&#xff09;创建新的工程。 &#xff08;2&#xff09;在弹出的窗口中&#xff0c;确定项目名和工作路径&#xff0c;库保持为work不变(如有需要可以根据需求进行更改)。 &#xff08;3&#xff09;添加已经存在的文件&#xff08;rtl代码和t…

服务攻防-Java组件安全FastJson高版本JNDI不出网C3P0编码绕WAF写入文件CI链

服务攻防-Java组件安全&FastJson&高版本JNDI&不出网C3P0&编码绕WAF&写入文件CI链26天 原创 朝阳 Sec朝阳 2025年07月18日 09:23 湖北 标题已修改 演示环境&#xff1a; https://github.com/lemono0/FastJsonParty FastJson全版本Docker漏洞环境(涵盖1.…

【Python】DRF核心组件详解:Mixin与Generic视图

在 Django REST Framework (DRF) 中&#xff0c;mixins.CreateModelMixin、mixins.ListModelMixin、GenericAPIView 和 GenericViewSet 是构建 API 视图的核心组件。以下是对这些组件的主要方法及其职责的简要说明&#xff0c;内容清晰且结构化&#xff1a;1. mixins.CreateMod…

HTML+CSS+JS基础

文章目录&#xff08;一&#xff09;html1.常见标签&#xff08;1&#xff09;注释&#xff08;2&#xff09;标题 h1~h6&#xff08;3&#xff09;段落 p&#xff08;4&#xff09;换行与空格 br \ &#xff08;5&#xff09;格式化标签 b i s u&#xff08;6&#xff09;…

Vue导出Html为Word中包含图片在Microsoft Word显示异常问题

问题背景 碰到一个问题&#xff1a;将包含图片和SVG数学公式的HTML内容导出为Word文档时&#xff0c;将图片都转为ase64格式导出&#xff0c;在WPS Word中显示正常&#xff0c;但是在Microsoft Word中出现图片示异常。具体问题表现 WPS兼容性&#xff1a;在WPS中显示正常&#…

椭圆曲线密码学 Elliptic Curve Cryptography

密码学是研究在存在对抗行为的情况下还能安全通信的技术。即算法加密信息&#xff0c;再算法解密出信息。加密分为两类 1. Symmetric-key Encryption (secret key encryption) 即一种密钥&#xff0c;加密和解密使用同一密钥&#xff0c;可相互转换 2. Asymmetric-key Encry…