PyTorch深度学习总结

第九章 PyTorch中torch.nn模块的循环层


文章目录

  • PyTorch深度学习总结
  • 前言
  • 一、循环层
      • 1. 简单循环层(RNN)
      • 2. 长短期记忆网络(LSTM)
      • 3. 门控循环单元(GRU)
      • 4. 双向循环层
  • 二、循环层参数
      • 1. 输入维度相关参数
      • 2. 隐藏层相关参数
      • 3. 其他参数
  • 三、函数总结


前言

上文介绍了PyTorch中介绍了池化和torch.nn模块中的池化层函数,本文将进一步介绍torch.nn模块中的循环层。


一、循环层

在PyTorch中,循环层Recurrent Layers)是处理序列数据的重要组件,常用于自然语言处理、时间序列分析等领域。
下面为你详细介绍几种常见的循环层:

1. 简单循环层(RNN)

  • 原理简单循环层RNN)是最基础的循环神经网络结构,它在每个时间步接收当前输入和上一个时间步的隐藏状态,通过特定的激活函数计算当前时间步的隐藏状态。这种结构使得RNN能够对序列数据中的时间依赖关系进行建模。
  • PyTorch实现:在PyTorch中,可以使用torch.nn.RNN类来构建简单循环层。以下是一个简单的示例代码:
import torch
import torch.nn as nn# 定义输入维度、隐藏维度和层数
input_size = 10
hidden_size = 20
num_layers = 1# 创建RNN层
rnn = nn.RNN(input_size, hidden_size, num_layers)# 生成输入数据
batch_size = 3
seq_len = 5
input_data = torch.randn(seq_len, batch_size, input_size)# 初始化隐藏状态
h_0 = torch.randn(num_layers, batch_size, hidden_size)# 前向传播
output, h_n = rnn(input_data, h_0)
  • 应用场景:简单循环层适用于处理一些简单的序列数据,例如短文本分类、简单的时间序列预测等。但由于存在梯度消失梯度爆炸的问题,对于长序列数据的处理效果不佳。

2. 长短期记忆网络(LSTM)

  • 原理长短期记忆网络LSTM)是为了解决RNN的梯度消失问题而提出的。它引入了门控机制,包括输入门、遗忘门和输出门,通过这些门控单元可以更好地控制信息的流动,从而有效地捕捉序列数据中的长距离依赖关系。
  • PyTorch实现:在PyTorch中,可以使用torch.nn.LSTM类来构建LSTM层。示例代码如下:
import torch
import torch.nn as nn# 定义输入维度、隐藏维度和层数
input_size = 10
hidden_size = 20
num_layers = 1# 创建LSTM层
lstm = nn.LSTM(input_size, hidden_size, num_layers)# 生成输入数据
batch_size = 3
seq_len = 5
input_data = torch.randn(seq_len, batch_size, input_size)# 初始化隐藏状态和细胞状态
h_0 = torch.randn(num_layers, batch_size, hidden_size)
c_0 = torch.randn(num_layers, batch_size, hidden_size)# 前向传播
output, (h_n, c_n) = lstm(input_data, (h_0, c_0))
  • 应用场景:LSTM广泛应用于自然语言处理中的机器翻译、文本生成,以及时间序列分析中的股票价格预测、天气预测等领域。

3. 门控循环单元(GRU)

  • 原理门控循环单元GRU)是LSTM的一种简化版本,它将LSTM中的输入门和遗忘门合并为一个更新门,并取消了细胞状态,只保留隐藏状态。这种简化使得GRU的计算效率更高,同时也能够较好地捕捉序列数据中的长距离依赖关系。
  • PyTorch实现:在PyTorch中,可以使用torch.nn.GRU类来构建GRU层。示例代码如下:
import torch
import torch.nn as nn# 定义输入维度、隐藏维度和层数
input_size = 10
hidden_size = 20
num_layers = 1# 创建GRU层
gru = nn.GRU(input_size, hidden_size, num_layers)# 生成输入数据
batch_size = 3
seq_len = 5
input_data = torch.randn(seq_len, batch_size, input_size)# 初始化隐藏状态
h_0 = torch.randn(num_layers, batch_size, hidden_size)# 前向传播
output, h_n = gru(input_data, h_0)
  • 应用场景:GRU在一些对计算资源要求较高的场景中表现出色,例如实时语音识别、在线文本分类等。

4. 双向循环层

  • 原理双向循环层Bidirectional RNN/LSTM/GRU)是在单向循环层的基础上扩展而来的。它同时考虑了序列数据的正向和反向信息,通过将正向和反向的隐藏状态拼接或相加,能够更全面地捕捉序列数据中的上下文信息。
  • PyTorch实现:在PyTorch中,可以通过设置bidirectional=True来创建双向循环层。以双向LSTM为例,示例代码如下:
import torch
import torch.nn as nn# 定义输入维度、隐藏维度和层数
input_size = 10
hidden_size = 20
num_layers = 1# 创建双向LSTM层
lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True)# 生成输入数据
batch_size = 3
seq_len = 5
input_data = torch.randn(seq_len, batch_size, input_size)# 初始化隐藏状态和细胞状态
h_0 = torch.randn(num_layers * 2, batch_size, hidden_size)
c_0 = torch.randn(num_layers * 2, batch_size, hidden_size)# 前向传播
output, (h_n, c_n) = lstm(input_data, (h_0, c_0))
  • 应用场景双向循环层自然语言处理中的命名实体识别、情感分析等任务中表现出色,因为这些任务需要充分利用上下文信息来做出准确的判断。

二、循环层参数

以下为你详细介绍 PyTorch 中几种常见循环层(RNN、LSTM、GRU)的常见参数:

1. 输入维度相关参数

  • input_size
  • 含义:该参数表示输入序列中每个时间步的特征数量。可以理解为输入数据的特征维度
    - 例子:在处理文本数据时,如果使用词向量表示每个单词,词向量的维度就是 input_size。假如使用 300 维的词向量,那么 input_size 就为 300。
  • batch_first
  • 含义:这是一个布尔类型的参数,用于指定输入和输出张量的维度顺序。当 batch_first=True 时,输入和输出张量的形状为 (batch_size, seq_len, input_size);当 batch_first=False(默认值)时,形状为 (seq_len, batch_size, input_size)
    - 例子:假设 batch_size 为 32,seq_len 为 10,input_size 为 50。若 batch_first=True,输入张量形状就是 (32, 10, 50);若 batch_first=False,输入张量形状则为 (10, 32, 50)

2. 隐藏层相关参数

  • hidden_size
  • 含义:代表隐藏状态的维度,即每个时间步中隐藏层神经元的数量。隐藏状态在循环层的计算中起着关键作用,它会在不同时间步之间传递信息。
    - 例子:如果 hidden_size 设置为 128,意味着每个时间步的隐藏层有 128 个神经元,隐藏状态的维度就是 128。
  • num_layers
  • 含义:表示循环层的堆叠层数。多层循环层可以学习更复杂的序列模式,通过堆叠多个循环层,模型能够从不同抽象层次上处理序列数据。
    - 例子:当 num_layers 为 2 时,意味着有两个循环层堆叠在一起,前一层的输出会作为后一层的输入。

3. 其他参数

  • bias
  • 含义:布尔类型参数,用于决定是否在循环层中使用偏置项。bias=True 表示使用偏置,bias=False 则不使用。
    - 例子:在大多数情况下,bias 默认为 True,即使用偏置项,这样可以增加模型的灵活性。
  • dropout
  • 含义:该参数用于在循环层中应用 Dropout 正则化,以防止过拟合。取值范围为 0 到 1 之间,表示 Dropout 的概率。
    - 例子:当 dropout = 0.2 时,意味着在训练过程中,每个神经元有 20% 的概率被随机置为 0。需要注意的是,dropout 只在 num_layers > 1 时有效。
  • bidirectional
  • 含义:布尔类型参数,用于指定是否使用双向循环层。bidirectional=True 表示使用双向循环层,bidirectional=False 表示使用单向循环层。
    - 例子:在双向 LSTM 中,设置 bidirectional=True 后,模型会同时考虑序列的正向和反向信息,最后将正反向的隐藏状态进行拼接或相加。
  • LSTM 的 proj_size
  • 含义:用于指定 LSTM 中投影层的维度。投影层可以将隐藏状态的维度进行压缩,从而减少模型的参数数量。
    - 例子:若 proj_size 为 64,原本 hidden_size 为 128,那么经过投影层后,隐藏状态的维度会变为 64。

三、函数总结

循环层类型原理PyTorch实现应用场景优缺点
简单循环层(RNN)每个时间步接收当前输入和上一个时间步的隐藏状态,通过激活函数计算当前时间步隐藏状态,对序列时间依赖关系建模rnn = nn.RNN(input_size, hidden_size, num_layers)短文本分类、简单时间序列预测等简单序列数据处理优点:结构简单;缺点:存在梯度消失或爆炸问题,处理长序列效果不佳
长短期记忆网络(LSTM)引入门控机制(输入门、遗忘门和输出门),控制信息流动,捕捉长距离依赖关系lstm = nn.LSTM(input_size, hidden_size, num_layers)机器翻译、文本生成、股票价格预测、天气预测等优点:能有效处理长序列;缺点:计算复杂度相对较高
门控循环单元(GRU)将LSTM的输入门和遗忘门合并为更新门,取消细胞状态,保留隐藏状态gru = nn.GRU(input_size, hidden_size, num_layers)实时语音识别、在线文本分类等对计算资源要求高的场景优点:计算效率高;缺点:在某些复杂长序列任务效果可能不如LSTM
双向循环层(Bidirectional RNN/LSTM/GRU)同时考虑序列正向和反向信息,通过拼接或相加正反向隐藏状态捕捉上下文信息lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True)命名实体识别、情感分析等需充分利用上下文信息的任务优点:能更全面捕捉上下文;缺点:计算量更大

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/88040.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/88040.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/88040.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu 24.04 LTS 服务器配置:安装 JDK、Nginx、Redis。

Ubuntu 24.04 LTS 服务器配置:安装 JDK、Nginx、Redis。新建用来放置软件安装包的目录 mkdir /home/software 配置目录所有者为 ubuntu 用户: chown ubuntu /home/software将软件安装包上传到 /home/software配置 JDK-8 新建 jdk 安装目录 mkdir /usr/ja…

工作中用到过哪些设计模式?是怎么实现的?

1. 单例模式(结合 Spring Component)场景:配置中心、全局状态管理 Spring 实现:java// 自动注册为Spring Bean(默认单例) Component public class AppConfig {Value("${server.port}")private in…

Leetcode 3609. Minimum Moves to Reach Target in Grid

Leetcode 3609. Minimum Moves to Reach Target in Grid 1. 解题思路2. 代码实现 题目链接:3609. Minimum Moves to Reach Target in Grid 1. 解题思路 这一题我一开始走岔了,走了一个正向遍历走法的思路,无论怎么剪枝都一直超时。后来看了…

工作流引擎:IDEA没有actiBPMN插件怎么办?

文章目录一、问题描述二、替代方案一、问题描述 我们在学习activiti7工作流引擎的时候,需要设计流程图。 一般推荐的就是使用IDEA插件actiBPMN进行开发。 但是,这个插件在IDEA2019后的版本都不在支持。 也就是搜不到 那么,怎么办了&#x…

Android音视频探索之旅 | CMake基础语法 创建支持Ffmpeg的Android项目

一.CMake语法 CMake语法非常多,我们知道如何导入静态库和动态库以及最基础的使用,目前是够用的。其它方面则根据实际项目同步学习。 1.1.基础语法-常用 cmake_minimum_required:指定cmake最小版本include_directories:引入&#x…

React Native 初始化项目和模拟器运行

中文官方文档:https://reactnative.cn/docs/environment-setup 英文官方文档:https://reactnative.dev/docs/getting-started-without-a-framework#step-1-creating-a-new-application 创建新项目 1、初始化 # 如果你之前全局安装过旧的react-native-cli…

20250706-5-Docker 快速入门(上)-创建容器常用选项_笔记

一、创建容器常用选项1. 创建容器常用选项1)常用选项创建容器常用选项交互式选项:-i:保持标准输入打开,允许交互式操作-t:分配伪终端,使容器像传统终端一…

插值与拟合(3):B样条曲线

在路径规划问题中,通常会用到B样条来平滑路径,本文实现并封装了三次准均匀开放B样条曲线,供大学学习使用。作者提供了三套代码方案。可以用于不同平台:方案1:MATLAB;方案2:标准C;方案…

[免费]基于Python豆瓣电影数据分析及可视化系统(Flask+echarts+pandas)【论文+源码+SQL脚本】

大家好,我是java1234_小锋老师,看到一个不错的于Python豆瓣电影数据分析及可视化系统(Flaskechartpandas)【论文源码SQL脚本】,分享下哈。项目介绍随着如今电影越来越多,各种各样的烂片和捞钱的商业片也层出不穷,而有意…

SQL127 月总刷题数和日均刷题数

SQL127 月总刷题数和日均刷题数 withtemp as (selectDATE_FORMAT(submit_time, "%Y%m") as submit_month,count(question_id) as month_q_cnt,round(count(question_id) / day(last_day(max(submit_time))),3) as avg_day_q_cntfrompractice_recordwhereyear(submit…

unity luban接入

1.找到luban官网并下载他的例子和.net8.0的sdk安装 官网地址如下 快速上手 | Luban 参考大佬教程如下 Luban新版本接入教程_哔哩哔哩_bilibili 2.找到他的luban_examples-main示例下的两个文件MiniTemplate和tool 3.MiniTemplate这个文件复制一份到项目工程下,自…

Django服务开发镜像构建

最后完整的项目目录结构1、安装依赖pip install django django-tables2 django-filter2、创建项目和主应用django-admin startproject configcd configpython manage.py startapp dynamic_models3、配置settings.py将项目模块dynamic_models加入进来,django_tables2…

20250706-3-Docker 快速入门(上)-常用镜像管理命令_笔记

一、配置加速器1. Docker Hub简介与地址公共镜像仓库: 由Docker公司维护的公共镜像仓库,包含大量容器镜像默认下载源: Docker工具默认从这个公共镜像库下载镜像访问地址: https://hub.docker.com镜像搜索功能: 可通过浏览器访问图形化管理系…

【unity游戏开发——优化篇】使用Occlusion Culling遮挡剔除,只渲染相机视野内的游戏物体提升游戏性能

注意:考虑到优化的内容比较多,我将该内容分开,并全部整合放在【unity游戏开发——优化篇】专栏里,感兴趣的小伙伴可以前往逐一查看学习。 文章目录 前言实战1、确保所有静止的3D物体都标记为Occluder Static静态遮挡体和Occludee …

通用业务编号生成工具类(MyBatis-Plus + Spring Boot)详解 + 3种调用方式

在企业应用开发中,我们经常需要生成类似 BZ -240704-0001 这种“业务编号”,它通常具有以下特点:前缀:代表业务类型,如 BZ 表示包装日期:年月日格式,通常为 yyMMdd序列号:当天内递增…

前端相关性能优化笔记

1.打开速度怎么变快 - 首屏加载优化2.再次打开速度怎么变快 - 缓存优化了3.操作怎么才顺滑 - 渲染优化4.动画怎么保证流畅 - 长任务拆分2.1 首屏加载指标细化:1.FP(First Paint 首次绘制) 2.FCP(First contentful Paint 首次内容绘制),FP 到 FCP 中间其实主要是 SPA…

7.7晚自习作业

实操作业02:Spark核心开发 作业说明 请严格按照步骤操作,并将最终结果文件(命名为:sparkcore_result.txt)于20点前上传。结果文件需包含每一步的关键命令执行结果文本输出。 一、数据读取与转换操作 上传账户数据$…

手机FunASR识别SIM卡通话占用内存和运行性能分析

手机FunASR识别SIM卡通话占用内存和运行性能分析 --本地AI电话机器人 上一篇:手机无网离线使用FunASR识别SIM卡语音通话内容 下一篇:手机通话语音离线ASR识别商用和优化方向 一、前言 书接上一文《阿里FunASR本地断网离线识别模型简析》,…

虚幻引擎Unreal Engine5恐怖游戏设计制作教程,从入门到精通从零开始完整项目开发实战详细讲解中英字幕

和大家分享一个以前收集的UE5虚幻引擎恐怖游戏开发教程,这是国外一个大神制作的视频教程,教程从零开始到制作出一款完整的游戏。内容讲解全面,如蓝图基础知识讲解、角色控制、高级交互系统、高级库存系统、物品检查、恐怖环境氛围设计、过场动…

多人协同开发时Git使用命令

拉取仓库代码 # 拉取远程仓库至本地tar_dir路径 git clone gitgithub.com:your-repo.git target_dir # 默认是拉取远程master分支,下面拉取并切换到自己需要开发的分支上 # 假设自己需要开发的分支是/feature/my_branch分支 git checkout -b feature/my_branch orig…