前面我们谈到RNN与LSTM之间的关系,而GRU也是循环神经网络中的一种模型,那么它与LSTM有什么区别呢?

接下来我来对GRU(Gated Recurrent Unit)模型进行一次深度解析,重点关注其内部结构、参数以及与LSTM的对比。GRU是LSTM的一种流行且高效的变体,由Cho等人在2014年提出,旨在解决与LSTM相同的长期依赖问题,但通过更简化的结构和更少的参数来实现。

核心思想:简化LSTM,保持性能

  • LSTM的复杂性: LSTM通过细胞状态C_t和隐藏状态h_t,以及三个门(遗忘门、输入门、输出门)来管理信息流。虽然有效,但参数较多,计算稍显复杂。

  • GRU的解决方案: GRU的核心创新在于:

    1. 合并状态: 取消了独立的细胞状态(C_t),只保留隐藏状态(h_t)。隐藏状态h_t同时承担了LSTM中细胞状态(承载长期记忆)和隐藏状态(作为当前输出)的双重角色。

    2. 减少门数量: 将LSTM的三个门合并为两个门

      • 更新门(z_t): 融合了LSTM中遗忘门输入门的功能。它决定了有多少旧的隐藏状态信息需要保留,以及有多少新的候选隐藏状态信息需要加入。

      • 重置门(r_t): 控制前一个隐藏状态h_{t-1}对计算新的候选隐藏状态的影响程度。它决定了在生成候选状态时,应该“重置”或忽略多少过去的信息。

    3. 简化计算流程: 合并状态和门减少了计算步骤和参数数量,通常训练更快,并且在许多任务上表现与LSTM相当甚至有时更好。

GRU单元的内部结构与计算流程(关键!)

想象一个GRU单元在时间步 t 的处理过程。它接收两个输入:

  1. 当前时间步的输入: x_t (维度 input_dim)

  2. 前一时间步的隐藏状态: h_{t-1} (维度 hidden_dim)

它产生一个输出:

  1. 当前时间步的隐藏状态: h_t (维度 hidden_dim)

单元内部的计算涉及以下步骤:

  1. 更新门:

    • z_t = σ(W_z · [h_{t-1}, x_t] + b_z)

    • σ 是 Sigmoid 激活函数(输出 0 到 1)。

    • W_z 是更新门的权重矩阵 (维度 hidden_dim x (hidden_dim + input_dim))。

    • [h_{t-1}, x_t] 表示将 h_{t-1} 和 x_t 拼接成一个向量 (维度 hidden_dim + input_dim)。

    • b_z 是更新门的偏置向量 (维度 hidden_dim)。

    • z_t 的每个元素在 0 到 1 之间,值接近 1 表示倾向于保留更多旧状态 h_{t-1},值接近 0 表示倾向于采用更多新候选状态 h̃_t

  2. 重置门:

    • r_t = σ(W_r · [h_{t-1}, x_t] + b_r)

    • W_r 是重置门的权重矩阵 (维度 hidden_dim x (hidden_dim + input_dim))。

    • b_r 是重置门的偏置向量 (维度 hidden_dim)。

    • r_t 的每个元素在 0 到 1 之间,值接近 0 表示“重置”(忽略)前一个隐藏状态 h_{t-1},值接近 1 表示“保留”前一个隐藏状态 h_{t-1}。它主要用于控制h_{t-1}在计算候选状态时的贡献。

  3. 候选隐藏状态:

    • h̃_t = tanh(W_h · [r_t * h_{t-1}, x_t] + b_h)

    • tanh 激活函数将值压缩到 -1 到 1 之间。

    • W_h 是候选隐藏状态的权重矩阵 (维度 hidden_dim x (hidden_dim + input_dim))。

    • b_h 是候选隐藏状态的偏置向量 (维度 hidden_dim)。

    • [r_t * h_{t-1}, x_t] 表示将 r_t 与 h_{t-1} 逐元素相乘的结果 和 x_t 拼接起来。

      • 这是GRU的关键操作之一!r_t * h_{t-1} 表示根据重置门有选择地“过滤”前一个隐藏状态的信息。如果 r_t 接近 0,相当于在计算候选状态时忽略了 h_{t-1},只基于当前输入 x_t(和偏置)进行计算;如果 r_t 接近 1,则完整保留 h_{t-1} 的信息用于计算新候选状态。

    • h̃_t 表示基于当前输入 x_t 和经过重置门筛选后的前一个状态 r_t * h_{t-1} 计算出的新的、候选的隐藏状态。

  4. 计算当前隐藏状态:

    • h_t = (1 - z_t) * h̃_t + z_t * h_{t-1}

    • * 表示逐元素乘法

    • 这是GRU的核心操作,也是更新门发挥作用的地方:

      • z_t * h_{t-1}: 表示保留多少旧状态 h_{t-1}

      • (1 - z_t) * h̃_t: 表示加入多少新候选状态 h̃_t

      • 新的隐藏状态 h_t 是旧状态 h_{t-1} 和候选新状态 h̃_t 的线性插值,由更新门 z_t 控制比例。

      • 如果 z_t 接近 1,则 h_t ≈ h_{t-1}(几乎完全保留旧状态,忽略当前输入)。

      • 如果 z_t 接近 0,则 h_t ≈ h̃_t(几乎完全采用基于当前输入和重置后状态计算的新候选状态)。

可视化表示(简化)

GRU的内部参数详解

从上面的计算过程可以看出,一个标准的GRU单元包含以下参数:

  1. 权重矩阵 (Weights): 共有 3 组,分别对应更新门、重置门、候选隐藏状态。

    • W_z: 更新门的权重矩阵 (维度: hidden_dim x (hidden_dim + input_dim))

    • W_r: 重置门的权重矩阵 (维度: hidden_dim x (hidden_dim + input_dim))

    • W_h: 候选隐藏状态的权重矩阵 (维度: hidden_dim x (hidden_dim + input_dim))

  2. 偏置向量 (Biases): 共有 3 组,与权重矩阵一一对应。

    • b_z: 更新门的偏置向量 (维度: hidden_dim)

    • b_r: 重置门的偏置向量 (维度: hidden_dim)

    • b_h: 候选隐藏状态的偏置向量 (维度: hidden_dim)

重要说明

  • 参数共享: 同一个GRU层中的所有时间步 t 共享同一套参数 (W_zW_rW_hb_zb_rb_h)。这是循环神经网络的核心特性。

  • 参数总量计算: 对于一个GRU层:

    • 总参数量 = 3 * [hidden_dim * (hidden_dim + input_dim) + hidden_dim]

    • 简化: 3 * (hidden_dim * hidden_dim + hidden_dim * input_dim + hidden_dim) = 3 * (hidden_dim^2 + hidden_dim * input_dim + hidden_dim)

    • 与LSTM对比: GRU的参数数量是LSTM的 3/4 (75%)。例如:input_dim=100hidden_dim=256:

      • LSTM参数量: 4 * (256^2 + 256*100 + 256) = 4 * 91392 = 365, 568

      • GRU参数量: 3 * (256^2 + 256*100 + 256) = 3 * 91392 = 274, 176

      • 减少了 91, 392 个参数 (约25%)。

  • 输入维度: input_dim 是输入数据 x_t 的特征维度。

  • 隐藏层维度: hidden_dim 是一个超参数,决定了:

    • 隐藏状态 h_t、更新门 z_t、重置门 r_t、候选状态 h̃_t 的维度。

    • 模型的容量。更大的 hidden_dim 通常能学习更复杂的模式,但也需要更多计算资源和数据。

  • 激活函数:

    • 更新门(z_t)和重置门(r_t): 使用 Sigmoid (σ),输出0-1,控制信息流比例。

    • 候选隐藏状态(h̃_t): 使用 tanh,将值规范到-1到1之间,提供非线性变换。

  • 关键操作解读:

    • 重置门(r_t): 作用于计算候选状态 h̃_t 之前。它决定在生成新的候选信息时,应该考虑多少过去的状态 h_{t-1}。如果模型发现 h_{t-1} 与预测未来无关(例如,遇到句子边界或主题切换),它可以学习将 r_t 设置为接近0,从而在计算 h̃_t 时“重置”或忽略 h_{t-1},主要依赖当前输入 x_t

    • 更新门(z_t): 作用于生成最终隐藏状态 h_t 时。它决定了新的 h_t 应该由多少旧状态 h_{t-1} 和多少新候选状态 h̃_t 组成。这类似于LSTM中遗忘门(保留多少旧细胞状态)和输入门(添加多少新候选细胞状态)的组合功能。一个接近1的 z_t 允许信息在隐藏状态中长期保留(缓解梯度消失),一个接近0的 z_t 则使隐藏状态快速更新为新的信息。

GRU如何解决长期依赖问题?

  1. 更新门是关键: 公式 h_t = (1 - z_t) * h̃_t + z_t * h_{t-1} 是核心。这个加法操作 (+) 允许梯度在 h_t 直接流向 h_{t-1} 时相对稳定地流动(类似于LSTM细胞状态中的加法)。反向传播时,梯度 ∂h_t / ∂h_{t-1} 包含 z_t 项(可能接近1)。只要网络能够学习到在需要长期记忆的位置让 z_t 接近1,梯度就可以几乎无损地流过许多时间步。

  2. 门控机制赋予选择性:

    • 选择性重置: 重置门 r_t 允许模型在计算新的候选信息时,有选择地丢弃与当前计算无关的过去信息

    • 选择性更新: 更新门 z_t 允许模型有选择地将新的相关信息(来自 h̃_t)融合进隐藏状态,同时保留相关的长期信息(来自 h_{t-1})。

  3. 参数效率: 更少的参数意味着模型更容易训练(尤其是在数据量有限时),收敛可能更快,且计算开销更低,同时通常能达到与LSTM相当的性能。

GRU vs LSTM:主要区别总结

特性LSTMGRU
状态数量两个:细胞状态 C_t + 隐藏状态 h_t一个:隐藏状态 h_t
门数量三个:遗忘门 f_t, 输入门 i_t, 输出门 o_t两个:更新门 z_t, 重置门 r_t
核心操作C_t = f_t * C_{t-1} + i_t * g_t
h_t = o_t * tanh(C_t)
h_t = (1 - z_t) * h̃_t + z_t * h_{t-1}
h̃_t = tanh(W·[r_t * h_{t-1}, x_t] + b)
参数数量4组权重矩阵 + 4组偏置 (≈4h(h+d+h))3组权重矩阵 + 3组偏置 (≈3h(h+d+h)) (比LSTM少25%)
计算效率相对较高相对较低 (更少的参数和计算步骤)
性能在大多数任务上非常强大,尤其超长序列大多数任务上与LSTM性能相当或接近,有时略优或略劣,在中小型数据集上有时表现更好
输出h_t (可能用于预测)h_t (直接作为输出和下一时间步的输入)

选择建议:

  • 计算资源/时间敏感: 优先考虑GRU(更快,更少参数)。

  • 任务性能至上(尤其是超长序列): 两种都试试,LSTM有时在极端长序列任务中更鲁棒(得益于独立的细胞状态),但差异通常不大。

  • 数据集较小: GRU可能更有优势(更少参数,降低过拟合风险)。

  • 实践: 在很多现代应用中(如Transformer之前的RNN时代),GRU因其效率成为LSTM的有力竞争者。最佳选择通常需要通过实验在具体任务和数据集上验证。

总结

GRU通过合并细胞状态与隐藏状态以及将三个门简化为两个门(更新门和重置门),创造了一种比LSTM更简洁高效的循环神经网络结构。其核心在于:

  1. 更新门(z_t): 控制新隐藏状态 h_t 由多少旧状态 h_{t-1} 和多少新候选状态 h̃_t 组成,是维持长期依赖的关键。

  2. 重置门(r_t): 控制前一个状态 h_{t-1} 在计算新候选状态 h̃_t 时的影响程度,实现有选择的信息重置。

  3. 候选状态(h̃_t): 基于当前输入 x_t 和经过重置门筛选的前状态 r_t * h_{t-1} 计算得出的潜在新状态。

GRU的参数包括三组权重矩阵(W_zW_rW_h)和对应的偏置向量(b_zb_rb_h),其总参数量约为LSTM的75%。这种结构上的简化使得GRU通常训练更快、计算开销更小,并且在广泛的序列建模任务中展现出与LSTM相当甚至有时更优的性能,成为处理长期依赖问题的一种强大而实用的工具。理解GRU的门控机制和参数作用,对于有效使用和调优模型至关重要。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/87230.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/87230.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/87230.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025年数字信号、计算机通信与软件工程国际会议(DSCCSE 2025)

2025年数字信号、计算机通信与软件工程国际会议(DSCCSE 2025) 2025 International Conference on Digital Signal, Computer Communication, and Software Engineering 一、大会信息 会议简称:DSCCSE 2025 大会地点:中国北京 审稿…

北峰智能SDC混合组网通信方案,助力无网络场景高效作业

在自然灾害、公共安全事件或大规模活动应急响应中,专用无线对讲通信因其不受外部网络限制、免去通话费用、无需拨号便可实现即时语音调度的特点,展现出其不可替代的价值。尤其在许多无基础设施的地区,对智能化调度管理的需求并不亚于城市地区…

HarmonyOS应用开发高级认证知识点梳理 (二) 组件交互

以下是 HarmonyOS 应用开发中 ‌组件交互‌ 的核心知识点梳理(高级认证备考重点),涵盖事件传递、状态管理、通信机制及生命周期协同: 一、事件处理机制 基础交互类型‌ (1)点击事件(onClick) 核心要点‌…

【SQL优化案例】索引创建不合理导致SQL消耗大量CPU资源

#隐式转换 第一章 适用环境 oracle 11glinux 6.9 第二章 Top SQL概况 下面列出我们发现的特定模块中Top SQL的相关情况: SQL_ID 模块 SQL类型 主要问题 fnc58puaqkd1n 无 select 索引创建不合理,导致全索引扫描,产生了大量逻辑读 …

autoas/as 工程的RTE静态消息总线实现与端口数据交换机制详解

0. 概述 autoas/as 工程的RTE(Runtime Environment)通过自动生成C代码,将各SWC(软件组件)之间的数据通信全部静态化、结构化,实现了类似“静态消息总线”的通信模型。所有端口的数据交换都必须经过RTE接口…

【机器学习第四期(Python)】LightGBM 方法原理详解

LightGBM 概述 一、LightGBM 简介二、LightGBM 原理详解⚙️ 核心原理🧠 LightGBM 的主要特点 三、LightGBM 实现步骤(Python)🧪 可调参数推荐完整案例代码(回归任务 可视化) 参考 LightGBM 是由微软开源的…

时序数据库IoTDB监控指标采集与可视化指南

一、概述 本文以时序数据库IoTDB V1.0.1版本为例,介绍如何通过Prometheus采集Apache IoTDB的监控指标,并使用Grafana进行可视化。 二、Prometheus聚合运算符 Prometheus支持多种聚合运算符,用于在时间序列数据上进行聚合操作。以下是一些常…

React安装使用教程

一、React 简介 React 是由 Facebook 开发和维护的一个用于构建用户界面的 JavaScript 库,适用于构建复杂的单页应用(SPA)。它采用组件化、虚拟 DOM 和声明式编程等理念,已成为前端开发的主流选择。 二、React 安装方式 2.1 使用…

.NET MAUI跨平台串口通讯方案

文章目录 MAUI项目架构设计平台特定实现接口定义Windows平台实现Android平台实现 MAUI主界面实现依赖注入配置相关学习资源.NET MAUI开发移动端开发平台特定实现依赖注入与架构移动应用发布跨平台开发最佳实践性能优化测试与调试开源项目参考 MAUI项目架构设计 #mermaid-svg-OG…

BUUCTF在线评测-练习场-WebCTF习题[MRCTF2020]你传你[特殊字符]呢1-flag获取、解析

解题思路 打开靶场&#xff0c;左边是艾克&#xff0c;右边是诗人&#xff0c;下面有个文件上传按钮 结合题目&#xff0c;是一个文件上传漏洞&#xff0c;一键去世看源码可知是提交按钮&#xff0c;先上传个一句话木马.php试试 <?php eval($_POST[shell]); ?> 被过…

【容器】容器平台初探 - k8s整体架构

目录 K8s总揽 K8s主要组件 组件说明 一、Master组件 二、WokerNode组件 K8s是Kubernetes的简称&#xff0c;它是Google的开源容器集群管理系统&#xff0c;其提供应用部署、维护、扩展机制等功能&#xff0c;利用k8s能很方便地管理跨机器运行容器化的应用。 K8s总揽 K8s主…

C++--继承

文章目录 继承1. 继承的概念及定义1.1 继承的概念1.2 继承的定义1.2.1 定义格式1.2.2 继承方式和访问限定符1.2.3 继承基类成员访问方式的变化1.2.3.1 基类成员访问方式的变化规则1.2.3.2 默认继承方式 1.3 继承类模版 2. 基类和派生类的转化3. 继承中的作用域3.1 隐藏3.2 经典…

无REPOSITORY、TAG的docker悬空镜像究竟是什么?是否可删除?

有时候&#xff0c;使用docker images指令我们可以发现大量的无REPOSITORY、TAG的docker镜像&#xff0c;这些镜像究竟是什么&#xff1f; 它们没有REPOSITORY、TAG名称&#xff0c;没有办法引用&#xff0c;那么它们还有什么用&#xff1f; [rootcdh-100 data]# docker image…

创建一个基于YOLOv8+PyQt界面的驾驶员疲劳驾驶检测系统 实现对驾驶员疲劳状态的打哈欠检测,头部下垂 疲劳眼睛检测识别

如何使用Yolov8创建一个基于YOLOv8的驾驶员疲劳驾驶检测系统 文章目录 1. 数据集准备2. 安装依赖3. 创建PyQt界面4. 模型训练1. 数据集准备2. 模型训练数据集配置文件 (data.yaml)训练脚本 (train.py) 3. PyQt界面开发主程序 (MainProgram.py) 4. 运行项目5. 关键代码解释数据集…

使用FFmpeg将YUV编码为H.264并封装为MP4,通过api接口实现

YUV数据来源 摄像头直接采集的原始视频流通常为YUV格式&#xff08;如YUV420&#xff09;&#xff0c;尤其是安防摄像头和网络摄像头智能手机、平板电脑的摄像头通过硬件接口视频会议软件&#xff08;如Zoom、腾讯会议&#xff09;从摄像头捕获YUV帧&#xff0c;进行预处理&am…

tcpdump工具交叉编译

本文默认系统已经安装了交叉工具链环境。 下载相关版本源码 涉及tcpdump源码&#xff0c;以及tcpdump编译过程依赖的pcap库源码。 网站&#xff1a;http://www.tcpdump.org/release wget http://www.tcpdump.org/release/libpcap-1.8.1.tar.gz wget http://www.tcpdump.org/r…

神经网络中torch.nn的使用

卷积层 通过卷积核&#xff08;滤波器&#xff09;在输入数据上滑动&#xff0c;卷积层能够自动检测和提取局部特征&#xff0c;如边缘、纹理、颜色等。不同的卷积核可以捕捉不同类型的特征。 nn.conv2d() in_channels:输入的通道数&#xff0c;彩色图片一般为3通道 out_c…

在MATLAB中使用GPU加速计算及多GPU配置

文章目录 在MATLAB中使用GPU加速计算及多GPU配置一、基本GPU加速使用1. 检查GPU可用性2. 将数据传输到GPU3. 执行GPU计算 二、多GPU配置与使用1. 选择特定GPU设备2. 并行计算工具箱中的多GPU支持3. 数据并行处理&#xff08;适用于深度学习&#xff09; 三、高级技巧1. 异步计算…

【unitrix】 4.12 通用2D仿射变换矩阵(matrix/types.rs)

一、源码 这段代码定义了一个通用的2D仿射变换矩阵结构&#xff0c;可用于表示二维空间中的各种线性变换。 /// 通用2D仿射变换矩阵&#xff08;元素仅需实现Copy trait&#xff09; /// /// 该矩阵可用于表示二维空间中的任意仿射变换&#xff0c;支持以下应用场景&#xff…

android RecyclerView隐藏整个Item后,该Item还占位留白问题

前言 android RecyclerView隐藏整个Item后,该Item还占位留白问题 思考了利用隐藏和现实来控制item 结果实现不了方案 解决方案 要依据 model 的第三个参数&#xff08;布尔值&#xff09;决定是否保留数据&#xff0c;可以通过 ​filter 高阶函数结合 ​空安全操作符​ 实…