神经网络中的回归详解

引言

神经网络(NeuralNetworks)是一种强大的机器学习模型,可用于分类和回归任务。本文聚焦于神经网络中的回归(Regression),即预测连续输出值(如房价、温度)。

回归问题:给定输入特征x⃗\vec{x}x,预测连续目标yyy。神经网络通过多层非线性变换学习复杂映射f:x⃗↦yf:\vec{x}\mapsto yf:xy

基本概念回顾

神经元与层

  • 神经元(Neuron):基本单元。输入x⃗=(x1,…,xn)\vec{x}=(x_1,\dots,x_n)x=(x1,,xn),权重w⃗=(w1,…,wn)\vec{w}=(w_1,\dots,w_n)w=(w1,,wn),偏置bbb
    计算:线性组合z=w⃗⋅x⃗+b=∑i=1nwixi+bz=\vec{w}\cdot\vec{x}+b=\sum_{i=1}^nw_ix_i+bz=wx+b=i=1nwixi+b
    然后激活:a=σ(z)a=\sigma(z)a=σ(z)σ\sigmaσ为激活函数。

  • (Layer):多个神经元组成。

    • 输入层:原始特征。
    • 隐藏层:中间变换。
    • 输出层:最终预测y^\hat{y}y^(回归中通常1个神经元,无激活或线性激活)。
  • 前馈神经网络(FeedforwardNeuralNetwork,FNN):信息从输入到输出单向流动。也称多层感知机(MLP)。

激活函数

激活引入非线性。常见:

  • Sigmoid:σ(z)=1/(1+e−z)\sigma(z)=1/(1+e^{-z})σ(z)=1/(1+ez),输出[0,1]。
  • Tanh:σ(z)=(ez−e−z)/(ez+e−z)\sigma(z)=(e^z-e^{-z})/(e^z+e^{-z})σ(z)=(ezez)/(ez+ez),输出[-1,1]。
  • ReLU:σ(z)=max⁡(0,z)\sigma(z)=\max(0,z)σ(z)=max(0,z),简单高效,避免梯度消失。
  • Linear:σ(z)=z\sigma(z)=zσ(z)=z,用于回归输出层。

隐藏层常用ReLU,输出层线性以输出任意实数。

神经网络回归模型结构

数学表示

假设网络有LLL层。第lll层有mlm_lml个神经元。

  • 输入:a⃗(0)=x⃗∈Rm0\vec{a}^{(0)}=\vec{x}\in\mathbb{R}^{m_0}a(0)=xRm0

  • lll层:
    z⃗(l)=W(l)a⃗(l−1)+b⃗(l) \vec{z}^{(l)}=W^{(l)}\vec{a}^{(l-1)}+\vec{b}^{(l)} z(l)=W(l)a(l1)+b(l)
    a⃗(l)=σ(l)(z⃗(l)) \vec{a}^{(l)}=\sigma^{(l)}(\vec{z}^{(l)}) a(l)=σ(l)(z(l))
    其中W(l)∈Rml×ml−1W^{(l)}\in\mathbb{R}^{m_l\times m_{l-1}}W(l)Rml×ml1为权重矩阵,b⃗(l)∈Rml\vec{b}^{(l)}\in\mathbb{R}^{m_l}b(l)Rml为偏置。

  • 输出:y^=a⃗(L)\hat{y}=\vec{a}^{(L)}y^=a(L)(标量)。

整个网络:y^=f(x⃗;θ)\hat{y}=f(\vec{x};\theta)y^=f(x;θ)θ={W(l),b⃗(l)}l=1L\theta=\{W^{(l)},\vec{b}^{(l)}\}_{l=1}^Lθ={W(l),b(l)}l=1L为参数。

示例结构

简单回归网络:输入2维,1隐藏层(3神经元),输出1维。

  • 输入层:x⃗=(x1,x2)\vec{x}=(x_1,x_2)x=(x1,x2)
  • 隐藏层:W(1)∈R3×2W^{(1)}\in\mathbb{R}^{3\times2}W(1)R3×2b⃗(1)∈R3\vec{b}^{(1)}\in\mathbb{R}^3b(1)R3,激活ReLU。
  • 输出层:W(2)∈R1×3W^{(2)}\in\mathbb{R}^{1\times3}W(2)R1×3b⃗(2)∈R\vec{b}^{(2)}\in\mathbb{R}b(2)R,激活线性。

损失函数

回归常用均方误差(MeanSquaredError,MSE):
L(y^,y)=12(y^−y)2 \mathcal{L}(\hat{y},y)=\frac{1}{2}(\hat{y}-y)^2 L(y^,y)=21(y^y)2
批次样本:$ \mathcal{L}=\frac{1}{N}\sum_{i=1}N\frac{1}{2}(\hat{y}_i-y_i)2 $

其他:MAE(L=∣y^−y∣\mathcal{L}=|\hat{y}-y|L=y^y),HuberLoss(对异常值鲁棒)。

训练过程:反向传播与梯度下降

前向传播

从输入计算到输出,得到y^\hat{y}y^L\mathcal{L}L

反向传播(Backpropagation)

计算梯度∂L/∂θ\partial\mathcal{L}/\partial\thetaL/θ

  • 输出层误差:δ(L)=∂L/∂z⃗(L)=(y^−y)⋅σ(L)′(z⃗(L))\delta^{(L)}=\partial\mathcal{L}/\partial\vec{z}^{(L)}=(\hat{y}-y)\cdot\sigma^{(L)'}(\vec{z}^{(L)})δ(L)=L/z(L)=(y^y)σ(L)(z(L))(线性激活时σ′=1\sigma'=1σ=1,故δ(L)=y^−y\delta^{(L)}=\hat{y}-yδ(L)=y^y)。
  • 向后传播:δ(l)=(W(l+1))Tδ(l+1)⊙σ(l)′(z⃗(l))\delta^{(l)}=(W^{(l+1)})^T\delta^{(l+1)}\odot\sigma^{(l)'}(\vec{z}^{(l)})δ(l)=(W(l+1))Tδ(l+1)σ(l)(z(l))⊙\odot为逐元素乘。
  • 梯度:
    ∂L∂W(l)=δ(l)(a⃗(l−1))T \frac{\partial\mathcal{L}}{\partial W^{(l)}}=\delta^{(l)}(\vec{a}^{(l-1)})^T W(l)L=δ(l)(a(l1))T
    ∂L∂b⃗(l)=δ(l) \frac{\partial\mathcal{L}}{\partial\vec{b}^{(l)}}=\delta^{(l)} b(l)L=δ(l)

优化:梯度下降

更新参数:θ:=θ−η∇θL\theta:=\theta-\eta\nabla_\theta\mathcal{L}θ:=θηθLη\etaη为学习率。

变体:

  • SGD:随机梯度下降,每批次更新。
  • Momentum:添加动量v:=βv−η∇v:=\beta v-\eta\nablav:=βvηθ:=θ+v\theta:=\theta+vθ:=θ+v
  • Adam:自适应学习率,结合动量和RMSProp。

完整训练算法

  1. 初始化θ\thetaθ(e.g.,Xavier初始化)。
  2. 对于每个epoch:
    a. 前向:计算y^\hat{y}y^L\mathcal{L}L
    b. 反向:计算梯度。
    c. 更新θ\thetaθ
  3. 监控验证损失,早停防止过拟合。

数学推导示例:简单网络

假设单隐藏层,输入1维xxx,隐藏1神经元,输出y^\hat{y}y^

  • 前向:
    z(1)=w1x+b1z^{(1)}=w_1x+b_1z(1)=w1x+b1a(1)=σ(z(1))a^{(1)}=\sigma(z^{(1)})a(1)=σ(z(1))(ReLU)。
    z(2)=w2a(1)+b2z^{(2)}=w_2a^{(1)}+b_2z(2)=w2a(1)+b2y^=z(2)\hat{y}=z^{(2)}y^=z(2)(线性)。
  • 损失:L=12(y^−y)2\mathcal{L}=\frac{1}{2}(\hat{y}-y)^2L=21(y^y)2
  • 梯度:
    ∂L/∂y^=y^−y\partial\mathcal{L}/\partial\hat{y}=\hat{y}-yL/y^=y^y
    ∂L/∂w2=(y^−y)a(1)\partial\mathcal{L}/\partial w_2=(\hat{y}-y)a^{(1)}L/w2=(y^y)a(1)
    ∂L/∂b2=y^−y\partial\mathcal{L}/\partial b_2=\hat{y}-yL/b2=y^y
    ∂L/∂a(1)=(y^−y)w2\partial\mathcal{L}/\partial a^{(1)}=(\hat{y}-y)w_2L/a(1)=(y^y)w2
    ∂L/∂z(1)=∂L/∂a(1)⋅σ′(z(1))\partial\mathcal{L}/\partial z^{(1)}=\partial\mathcal{L}/\partial a^{(1)}\cdot\sigma'(z^{(1)})L/z(1)=L/a(1)σ(z(1))(ReLU’:1 ifz(1)>0z^{(1)}>0z(1)>0,else0)。
    ∂L/∂w1=∂L/∂z(1)⋅x\partial\mathcal{L}/\partial w_1=\partial\mathcal{L}/\partial z^{(1)}\cdot xL/w1=L/z(1)x
    ∂L/∂b1=∂L/∂z(1)\partial\mathcal{L}/\partial b_1=\partial\mathcal{L}/\partial z^{(1)}L/b1=L/z(1)

正则化与优化技巧

  • 过拟合防治

    • L1/L2正则:添加λ∑∣w∣\lambda\sum|w|λwλ∑w2\lambda\sum w^2λw2到损失。
    • Dropout:训练时随机丢弃神经元(概率p)。
    • 数据增强:增加训练数据。
    • 早停:验证损失上升时停止。
  • 初始化:He初始化forReLU:w∼N(0,2/ml−1)w\sim\mathcal{N}(0,\sqrt{2/m_{l-1}})wN(0,2/ml1)

  • 批标准化(BatchNormalization):在每层后标准化z⃗(l)\vec{z}^{(l)}z(l),加速训练。

  • 学习率调度:余弦退火或指数衰减。

  • 超参数调优:层数、神经元数、学习率、批大小。用GridSearch或BayesianOptimization。

优点与缺点

  • 优点

    • 处理非线性关系:通用函数逼近器。
    • 自动特征提取:隐藏层学习高级表示。
    • 可扩展:深层网络捕捉复杂模式。
  • 缺点

    • 计算密集:训练需GPU。
    • 黑箱:解释性差(用SHAP或LIME改善)。
    • 需大量数据:小数据集易过拟合。
    • 梯度消失/爆炸:深层网络问题(用ReLU、残差连接缓解)。

应用场景

  • 房价预测:输入面积、位置等,输出价格。
  • 时间序列预测:RNN/LSTM变体,但基本FNN可用于简单回归。
  • 图像回归:CNN提取特征,后接全连接回归(如年龄估计)。
  • 金融:股票价格预测。

实际例子

例子1:线性回归模拟

用单层无激活网络模拟线性回归y=2x+1y=2x+1y=2x+1

  • 输入xxx,输出y^=wx+b\hat{y}=wx+by^=wx+b
  • 损失MSE。
  • 训练后w≈2w\approx2w2b≈1b\approx1b1

例子2:非线性回归

预测y=sin⁡(x)+噪声y=\sin(x)+噪声y=sin(x)+噪声

  • 网络:输入1,隐藏[64,64]ReLU,输出1线性。
  • 数据:1000点x∈[−π,π]x\in[-π,π]x[π,π]
  • 训练:Adam,MSE,epochs=1000。
    网络学习正弦曲线。

代码实现(Python with PyTorch)

import torch
import torch.nn as nn
import torch.optim as optimclass RegressionNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(1, 64)self.fc2 = nn.Linear(64, 64)self.fc3 = nn.Linear(64, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))return self.fc3(x)# 数据
x = torch.randn(1000, 1) * 3.14
y = torch.sin(x) + 0.1 * torch.randn(1000, 1)# 训练
model = RegressionNet()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)for epoch in range(1000):optimizer.zero_grad()output = model(x)loss = criterion(output, y)loss.backward()optimizer.step()

总结

神经网络回归通过多层变换、反向传播和优化学习连续映射。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/96853.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/96853.shtml
英文地址,请注明出处:http://en.pswp.cn/web/96853.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVASCRIPT 前端数据库-V9--仙盟数据库架构-—仙盟创梦IDE

老版本 在v1 版本中我们讲述了 基础版的应用JAVASCRIPT 前端数据库-V1--仙盟数据库架构-—-—仙盟创梦IDE-CSDN博客接下载我们做一个更复杂的的其他场景由于,V1查询字段必须 id接下来我们修改了了代码JAVASCRIPT 前端数据库-V2--仙盟数据库架构-—-—仙盟创梦IDE-CS…

k8s核心资料基本操作

NamespaceNamespace是kubernetes系统中的一种非常重要资源,它的主要作用是用来实现多套环境的资源隔离或者多租户的资源隔离。默认情况下,kubernetes集群中的所有的Pod都是可以相互访问的。但是在实际中,可能不想让两个Pod之间进行互相的访问…

PostgreSQL——分区表

分区表一、分区表的意义二、传统分区表2.1、继承表2.2、创建分区表2.3、使用分区表2.4、查询父表还是子表2.5、constraint_exclusion参数2.6、添加分区2.7、删除分区2.8、分区表相关查询2.9、传统分区表注意事项三、内置分区表3.1、创建分区表3.2、使用分区表3.3、内置分区表原…

Linux任务调度全攻略

Linux下的任务调度分为两类,系统任务调度和用户任务调度。系统任务调度:系统周期性所要执行的工作,比如写缓存数据到硬盘、日志清理等。在/etc目录下有一个crontab文件,这个就是系统任务调度的配置文件。/etc/crontab文件包括下面…

回溯算法通关秘籍:像打怪一样刷题

🚀 回溯算法通关秘籍:像打怪一样刷题! 各位同学,今天咱们聊聊 回溯算法(Backtracking)。它听起来玄乎,但其实就是 “暴力搜索 剪枝” 的优雅版。 打个比方:回溯就是在迷宫里探险&am…

嵌入式Linux常用命令

📟 核心文件与目录操作pwd-> 功能: 打印当前工作目录的绝对路径。-> 示例: pwd -> 输出 /home/user/projectls [选项] [目录]-> 功能: 列出目录内容。-> 常用选项:-l: 长格式显示(详细信息)-a: 显示所有文件(包括隐…

深入理解 Linux 内核进程管理

在 Linux 系统中,进程是资源分配和调度的基本单位,内核对进程的高效管理直接决定了系统的性能与稳定性。本文将从进程描述符的结构入手,逐步剖析进程的创建、线程实现与进程终结的完整生命周期,带您深入理解 Linux 内核的进程管理…

ACP(三):让大模型能够回答私域知识问题

让大模型能够回答私域知识问题 未经过特定训练答疑机器人,是无法准确回答“我们公司项目管理用什么工具”这类内部问题。根本原因在于,大模型的知识来源于其训练数据,这些数据通常是公开的互联网信息,不包含任何特定公司的内部文档…

使用Xterminal连接Linux服务器

使用Xterminal连接Linux服务器(VMware虚拟机)的步骤如下,前提是虚拟机已获取IP(如 192.168.31.105)且网络互通: 一、准备工作(服务器端确认)确保SSH服务已安装并启动 Linux服务器需要…

ChatBot、Copilot、Agent啥区别

以下内容为AI生成ChatBot(聊天机器人)、Copilot(副驾驶)和Agent(智能体/代理)是AI应用中常见的三种形态,它们在人机交互、自动化程度和任务处理能力上有着显著的区别。特征维度ChatBot (聊天机器…

2025 年大语言模型架构演进:DeepSeek V3、OLMo 2、Gemma 3 与 Mistral 3.1 核心技术剖析

编者按: 在 Transformer 架构诞生八年之际,我们是否真的见证了根本性的突破,还是只是在原有设计上不断打磨?今天我们为大家带来的这篇文章,作者的核心观点是:尽管大语言模型在技术细节上持续优化&#xff0…

基于Matlab GUI的心电信号QRS波群检测与心率分析系统

心电信号(Electrocardiogram, ECG)是临床诊断心脏疾病的重要依据,其中 QRS 波群的准确检测对于心率分析、心律失常诊断及自动化心电分析系统具有核心意义。本文设计并实现了一套基于 MATLAB GUI 的心电信号处理与分析系统,集成了数…

1台SolidWorks服务器能带8-10人并发使用

在工业设计和机械工程领域,SolidWorks作为主流的三维CAD软件,其服务器部署方案直接影响企业协同效率。通过云飞云共享云桌面技术实现多人并发使用SolidWorks时,实际承载量取决于硬件配置、网络环境、软件优化等多维度因素的综合作用。根据专业…

String、StringBuilder和StringBuffer的区别

目录一. String:不可变的字符串二.StringBuilder:可变字符串三.StringBuffer:线程安全的可变字符串四.总结在 Java 开发中,字符串处理是日常编码中最频繁的操作之一。String、StringBuilder 和 StringBuffer 这三个类虽然都用于操…

Power Automate List Rows使用Fetchxml查询的一个bug

看一段FetchXML, 这段查询在XRMtoolbox中的fech test工具里执行完全ok<fetch version"1.0" mapping"logical" distinct"true" no-lock"false"> <entity name"new_projectchange"> <link-entity name"sy…

Letta(MemGPT)有状态AI代理的开源框架

1. 项目概述Letta&#xff08;前身为 MemGPT&#xff09;是一个用于构建有状态AI代理的开源框架&#xff0c;专注于提供长期记忆和高级推理能力。该项目是MemGPT研究论文的实现&#xff0c;引入了"LLM操作系统"的概念用于内存管理。核心特点有状态代理&#xff1a;具…

除了ollama还有哪些模型部署方式?多样化模型部署方式

在人工智能的浪潮中&#xff0c;模型部署是释放其强大能力的关键一环。大家都知道ollama&#xff0c;它在模型部署领域有一定知名度&#xff0c;操作相对简单&#xff0c;受到不少人的青睐。但其实&#xff0c;模型部署的世界丰富多样&#xff0c;今天要给大家介绍一款工具&…

Linux系统学习之进阶命令汇总

文章目录一、系统信息1.1 查看系统信息&#xff1a;uname1.2 查看主机名&#xff1a;hostname1.3 查看cpu信息&#xff1a;1.4 当前已加载的内核模块: lsmod1.5 查看磁盘空间使用情况: df1.6 管理磁盘分区: fdisk1.7 查看目录或文件磁盘使用情况: du1.8 查看I/O使用情况: iosta…

算法面试(2)------休眠函数sleep_for和sleep_until

操作系统&#xff1a;ubuntu22.04 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 这两个函数都定义在 头文件中&#xff0c;属于 std::this_thread 命名空间&#xff0c;用于让当前线程暂停执行一段时间。函数功能sleep_for(rel_time)让当前线程休眠一段相对时间&…

贪心算法应用:5G网络切片问题详解

Java中的贪心算法应用&#xff1a;5G网络切片问题详解 1. 5G网络切片问题概述 5G网络切片是将物理网络划分为多个虚拟网络的技术&#xff0c;每个切片可以满足不同业务需求&#xff08;如低延迟、高带宽等&#xff09;。网络切片资源分配问题可以抽象为一个典型的优化问题&…