循环神经网络(Recurrent Neural Network, RNN)是一类专门处理序列数据的神经网络,如时间序列、自然语言、音频等。与前馈神经网络不同,RNN 引入了循环结构,能够捕捉序列中的时序信息,使模型在不同时间步之间共享参数。这种结构赋予了 RNN 处理变长输入、保留历史信息的能力,成为序列建模的强大工具。

RNN 的基本原理与核心结构

传统神经网络在处理序列数据时,无法利用序列中的时序依赖关系。RNN 通过在网络中引入循环连接,使得信息可以在不同时间步之间传递。

1. 简单 RNN 的数学表达

在时间步t,RNN 的隐藏状态\(h_t\)的计算如下:

\(h_t = \sigma(W_{hh}h_{t-1} + W_{xh}x_t + b)\)

其中,\(x_t\)是当前时间步的输入,\(h_{t-1}\)是上一时间步的隐藏状态,\(W_{hh}\)和\(W_{xh}\)是权重矩阵,b是偏置,\(\sigma\)是非线性激活函数(如 tanh 或 ReLU)。

2. RNN 的展开结构

虽然 RNN 在结构上包含循环,但在计算时通常将其展开为一个时间步序列。这种展开视图更清晰地展示了 RNN 如何处理序列数据:

plaintext

x1    x2    x3    ...   xT
|     |     |           |
v     v     v           v
h0 -> h1 -> h2 -> ... -> hT
|     |     |           |
v     v     v           v
y1    y2    y3    ...   yT

其中,\(h_0\)通常初始化为零向量,\(y_t\)是时间步t的输出(如果需要)。

3. RNN 的局限性

简单 RNN 虽然能够处理序列数据,但存在严重的梯度消失或梯度爆炸问题,导致难以学习长距离依赖关系。这限制了它在处理长序列时的性能。

长短期记忆网络(LSTM)与门控循环单元(GRU)

为了解决简单 RNN 的局限性,研究人员提出了更复杂的门控机制,主要包括 LSTM 和 GRU。

1. 长短期记忆网络(LSTM)

LSTM 通过引入遗忘门、输入门和输出门,有效控制信息的流动:

\(\begin{aligned} f_t &= \sigma(W_f[h_{t-1}, x_t] + b_f) \\ i_t &= \sigma(W_i[h_{t-1}, x_t] + b_i) \\ o_t &= \sigma(W_o[h_{t-1}, x_t] + b_o) \\ \tilde{C}_t &= \tanh(W_C[h_{t-1}, x_t] + b_C) \\ C_t &= f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \\ h_t &= o_t \odot \tanh(C_t) \end{aligned}\)

其中,\(f_t\)、\(i_t\)、\(o_t\)分别是遗忘门、输入门和输出门,\(C_t\)是细胞状态,\(\odot\)表示逐元素乘法。

2. 门控循环单元(GRU)

GRU 是 LSTM 的简化版本,合并了遗忘门和输入门,并将细胞状态和隐藏状态合并:

\(\begin{aligned} z_t &= \sigma(W_z[h_{t-1}, x_t] + b_z) \\ r_t &= \sigma(W_r[h_{t-1}, x_t] + b_r) \\ \tilde{h}_t &= \tanh(W_h[r_t \odot h_{t-1}, x_t] + b_h) \\ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \end{aligned}\)

其中,\(z_t\)是更新门,\(r_t\)是重置门。

RNN 的典型应用场景

RNN 在各种序列建模任务中取得了广泛应用:

  1. 自然语言处理:机器翻译、文本生成、情感分析、命名实体识别等。
  2. 语音识别:将语音信号转换为文本。
  3. 时间序列预测:股票价格预测、天气预测等。
  4. 视频分析:动作识别、视频描述生成。
  5. 音乐生成:自动作曲。
使用 PyTorch 实现 RNN 进行文本分类

下面我们使用 PyTorch 实现一个基于 LSTM 的文本分类模型,使用 IMDB 电影评论数据集进行情感分析。

python

运行

import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.legacy import data, datasets
import random
import numpy as np# 设置随机种子,保证结果可复现
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True# 定义字段
TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm', include_lengths=True)
LABEL = data.LabelField(dtype=torch.float)# 加载IMDB数据集
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)# 创建验证集
train_data, valid_data = train_data.split(random_state=random.seed(SEED))# 构建词汇表
MAX_VOCAB_SIZE = 25000
TEXT.build_vocab(train_data, max_size=MAX_VOCAB_SIZE, vectors="glove.6B.100d")
LABEL.build_vocab(train_data)# 创建迭代器
BATCH_SIZE = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_data, valid_data, test_data), batch_size=BATCH_SIZE,sort_within_batch=True,device=device)# 定义LSTM模型
class LSTMClassifier(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout, pad_idx):super().__init__()# 嵌入层self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)# LSTM层self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout)# 全连接层self.fc = nn.Linear(hidden_dim * 2, output_dim)# Dropout层self.dropout = nn.Dropout(dropout)def forward(self, text, text_lengths):# text = [sent len, batch size]# 应用dropout到嵌入层embedded = self.dropout(self.embedding(text))# embedded = [sent len, batch size, emb dim]# 打包序列packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'))# 通过LSTM层packed_output, (hidden, cell) = self.lstm(packed_embedded)# 展开序列output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)# output = [sent len, batch size, hid dim * num directions]# hidden = [num layers * num directions, batch size, hid dim]# cell = [num layers * num directions, batch size, hid dim]# 我们使用双向LSTM的最终隐藏状态hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))# hidden = [batch size, hid dim * num directions]return self.fc(hidden)# 初始化模型
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 256
OUTPUT_DIM = 1
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.5
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]model = LSTMClassifier(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT, PAD_IDX)# 加载预训练的词向量
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)# 优化器和损失函数
optimizer = optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss()model = model.to(device)
criterion = criterion.to(device)# 准确率计算函数
def binary_accuracy(preds, y):"""Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8"""# 四舍五入预测值rounded_preds = torch.round(torch.sigmoid(preds))correct = (rounded_preds == y).float()  # 转换为float计算准确率acc = correct.sum() / len(correct)return acc# 训练函数
def train(model, iterator, optimizer, criterion):epoch_loss = 0epoch_acc = 0model.train()for batch in iterator:optimizer.zero_grad()text, text_lengths = batch.textpredictions = model(text, text_lengths).squeeze(1)loss = criterion(predictions, batch.label)acc = binary_accuracy(predictions, batch.label)loss.backward()optimizer.step()epoch_loss += loss.item()epoch_acc += acc.item()return epoch_loss / len(iterator), epoch_acc / len(iterator)# 评估函数
def evaluate(model, iterator, criterion):epoch_loss = 0epoch_acc = 0model.eval()with torch.no_grad():for batch in iterator:text, text_lengths = batch.textpredictions = model(text, text_lengths).squeeze(1)loss = criterion(predictions, batch.label)acc = binary_accuracy(predictions, batch.label)epoch_loss += loss.item()epoch_acc += acc.item()return epoch_loss / len(iterator), epoch_acc / len(iterator)# 训练模型
N_EPOCHS = 5best_valid_loss = float('inf')for epoch in range(N_EPOCHS):train_loss, train_acc = train(model, train_iterator, optimizer, criterion)valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)if valid_loss < best_valid_loss:best_valid_loss = valid_losstorch.save(model.state_dict(), 'lstm-model.pt')print(f'Epoch: {epoch+1:02}')print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')# 测试模型
model.load_state_dict(torch.load('lstm-model.pt'))
test_loss, test_acc = evaluate(model, test_iterator, criterion)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

RNN 的挑战与发展趋势

尽管 RNN 在序列建模中取得了成功,但仍面临一些挑战:

  1. 长序列处理困难:即使是 LSTM 和 GRU,在处理极长序列时仍有困难。
  2. 并行计算能力有限:RNN 的时序依赖性导致难以高效并行化。
  3. 注意力机制的兴起:注意力机制可以更灵活地捕获序列中的长距离依赖,减少对完整历史的依赖。

近年来,RNN 的发展趋势包括:

  1. 注意力机制与 Transformer:注意力机制和 Transformer 架构在许多序列任务中取代了传统 RNN,如 BERT、GPT 等模型。
  2. 混合架构:结合 RNN 和注意力机制的优点,如 Google 的 T5 模型。
  3. 少样本学习与迁移学习:利用预训练模型(如 XLNet、RoBERTa)进行微调,减少对大量标注数据的需求。
  4. 神经图灵机与记忆网络:增强 RNN 的记忆能力,使其能够处理更复杂的推理任务。

循环神经网络为序列数据处理提供了强大的工具,尽管面临一些挑战,但通过不断的研究和创新,RNN 及其变体仍在众多领域发挥着重要作用,并将继续推动序列建模技术的发展。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/82550.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/82550.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/82550.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java 项目登录请求业务解耦模块全面

登录是统一的闸机&#xff1b; 密码存在数据库中&#xff0c;用的是密文&#xff0c;后端加密&#xff0c;和数据库中做对比 1、UserController public class UserController{Autowiredprivate IuserService userservicepublic JsonResult login(Validated RequestBody UserLo…

【手写数据库核心揭秘系列】第9节 可重入的SQL解析器,不断解析Structure Query Language,语言翻译好帮手

可重入的SQL解析器 文章目录 可重入的SQL解析器一、概述 二、可重入解析器 2.1 可重入设置 2.2 记录状态的数据结构 2.3 节点数据类型定义 2.4 头文件引用 三、调整后的程序结构 四、总结 一、概述 现在就来修改之前sqlscanner.l和sqlgram.y程序,可以不断输入SQL语句,循环执…

微软开源bitnet b1.58大模型,应用效果测评(问答、知识、数学、逻辑、分析)

微软开源bitnet b1.58大模型,应用效果测评(问答、知识、数学、逻辑、分析) 目 录 1. 前言... 2 2. 应用部署... 2 3. 应用效果... 3 1.1 问答方面... 3 1.2 知识方面... 4 1.3 数字运算... 6 1.4 逻辑方面... …

用HTML5+JavaScript实现汉字转拼音工具

用HTML5JavaScript实现汉字转拼音工具 前一篇博文&#xff08;https://blog.csdn.net/cnds123/article/details/148067680&#xff09;提到&#xff0c;当需要将拼音添加到汉字上面时&#xff0c;用python实现比HTML5JavaScript实现繁琐。在这篇博文中用HTML5JavaScript实现汉…

鸿蒙OSUniApp 开发的动态背景动画组件#三方框架 #Uniapp

使用 UniApp 开发的动态背景动画组件 前言 在移动应用开发中&#xff0c;动态背景动画不仅能提升界面美感&#xff0c;还能增强用户的沉浸感和品牌辨识度。无论是登录页、首页还是活动页&#xff0c;恰到好处的动态背景都能让产品脱颖而出。随着鸿蒙&#xff08;HarmonyOS&am…

云原生技术架构技术探索

文章目录 前言一、什么是云原生技术架构二、云原生技术架构的优势三、云原生技术架构的应用场景结语 前言 在当今的技术领域&#xff0c;云原生技术架构正以一种势不可挡的姿态席卷而来&#xff0c;成为了众多开发者、企业和技术爱好者关注的焦点。那么&#xff0c;究竟什么是…

AWS之AI服务

目录 一、AWS AI布局 ​​1. 底层基础设施与芯片​​ ​​2. AI训练框架与平台​​ ​​3. 大模型与应用层​​ ​​4. 超级计算与网络​​ ​​与竞品对比​​ AI服务 ​​1. 机器学习平台​​ ​​2. 预训练AI服务​​ ​​3. 边缘与物联网AI​​ ​​4. 数据与AI…

lwip_bind、lwip_listen 是阻塞函数吗

在 lwIP 协议栈中&#xff0c;lwip_bind 和 lwip_listen 函数本质上是非阻塞的。 通常&#xff0c;bind和listen在大多数实现中都是非阻塞的&#xff0c;因为它们只是设置套接字的属性&#xff0c;不需要等待外部事件。阻塞通常发生在接受连接&#xff08;accept&#xff09;、…

【后端高阶面经:消息队列篇】28、从零设计高可用消息队列

一、消息队列架构设计的核心目标与挑战 设计高性能、高可靠的消息队列需平衡功能性与非功能性需求,解决分布式系统中的典型问题。 1.1 核心设计目标 吞吐量:支持百万级消息/秒处理,通过分区并行化实现横向扩展。延迟:端到端延迟控制在毫秒级,适用于实时业务场景。可靠性…

【运维实战】Linux 内存调优之进程内存深度监控

写在前面 内容涉及 Linux 进程内存监控 监控方式包括传统工具 ps/top/pmap ,以及 cgroup 内存子系统&#xff0c;proc 内存伪文件系统 监控内容包括进程内存使用情况&#xff0c; 内存全局数据统计&#xff0c;内存事件指标&#xff0c;以及进程内存段数据监控 监控进程的内…

决策树 GBDT XGBoost LightGBM

一、决策树 1. 决策树有一个很强的假设&#xff1a; 信息是可分的&#xff0c;否则无法进行特征分支 2. 决策树的种类&#xff1a; 2. ID3决策树&#xff1a; ID3决策树的数划分标准是信息增益&#xff1a; 信息增益衡量的是通过某个特征进行数据划分前后熵的变化量。但是&…

java基础学习(十四)

文章目录 4-1 面向过程与面向对象4-2 Java语言的基本元素&#xff1a;类和对象面向对象的思想概述 4-3 对象的创建和使用内存解析匿名对象 4-1 面向过程与面向对象 面向过程(POP) 与 面向对象(OOP) 二者都是一种思想&#xff0c;面向对象是相对于面向过程而言的。面向过程&…

TCP 三次握手,第三次握手报文丢失会发生什么?

文章目录 RTO(Retransmission Timeout)注意 客户端收到服务端的 SYNACK 报文后&#xff0c;会回给服务端一个 ACK 报文&#xff0c;之后处于 ESTABLISHED 状态 因为第三次握手的 ACK 是对第二次握手中 SYN 的确认报文&#xff0c;如果第三次握手报文丢失了&#xff0c;服务端就…

deepseek告诉您http与https有何区别?

有用户经常问什么是Http , 什么是Https &#xff1f; 两者有什么区别&#xff0c;下面为大家介绍一下两者的区别 一、什么是HTTP HTTP是一种无状态的应用层协议&#xff0c;用于在客户端浏览器和服务器之间传输网页信息&#xff0c;默认使用80端口 二、HTTP协议的特点 HTTP协议…

openresty如何禁止海外ip访问

前几天&#xff0c;我有一个徒弟问我&#xff0c;如何禁止海外ip访问他的网站系统&#xff1f;操作系统采用的是centos7.9&#xff0c;发布服务采用的是openresty。通过日志他发现&#xff0c;有很多类似以下数据 {"host":"172.30.7.95","clientip&q…

理解 Redis 事务-20 (MULTI、EXEC、DISCARD)

理解 Redis 事务&#xff1a;MULTI、EXEC、DISCARD Redis 事务允许你将一组命令作为一个单一的原子操作来执行。这意味着事务中的所有命令要么全部执行&#xff0c;要么全部不执行。这对于在需要一起执行多个操作时保持数据完整性至关重要。本课程将涵盖 Redis 事务的基础知识…

Milvus分区-分片-段结构详解与最佳实践

导读&#xff1a;在构建大规模向量数据库应用时&#xff0c;数据组织架构的设计往往决定了系统的性能上限。Milvus作为主流向量数据库&#xff0c;其独特的三层架构设计——分区、分片、段&#xff0c;为海量向量数据的高效存储和检索提供了坚实基础。 本文通过图书馆管理系统的…

Kettle 远程mysql 表导入到 hadoop hive

kettle 远程mysql 表导入到 hadoop hive &#xff08;教学用 &#xff09; 文章目录 kettle 远程mysql 表导入到 hadoop hive创建 对象 执行 SQL 语句 -mysql 导出 CSV格式CSV 文件远程上传到 HDFS运行 SSH 命令远程登录 run SSH 并执行 hadoop fs -put 建表和加载数据总结 创…

Linux输出命令——echo解析

摘要 全面解析Linux echo命令核心功能&#xff0c;涵盖文本输出、变量解析、格式控制及高级技巧&#xff0c;助力提升Shell脚本开发与终端操作效率。 一、核心功能与定位 作为Shell脚本开发的基础工具&#xff0c;echo命令承担着信息输出与数据传递的重要角色。其主要功能包…

Windows系统下 NVM 安装 Node.js 及版本切换实战指南

以下是 Windows 11 系统下使用 NVM 安装 Node.js 并实现版本自由切换的详细步骤&#xff1a; 一、安装 NVM&#xff08;Node Version Manager&#xff09; 1. 卸载已有 Node.js 如果已安装 Node.js&#xff0c;请先卸载&#xff1a; 控制面板 ➔ 程序与功能 ➔ 找到 Node.js…