预测效果

在这里插入图片描述

代码功能

该代码实现了一个结合卷积神经网络(CNN)和Kolmogorov–Arnold网络(KAN)的混合模型(CNN-KAN),用于时间序列预测任务。核心功能包括:

  1. 数据加载与预处理:加载标准化后的训练集和测试集(时间序列数据)。
  2. 模型构建
    • CNN部分:提取时间序列的局部特征(使用1D卷积层和池化层)。
    • KAN部分:替代全连接层,通过样条基函数增强非线性拟合能力,提高预测精度。
  3. 模型训练与评估:使用MSE损失和Adam优化器训练模型,保存最佳模型参数,并在测试集上计算评估指标(MSE、RMSE、MAE、R²)。
  4. 结果可视化:绘制训练/测试损失曲线,并反归一化预测结果。

算法步骤

  1. 数据加载

    • 使用joblib加载预处理后的训练集(train_set, train_label)和测试集(test_set, test_label)。
    • 封装为DataLoader(批量大小=64)。
  2. 模型定义

    • KANLinear
      • 基础线性变换 + 样条基函数(B-splines)的非线性变换。
      • 支持动态网格更新和正则化损失计算。
    • CNN1DKANModel
      • 卷积块:多个Conv1d + ReLU + MaxPool1d层(参考VGG架构)。
      • 自适应平均池化:替代全连接层,减少参数量。
      • KAN输出层:生成最终预测结果。
  3. 模型训练

    • 损失函数:均方误差(MSELoss)。
    • 优化器:Adam(学习率=0.0003)。
    • 训练循环
      • 前向传播 → 计算损失 → 反向传播 → 参数更新。
      • 记录每个epoch的训练/测试MSE,保存最佳模型(最低测试MSE)。
  4. 模型评估

    • 加载最佳模型进行预测。
    • 计算指标:(模型拟合优度)、MSERMSEMAE
    • 反归一化预测结果(使用预训练的StandardScaler)。
  5. 可视化

    • 绘制训练/测试MSE随epoch的变化曲线。
    • 输出评估指标和反归一化后的结果。

技术路线

  1. 框架:PyTorch(模型构建、训练、评估)。
  2. 数据预处理:使用StandardScaler标准化数据(通过joblib保存/加载)。
  3. 模型架构
    • 特征提取:CNN(1D卷积层)捕获时间序列局部模式。
    • 非线性映射:KAN层替代传统全连接层,通过样条函数灵活拟合复杂关系。
  4. 评估指标sklearn计算MSE等。
  5. 可视化matplotlib绘制损失曲线。

关键参数设定

参数说明
batch_size64数据批量大小
epochs50训练轮数
learn_rate0.0003Adam优化器学习率
conv_archs((2, 32), (2, 64))CNN层配置(卷积层数×通道数)
grid_size5KAN样条网格大小
spline_order3样条多项式阶数
output_dim1预测输出维度(回归任务)

运行环境

  • Python库
    torch, joblib, numpy, pandas, sklearn, matplotlib
    
  • 硬件:支持CUDA的GPU(优先)或CPU(自动切换)。
  • 数据依赖
    • 预处理的训练/测试集文件(train_set, train_label等)。
    • 预训练的StandardScalerscaler文件)。

应用场景

  1. 时间序列预测
    • 如股票价格、气象数据、电力负荷等序列数据的未来值预测。
  2. 高非线性关系建模
    • KAN层通过样条基函数灵活拟合复杂非线性模式,优于传统全连接层。
  3. 轻量化模型需求
    • 自适应池化替代全连接层,减少参数量(模型总参数量:22,432)。
  4. 研究验证
    • 探索CNN与KAN结合的混合架构在预测任务中的有效性(最终R²=0.995,拟合优度高)。

补充说明

  • 创新点:KAN作为输出层,通过动态网格更新和正则化约束(L1 + 熵),增强模型表达能力。
  • 性能:50个epoch后测试集MSE=0.2627(反归一化后MSE=0.0041),预测精度高。
  • 扩展性:可通过调整卷积架构、KAN参数适配不同时间序列长度和复杂度。

完整代码

  • 完整代码订阅专栏获取
# 模型预测
# 模型 测试集 验证  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 模型加载
model = torch.load('best_model_cnn_kan.pt')
model = model.to(device)# 预测数据
original_data = []
pre_data = []
with torch.no_grad():for data, label in test_loader:origin_lable = label.tolist()original_data += origin_lablemodel.eval()  # 将模型设置为评估模式data, label = data.to(device), label.to(device)# 预测test_pred = model(data)  # 对测试集进行预测test_pred = test_pred.tolist()pre_data += test_pred
[8]
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score# 模型分数
score = r2_score(original_data, pre_data)
print('*'*50)
print('模型分数--R^2:',score)print('*'*50)
# 测试集上的预测误差
test_mse = mean_squared_error(original_data, pre_data)
test_rmse = np.sqrt(test_mse)
test_mae = mean_absolute_error(original_data, pre_data)
print('测试数据集上的均方误差--MSE: ',test_mse)
print('测试数据集上的均方根误差--RMSE: ',test_rmse)
print('测试数据集上的平均绝对误差--MAE: ',test_mae)
**************************************************
模型分数--R^2: 0.9954956071920047
**************************************************
测试数据集上的均方误差--MSE:  0.004104453060426307
测试数据集上的均方根误差--RMSE:  0.06406600549766082
测试数据集上的平均绝对误差--MAE:  0.047805079976603375[19]
from sklearn.preprocessing import StandardScaler, MinMaxScaler# 将列表转换为 NumPy 数组
original_data = np.array(original_data)
pre_data = np.array(pre_data)# 反归一化处理
# 使用相同的均值和标准差对预测结果进行反归一化处理
# 反标准化
scaler  = load('scaler')
original_data = scaler.inverse_transform(original_data)
pre_data = scaler.inverse_transform(pre_data)
[20]
# 可视化结果
plt.figure(figsize=(12, 6), dpi=100)
plt.plot(original_data, label='原始值',color='orange')  # 真实值
plt.plot(pre_data, label='CNN-KAN预测值',color='green')  # 预测值
plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/914377.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/914377.shtml
英文地址,请注明出处:http://en.pswp.cn/news/914377.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UI前端与数字孪生结合实践探索:智慧物流的仓储优化与管理系统

hello宝子们...我们是艾斯视觉擅长ui设计和前端数字孪生、大数据、三维建模、三维动画10年经验!希望我的分享能帮助到您!如需帮助可以评论关注私信我们一起探讨!致敬感谢感恩!一、引言:仓储管理的 “数字孪生革命”传统物流仓储正面临 “效率瓶颈、可视化差、响应滞…

【Android】在平板上实现Rs485的数据通讯

前言 在工业控制领域,Android 设备通过 RS485 接口与 PLC(可编程逻辑控制器)通信是一种常见的技术方案。最近在实现一个项目需要和plc使用485进行通讯,记录下实现的方式。 我这边使用的从平的Android平板,从平里面已经…

MySQL技术笔记-备份与恢复完全指南

目录 前言 一、备份概述 (一)备份方式 (二)备份策略 二、物理备份及恢复 (一)备份操作 (二)恢复操作 三、逻辑备份及恢复 (一)逻辑备份 &#xff0…

SpringBoot或OpenFeign中 Jackson 配置参数名蛇形、小驼峰、大驼峰、自定义命名

SpringBoot或OpenFeign中 Jackson 配置参数名蛇形、小驼峰、大驼峰、自定义命名 前言 在调用外部接口时,对方给出的接口文档中,入参参数名一会大写加下划线,一会又是驼峰命名。 示例如下: {"MOF_DIV_CODE": "xx…

uni-app 途径站点组件开发与实现分享

在移动应用开发中,涉及到出行、物流等场景时,途径站点的展示是一个常见的需求。本文将为大家分享一个基于 uni-app 开发的途径站点组件,该组件能够清晰展示路线中的各个站点信息,包括站点名称、到达时间、是否已到达等状态&#x…

kotlin中集合的用法

从一个实际应用看起以下kotlin中代码语法正确吗 var testBeanAIP0200()var testList:List<AIP0200> ArrayList()testList.add(testBean)这段Kotlin代码存在语法错误&#xff0c;主要问题在于&#xff1a;List<AIP0200> 是Kotlin中的不可变集合接口&#xff0c;不能…

深入理解 Java Map 与 Set

文章目录前言1. 搜索树1.1 什么是搜索树1.2 查找1.3 插入1.4 删除情况一&#xff1a;cur 没有子节点&#xff08;即为叶子节点&#xff09;情况二&#xff1a;cur 只有一个子节点&#xff08;只有左子树或右子树&#xff09;情况三&#xff1a;cur 有两个子节点&#xff08;左右…

excel如何只保留前几行

方法一&#xff1a;手动删除多余行 选中你想保留的最后一行的下一行&#xff08;比如你只保留前10行&#xff0c;那选第11行&#xff09;。按住 Shift Ctrl ↓&#xff08;Windows&#xff09;或 Shift Command ↓&#xff08;Mac&#xff09;&#xff0c;选中从第11行到最…

实时连接,精准监控:风丘科技数据远程显示方案提升试验车队管理效率

风丘科技推出的数据远程实时显示方案更好地满足了客户对于试验车队远程实时监控的需求&#xff0c;并真正实现了试验车队的远程管理。随着新的数据记录仪软件IPEmotion RT和相应的跨平台显示解决方案的引入&#xff0c;让我们的客户端不仅可在线访问记录器系统状态&#xff0c;…

灰盒级SOA测试工具Parasoft SOAtest重新定义端到端测试

还在为脆弱的测试环境、强外部依赖和低效的测试复用拖慢交付而头疼&#xff1f;尤其在银行、医疗、制造等关键领域&#xff0c;传统的端到端测试常因环境不稳、接口难模拟、用例难共享而举步维艰。 灰盒级SOA测试工具Parasoft SOAtest以可视化编排简化复杂测试流程&#xff0c…

OKHttp 核心知识点详解

OKHttp 核心知识点详解 一、基本概念与架构 1. OKHttp 简介 类型&#xff1a;高效的HTTP客户端特点&#xff1a; 支持HTTP/2和SPDY&#xff08;多路复用&#xff09;连接池减少请求延迟透明的GZIP压缩响应缓存自动恢复网络故障2. 核心组件组件功能OkHttpClient客户端入口&#…

从“被动巡检”到“主动预警”:塔能物联运维平台重构路灯管理模式

从以往的‘被动巡检’转变至如今的‘主动预警’&#xff0c;塔能物联运维平台对路灯管理模式展开了重新构建。城市路灯属于极为重要的市政基础设施范畴&#xff0c;它的实际运行状态和市民出行安全以及城市形象有着直接且紧密的关联。不过呢&#xff0c;传统的路灯管理模式当下…

10. 常见的 http 状态码有哪些

总结 1xx: 正在处理2xx: 成功3xx: 重定向&#xff0c;302 重定向&#xff0c;304 协商缓存4xx: 客户端错误&#xff0c;401 未登录&#xff0c;403 没权限&#xff0c;404 资源不存在5xx: 服务器错误常见的 HTTP 状态码详解 HTTP 状态码&#xff08;HTTP Status Code&#xff0…

springBoot对接第三方系统

yml文件 yun:ip: port: username: password: controller package com.ruoyi.web.controller.materials;import com.ruoyi.common.core.controller.BaseController; import com.ruoyi.common.core.domain.AjaxResult; import com.ruoyi.materials.service.IYunService; import o…

【PTA数据结构 | C语言版】车厢重排

本专栏持续输出数据结构题目集&#xff0c;欢迎订阅。 文章目录题目代码题目 一列挂有 n 节车厢&#xff08;编号从 1 到 n&#xff09;的货运列车途径 n 个车站&#xff0c;计划在行车途中将各节车厢停放在不同的车站。假设 n 个车站的编号从 1 到 n&#xff0c;货运列车按照…

量子计算能为我们做什么?

科技公司正斥资数十亿美元投入量子计算领域&#xff0c;尽管这项技术距离实际应用还有数年时间。那么&#xff0c;未来的量子计算机将用于哪些方面&#xff1f;为何众多专家坚信它们会带来颠覆性变革&#xff1f; 自 20 世纪 80 年代起&#xff0c;打造一台利用量子力学独特性质…

BKD 树(Block KD-Tree)Lucene

BKD 树&#xff08;Block KD-Tree&#xff09;是 Lucene 用来存储和快速查询 **多维数值型数据** 的一种磁盘友好型数据结构&#xff0c;可以把它想成&#xff1a;> **“把 KD-Tree 分块压缩后落到磁盘上&#xff0c;既能做磁盘顺序读&#xff0c;又能像内存 KD-Tree 一样做…

【Mysql作业】

第一次作业要求1.首先打开Windows PowerShell2.连接到MYSQL服务器3.执行以下SQL语句&#xff1a;-- 创建数据库 CREATE DATABASE mydb6_product;-- 使用数据库 USE mydb6_product;-- 创建employees表 CREATE TABLE employees (id INT PRIMARY KEY,name VARCHAR(50) NOT NULL,ag…

(C++)STL:list认识与使用全解析

本篇基于https://cplusplus.com/reference/list/list/讲解 认识 list是一个带头结点的双向循环链表翻译总结&#xff1a; 序列容器&#xff1a;list是一种序列容器&#xff0c;允许在序列的任何位置进行常数时间的插入和删除操作。双向迭代&#xff1a;list支持双向迭代&#x…

Bash函数详解

目录**1. 基础函数****2. 参数处理函数****3. 文件操作函数****4. 日志与错误处理****5. 实用工具函数****6. 高级函数技巧****7. 常用函数库示例****总结&#xff1a;Bash 函数核心要点**1. 基础函数 1.1 定义与调用 可以自定义函数名称&#xff0c;例如将greet改为yana。❌…