用深度学习(LSTM)实现时间序列预测:从数据到闭环预测全解析

时间序列预测是工业、金融、环境等领域的核心需求——小到预测设备温度波动,大到预测股价走势,都需要从历史数据中挖掘时序规律。长短期记忆网络(LSTM)凭借对“长期依赖关系”的捕捉能力,成为时序预测的主流模型之一。

本文将基于MATLAB深度学习工具箱,以波形数据集(WaveformData) 为例,完整拆解LSTM时间序列预测的实现流程,重点讲解“闭环预测”的核心逻辑(用前一次预测结果作为下一次输入,无需真实值即可多步预测),并对代码逐行、参数逐个进行解析。

一、整体背景:LSTM与两种预测模式

LSTM是一种循环神经网络(RNN),通过“门控机制”(遗忘门、输入门、输出门)动态更新“隐藏状态”,从而记住序列中的关键历史信息,避免普通RNN的“梯度消失”问题。

时序预测有两种核心模式,也是本文的重点对比对象:

  • 开环预测:每次预测都需要“真实的历史数据”作为输入(比如预测第t步需要第t-1步的真实值),适合能实时获取真实数据的场景。
  • 闭环预测:仅用初始真实数据初始化,后续预测完全依赖“前一次的预测结果”作为输入(无需真实值),适合需要一次性预测多步未来、或无法获取实时真实数据的场景(如预测未来200天的温度)。

本文将从数据加载到闭环预测,一步步实现完整流程。

二、完整实现流程与代码解析

1. 第一步:加载与探索数据

首先加载示例数据集,了解数据结构,为后续处理做准备。

代码与逐行解析
% 加载波形数据集(MATLAB内置示例数据)
load WaveformData% 查看前5个序列的结构(数据是cell数组,每个元素是一个序列)
data(1:5)% 计算序列的通道数(所有序列通道数一致,才能训练网络)
numChannels = size(data{1},1)% 可视化前4个序列(堆叠图展示多通道)
figure
tiledlayout(2,2)  % 创建2x2的子图布局
for i = 1:4nexttile  % 激活下一个子图stackedplot(data{i}')  % 转置序列:让时间步为x轴,通道为y轴xlabel("Time Step")  % x轴标签:时间步
end% 划分训练集与测试集(9:1拆分)
numObservations = numel(data);  % 总序列数(data是cell数组,numel取元素个数)
idxTrain = 1:floor(0.9*numObservations);  % 训练集索引(前90%)
idxTest = floor(0.9*numObservations)+1:numObservations;  % 测试集索引(后10%)
dataTrain = data(idxTrain);  % 训练集序列
dataTest = data(idxTest);    % 测试集序列
关键参数与概念
  • WaveformData:MATLAB内置的合成波形数据集,结构为numObservations×1的cell数组,每个cell元素是numChannels×numTimeSteps的矩阵(numChannels=3,即每个时间步有3个特征;numTimeSteps为序列长度,不同序列长度不同)。
  • stackedplot:堆叠图函数,适合展示多通道时序数据(每个通道一条线,避免重叠)。
  • 数据划分逻辑:9:1拆分是时序预测的常用比例,既保证训练集足够大(学习规律),又保留测试集(评估泛化能力)。

2. 第二步:准备训练数据(核心:移位目标序列+归一化)

LSTM训练需要“输入-目标”配对的监督数据。时序预测的核心技巧是:输入为“去掉最后一个时间步的序列”,目标为“移位一个时间步的序列”,让LSTM学习“当前时间步→下一个时间步”的映射关系。

同时,为避免训练发散、提升收敛速度,需要对数据做“零均值单位方差”归一化。

代码与逐行解析
% 1. 构建训练集的“输入-目标”配对(移位序列)
for n = 1:numel(dataTrain)  % 遍历每个训练序列X = dataTrain{n};       % 取第n个训练序列(numChannels×numTimeSteps)XTrain{n} = X(:,1:end-1);  % 输入:去掉最后一个时间步(无法预测它的下一个值)TTrain{n} = X(:,2:end);    % 目标:移位一个时间步(每个输入对应下一个时间步的真实值)
end% 2. 归一化:计算训练集的均值和标准差(所有序列拼接后统计,保证一致性)
muX = mean(cat(2,XTrain{:}),2);  % 输入的均值:cat(2,...)按时间步拼接所有序列,mean(...,2)按通道算均值
sigmaX = std(cat(2,XTrain{:}),0,2);  % 输入的标准差:0表示除以N-1(无偏估计),2表示按通道算muT = mean(cat(2,TTrain{:}),2);  % 目标的均值
sigmaT = std(cat(2,TTrain{:}),0,2);  % 目标的标准差% 3. 对输入和目标进行归一化(用训练集的统计量,避免数据泄露)
for n = 1:numel(XTrain)XTrain{n} = (XTrain{n} - muX) ./ sigmaX;  % 输入归一化:(原始-均值)/标准差TTrain{n} = (TTrain{n} - muT) ./ sigmaT;  % 目标归一化
end
关键逻辑解释
  • 移位序列的原因:假设序列为[t1,t2,t3,t4],输入XTrain[t1,t2,t3],目标TTrain[t2,t3,t4],让LSTM学习“t1→t2”“t2→t3”“t3→t4”的映射,最终能实现“输入任意序列→预测下一个时间步”。
  • 归一化的必要性:若不同通道的数值范围差异大(如通道1是0-1,通道2是100-200),训练时会导致梯度更新失衡,模型难以收敛。用训练集统计量归一化,是为了避免“测试集信息泄露到训练集”(测试集的统计量未知)。

3. 第三步:定义LSTM网络架构

时序预测的LSTM网络需要适配“序列输入→序列输出”的需求,核心层包括:序列输入层、LSTM层、全连接层、回归层。

代码与逐行解析
layers = [sequenceInputLayer(numChannels)  % 序列输入层:输入维度=通道数(numChannels=3)lstmLayer(128)                   % LSTM层:128个隐藏单元(决定学习能力)fullyConnectedLayer(numChannels) % 全连接层:输出维度=通道数(与输入通道一致)regressionLayer];                % 回归层:定义回归任务的损失函数(默认均方误差MSE)
各层参数与作用详解
层名称参数配置作用说明
sequenceInputLayernumChannels=3接收“通道数×时间步”的序列输入,输入维度必须与数据的通道数一致(否则维度不匹配)。
lstmLayer128个隐藏单元隐藏单元数量决定LSTM的“记忆容量”:128个单元可捕捉中等复杂度的时序规律;数量越多学习能力越强,但易过拟合。
fullyConnectedLayernumChannels=3将LSTM输出的128维隐藏状态“映射”到3维(与输入通道数一致),确保输出序列的维度与目标序列匹配。
regressionLayer无参数(默认)回归任务的输出层,计算“预测值-真实值”的均方误差(MSE),作为训练的损失函数,指导网络更新权重。

4. 第四步:指定训练选项

训练选项决定模型的优化策略,需结合数据规模、网络复杂度调整。

代码与逐行解析
options = trainingOptions("adam", ...  % 优化器:Adam(自适应学习率,适合时序数据)MaxEpochs=200, ...                 % 最大训练轮数:200轮(平衡训练效果与时间)SequencePaddingDirection="left", ...% 序列对齐方式:左侧补零(保护右侧有效信息)Shuffle="every-epoch", ...         % 数据打乱:每轮训练前打乱训练集,避免过拟合Plots="training-progress", ...     % 可视化:显示训练进度(损失曲线、准确率等)Verbose=0);                        % 日志输出:0表示不打印详细训练日志(仅看进度图)
关键选项解释
  • Adam优化器:比SGD(随机梯度下降)收敛更快,通过自适应学习率调整不同参数的更新步长,适合LSTM这类复杂网络。
  • MaxEpochs=200:200轮是针对2000个序列的经验值——轮数太少可能欠拟合(没学会规律),太多则可能过拟合(记住训练集噪声)。
  • SequencePaddingDirection=“left”:不同序列长度不同,训练时需补零对齐。左侧补零是为了保护“右侧的近期信息”(时序数据中,右侧时间步更重要),避免右侧补零干扰预测。

5. 第五步:训练LSTM网络

调用trainNetwork函数,用训练集(XTrain, TTrain)和训练选项(options)训练网络。

代码与解析
% 训练网络:输入(XTrain)、目标(TTrain)、网络架构(layers)、训练选项(options)
net = trainNetwork(XTrain,TTrain,layers,options);
  • 输出:训练好的LSTM网络net,包含学习到的权重、偏置和网络结构。
  • 训练过程:运行时会弹出“训练进度图”,可观察训练损失(Training Loss)的下降趋势——若损失趋于平稳,说明网络收敛。

6. 第六步:测试网络(评估泛化能力)

测试的核心是:用训练好的网络预测测试集,计算误差(RMSE)评估泛化能力。

代码与逐行解析
% 1. 准备测试数据(与训练数据处理逻辑一致:移位+归一化)
for n = 1:size(dataTest,1)  % 遍历每个测试序列X = dataTest{n};        % 取第n个测试序列XTest{n} = (X(:,1:end-1) - muX) ./ sigmaX;  % 测试输入:移位+用训练集统计量归一化TTest{n} = (X(:,2:end) - muT) ./ sigmaT;    % 测试目标:移位+归一化
end% 2. 用测试集预测(指定左侧补零,与训练一致)
YTest = predict(net,XTest,SequencePaddingDirection="left");% 3. 计算每个测试序列的RMSE(均方根误差,评估预测精度)
for i = 1:size(YTest,1)% RMSE = sqrt(平均(预测值-真实值)^2),"all"表示对所有元素计算rmse(i) = sqrt(mean((YTest{i} - TTest{i}).^2,"all"));
end% 4. 可视化RMSE分布(直方图)
figure
histogram(rmse)  % 绘制RMSE的频率分布
xlabel("RMSE")    % x轴:RMSE值(越小精度越高)
ylabel("Frequency")  % y轴:频率(多少个序列的RMSE落在该区间)% 5. 计算所有测试序列的平均RMSE
mean(rmse)
评估逻辑
  • RMSE的意义:RMSE越小,预测值与真实值的偏差越小。例如,若平均RMSE=0.1,说明预测值与真实值的平均偏差仅0.1(归一化后的值,反归一化后可还原为原始尺度)。
  • 为什么用训练集统计量归一化:测试时无法获取“未来数据的统计量”,用训练集统计量才能模拟真实预测场景(避免数据泄露)。

7. 第七步:预测未来时间步(重点:开环vs闭环)

测试仅验证“单步预测”能力,实际应用中常需“多步预测”(如预测未来200个时间步)。此时需区分开环与闭环两种模式,闭环预测是本文核心

7.1 先理解:开环预测(依赖真实值)

开环预测的逻辑是:每次预测都需要“前一个时间步的真实值”作为输入,适合能实时获取真实数据的场景(如实时监测设备数据,用真实值预测下一秒)。

% 选择一个测试序列(索引=2)
idx = 2;
X = XTest{idx};  % 测试输入序列
T = TTest{idx};  % 测试目标序列% 1. 初始化网络状态(重置隐藏状态,避免历史数据干扰)
net = resetState(net);
% 2. 用前75个时间步的真实数据更新网络状态(让网络“记住”初始上下文)
offset = 75;  % 初始真实数据的时间步长度
[net,~] = predictAndUpdateState(net,X(:,1:offset));% 3. 开环预测:用真实值作为输入,预测剩余时间步
numTimeSteps = size(X,2);  % 测试序列总时间步
numPredictionTimeSteps = numTimeSteps - offset;  % 需预测的时间步数量
Y_open = zeros(numChannels,numPredictionTimeSteps);  % 存储开环预测结果for t = 1:numPredictionTimeStepsXt = X(:,offset+t);  % 输入:第offset+t步的真实值(开环的核心:依赖真实值)[net,Y_open(:,t)] = predictAndUpdateState(net,Xt);  % 预测+更新网络状态
end% 4. 可视化开环预测结果
figure
t = tiledlayout(numChannels,1);  % 按通道堆叠子图
title(t,"Open Loop Forecasting")
for i = 1:numChannelsnexttileplot(T(i,:))  % 真实值(目标序列)hold on% 预测值:从offset步开始,拼接offset步的真实值+预测值plot(offset:numTimeSteps,[T(i,offset) Y_open(i,:)],'--')ylabel("Channel " + i)
end
xlabel("Time Step")
nexttile(1)
legend(["True Value" "Forecasted Value"])
  • 开环的局限性:必须获取每个时间步的真实值才能继续预测,无法一次性预测多步未来(如无法直接预测未来200步,需等每一步真实值产生)。
7.2 核心:闭环预测(无需真实值,用前一次预测当输入)

闭环预测的逻辑是:仅用初始真实数据初始化,后续预测完全依赖“前一次的预测结果”作为输入,可一次性预测任意多步未来,适合无法获取实时真实数据的场景(如预测未来一个月的销量)。

代码与逐行解析
% 1. 重置网络状态(关键!清除历史隐藏状态,确保从干净的初始状态开始)
net = resetState(net);% 2. 用测试序列的所有真实数据初始化网络状态(让网络“记住”完整的初始上下文)
offset = size(X,2);  % offset=测试序列的总时间步(用全部真实数据初始化)
[net,Z] = predictAndUpdateState(net,X);  % Z是初始预测结果(与测试序列长度一致)% 3. 闭环预测:预测未来200个时间步(可自定义数量)
numPredictionTimeSteps = 200;  % 需预测的未来时间步数量
Xt = Z(:,end);  % 初始输入:最后一个时间步的预测值(闭环的核心:用预测值当输入)
Y_closed = zeros(numChannels,numPredictionTimeSteps);  % 存储闭环预测结果% 循环预测:每一步用前一次的预测值作为输入
for t = 1:numPredictionTimeSteps% 预测当前时间步+更新网络状态[net,Y_closed(:,t)] = predictAndUpdateState(net,Xt);% 更新输入:下一次预测用当前的预测值Xt = Y_closed(:,t);
end% 4. 可视化闭环预测结果
numTimeSteps = offset + numPredictionTimeSteps;  % 总时间步=初始真实数据+预测数据
figure
t = tiledlayout(numChannels,1);
title(t,"Closed Loop Forecasting")  % 标题:闭环预测for i = 1:numChannelsnexttileplot(T(i,1:offset))  % 初始真实数据(前offset步)hold on% 预测数据:从offset步开始,拼接offset步的真实值+未来200步的预测值plot(offset:numTimeSteps,[T(i,offset) Y_closed(i,:)],'--')ylabel("Channel " + i)
endxlabel("Time Step")
nexttile(1)
legend(["Input (True Value)" "Forecasted Value"])
闭环预测的核心细节
  1. 为什么要resetState
    LSTM的隐藏状态会“记忆”历史数据,若不重置,网络会携带上一次预测的残留信息(如之前预测过的其他序列),导致当前预测的初始状态错误,误差被不断放大。resetState能将隐藏状态清零,确保从“干净的初始状态”开始学习当前序列的上下文。

  2. predictAndUpdateState的作用
    该函数是闭环预测的核心工具,同时完成两个任务:

    • 基于当前输入(真实值或预测值)计算预测结果;
    • 更新网络的隐藏状态(让网络“记住”当前输入的信息,为下一次预测做准备)。
  3. 循环逻辑的关键
    每次循环中,Xt = Y_closed(:,t)将“当前预测值”作为“下一次预测的输入”,形成“预测→输入→再预测”的闭环,无需任何真实值即可持续预测多步未来。

三、闭环预测的优缺点与适用场景

特点优点缺点适用场景
数据依赖仅需初始真实数据,后续无需真实值误差会累积(前一步预测不准,后一步偏差更大)无法获取实时真实数据、需一次性预测多步未来(如预测未来1年的季节性波动)
灵活性可自定义预测步数(如预测200步、500步)精度通常低于开环预测长期趋势预测、资源有限无法实时采集数据的场景
计算效率一次性循环完成多步预测,无需等待真实数据需合理初始化网络状态(否则初始误差大)批量预测、离线预测任务

四、总结

本文通过完整的MATLAB代码,拆解了LSTM时间序列预测的全流程:从数据加载与移位处理、网络架构设计、训练优化,到开环与闭环预测的实现。核心结论如下:

  1. 数据处理是基础:移位目标序列让LSTM学习“当前→下一个”的映射,归一化避免训练发散,左侧补零保护有效信息。
  2. 网络架构需适配任务:sequenceInputLayer匹配通道数,lstmLayer隐藏单元数量平衡学习能力与过拟合,regressionLayer适配回归任务。
  3. 闭环预测是核心亮点:通过resetState初始化状态、predictAndUpdateState预测+更新状态、循环用前一次预测当输入,实现无需真实值的多步预测,适合实际应用中的长期预测需求。

掌握这套流程后,你可以将其迁移到自己的时序数据(如温度、销量、股价),只需调整通道数、隐藏单元数量、预测步数等参数,即可快速实现定制化的时间序列预测。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/100386.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/100386.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/100386.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

gpu-z功能介绍,安装与使用方法

GPU-Z 功能介绍、安装与使用方法 一、核心功能 硬件信息检测 识别显卡型号、制造商、核心架构(如NVIDIA Ada Lovelace、AMD RDNA 3)、制造工艺(如5nm、7nm)。显示显存类型(GDDR6X、HBM2e)、容量、带宽及显…

数据搬家后如何处理旧 iPhone

每年,苹果都会推出新款 iPhone,激发了人们升级到 iPhone 17、iPhone 17 Pro、iPhone 17 Pro Max 或 iPhone Air 等新机型的热情。但在获得新 iPhone 之前,有一件重要的事情要做:将数据从旧 iPhone 转移到新设备。虽然许多用户都能…

Java关键字深度解析(上)

这是一份全面的Java关键字实战指南 目录 1.数据类型关键字:内存布局与性能优化 1.1 基础类型的内存密码 byte-内存的极简主义者 int-Java世界的万能钥匙 long - 时间与ID的守护者 1.2 引用类型的架构设计 String-不是关键字但胜于关键字 2.访问修饰符:企业级权限控制 …

C语言深度解析:指针数组与数组指针的区别与应用

目录 1 引言:从名字理解本质区别 2 指针数组:灵活管理多个指针 2.1 基本概念与声明方式 2.2 内存布局与特性 2.3 典型应用场景:字符串数组与多维度数据管理 2.3.1 静态分配示例:字符串数组 2.3.2 动态分配示例:…

Node.js 高级应用:负载均衡与流量限制

在当今高并发的网络应用环境中,如何有效地分配服务器资源并保护系统免受恶意攻击是开发者必须面对的重要问题。Node.js 作为一款广受欢迎的服务器端 JavaScript 运行时环境,提供了丰富的工具和模块来应对这些挑战。本文将深入探讨如何在 Node.js 中实现负…

信任链验证流程

信任链验证流程 (The Chain of Trust)整个过程就像一场严格的接力赛,每一棒都必须从可信的上一位手中接过接力棒(信任),验证无误后,再跑自己的那段路,并把信任传递给下一棒现在,我们来详细解读图…

黄昏时刻复古胶片风格人像风光摄影后期Lr调色教程,手机滤镜PS+Lightroom预设下载!

调色教程这套 黄昏时刻复古胶片风格人像风光摄影后期 Lr 调色方案,以落日余晖为核心色彩元素,加入复古胶片质感,让画面充满温暖与怀旧氛围。整体色调偏向橙红与青绿的互补对比,天空的夕阳光影与人像肤色相互映衬,既有胶…

硬件驱动——I.MX6ULL裸机启动(3)(按键设置及中断设置

重点:1.GIC:(Generic Interrupt Controller)通用中断控制器,是ARM架构中用于管理中断的核心模块,主要用于现代多核处理器系统。它负责接收,分发并分发中断请求,减轻CPU负担&#x…

用deepseek对GPU服务器进行压力测试

利用 DeepSeek 模型对 GPU 服务器进行压力测试,核心思路是通过模拟高负载的模型推理 / 微调任务,验证 GPU 服务器在计算、显存、网络等维度的承载能力,同时观察稳定性与性能瓶颈。以下是具体的测试方案,涵盖测试环境准备、核心测试…

ARM(7)IMX6ULL 按键控制(轮询 + 中断)优化工程

一、硬件介绍1. 开关功能定义共 3 个开关(两红一黄),功能分工明确:中间开关:复位按钮左边开关:低功耗按钮右边开关:用户独立控制的试验按键(核心控制对象)2. 核心电平逻辑…

【QT随笔】什么是Qt元对象系统?Qt元对象系统的核心机制与应用实践

【QT随笔】什么是Qt元对象系统?Qt元对象系统的核心机制与应用实践 之所以写下这篇文章,是因为前段时间自己面试的时候被问到了!因此想借此分享一波!!!本文主要详细解释Qt元对象系统的概念、作用及实现机制…

从技术视角解析加密货币/虚拟货币/稳定币的设计与演进

随着加密货币行情的持续走高,除了资产价值,我想试着从底层程序设计与架构角度解析比特币、以太坊、稳定币以及新兴公链的核心技术方案。作者在2018年设计实施了基于区块链技术的金融项目,并荣获了国家课题进步奖,对加密货币及场景…

[MySQL]Order By:排序的艺术

[MySQL]Order By:排序的艺术 1. 简介 在数据库管理中,数据的排序是一项至关重要的操作。MySQL 的 ORDER BY 子句为我们提供了强大而灵活的功能,用于对查询结果进行排序。无论是按照字母顺序排列名称,还是根据日期或数值进行升序…

【工具代码】使用Python截取视频片段,截取视频中的音频,截取音频片段

目录 ■截取视频方法 1.下载 ffmpeg-8.0-essentials_build 2.配置到环境变量 3.python代码 4.运行 5.效果 ■更多 截取视频中的音频 截取音频 Sony的CR3图片,转换为JPG ■截取视频方法 1.下载 ffmpeg-8.0-essentials_build "https://www.gyan.de…

Three.js 平面始终朝向相机

instanceMesh需要让实例像粒子一样始终朝向相机 可以如下处理shaderexport const billboarding // billboarding函数的GLSL实现 // 参数: // - position: 顶点动态位置偏移 // - positionLocal: mesh的position // - horizontal: 水平方向是否朝向相机 // - vertical: 垂直方…

旗讯 OCR 识别系统深度解析:一站式解决表格、手写文字、证件识别难题!

在数字化办公日益普及的今天,“纸质文档转电子”“图片信息提取” 等需求愈发频繁,但传统手动录入不仅效率低下,还容易出现数据错误。近期发现一款实用性极强的工具 —— 旗讯数字 OCR 识别系统,其覆盖多场景的识别功能、极简操作…

MissionPlanner架构梳理之(十四)日志浏览

概述和目的 Mission Planner 中的日志浏览系统提供了加载、查看、分析和解读 ArduPilot 驱动的飞行器生成的飞行日志的工具。飞行日志包含飞行操作期间记录的关键遥测数据,使用户能够查看飞行性能、诊断问题并从过去的飞行中获取见解。 本页记录了日志浏览系统的架…

机器学习shap分析案例

在进行数据分析和机器学习时经常用到shap,本文对shap相关的操作进行演示。波士顿数据集链接在这里。 SHAP Analysis Guide Set up 导入必要包 import pandas as pd import numpy as np import lightgbm as lgb import matplotlib import matplotlib.pyplot as p…

网络编程相关函数

1. 套接字操作相关1.1 socketint socket(int domain, int type, int protocol);参数说明int domain协议族,常用 AF_INET(IPv4)、AF_INET6(IPv6)int type套接字类型,SOCK_DGRAM(UDP)、…

ESLint 自定义 Processor(处理器)

ESLint 自定义 Processor(处理器) 🔹 什么是 Processor? 在 ESLint 中,Processor(处理器)是一种扩展机制,允许处理非标准 JavaScript/TypeScript 文件。默认情况下,ESLin…