1. 引言:Excel 的“一键魔法”背后藏着什么智慧?

在 Excel 中,我们只需右键 → 添加趋势线,一条完美的直线就出现了。它快得像魔法,但魔法背后,是数学的严谨。

今天,我们不关心 Excel 内部用了什么算法可能是解析法),而是要问:

如果机器一开始‘什么都不知道’,它该怎么一步步‘学会’画出这条线?


2. 我们的任务:预测一个简单的线性关系

假设我们有一组数据,它们大致遵循 y = 2x + 1 的规律,但有一些随机噪声(模拟真实世界的测量误差)。

我们的目标:仅凭这10个数据点,让机器“学会”这个规律,找到最佳的 a(斜率)和 b(截距)


3. 准备数据:10 个“玩具”样本

我们用 Python 生成这10个数据点:

import numpy as np
import matplotlib.pyplot as plt# 设置随机种子,保证结果可复现
np.random.seed(42)# 生成 10 个 x 值,从 1 到 10
x = np.arange(1, 11)  # [1, 2, 3, ..., 10]# 生成 y 值:y = 2x + 1 + 噪声
true_a, true_b = 2, 1
noise = np.random.normal(0, 0.5, size=x.shape)  # 添加标准差为 0.5 的高斯噪声
y = true_a * x + true_b + noise# 查看数据
print("x:", x)
print("y:", np.round(y, 2))  # 保留两位小数

输出

x: [ 1  2  3  4  5  6  7  8  9 10]
y: [ 3.25  4.93  7.32  9.76 10.88 12.88 15.79 17.38 18.77 21.27]

关键点:真实规律是 y = 2x + 1,但数据有噪声,所以 y 值不完全精确。


4. 可视化:我们的目标是什么?

让我们画出这些数据点,并标出真实的线(y=2x+1):

plt.figure(figsize=(8, 6))
plt.scatter(x, y, color='blue', label='真实数据点', s=50)
# 画出真实规律的线
x_line = np.linspace(0, 11, 100)
y_line = true_a * x_line + true_b
plt.plot(x_line, y_line, 'g--', label=f'真实规律 y={true_a}x+{true_b}', linewidth=2)plt.xlabel('x')
plt.ylabel('y')
plt.title('我们的“学习”任务:根据10个数据点,找到最佳拟合线')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()


我们的任务就是:仅从这些蓝色散点,让机器找到一条最接近绿色虚线的直线


5. 模型定义:y_pred = a * x + b

我们假设模型的形式是:

y_pred = a * x + b
  • a:斜率(slope)
  • b:截距(intercept)
  • 初始时,我们不知道 a 和 b 是多少,可以先猜一个值,比如 a=0, b=0

6. 损失函数:我们“错”了多少?

我们需要一个“尺子”来衡量预测的好坏。用均方误差 (MSE)

MSE = (1/n) * Σ(y - y_pred)²

代码实现:

def compute_mse(y_true, y_pred):return np.mean((y_true - y_pred) ** 2)# 初始猜测:a=0, b=0
a, b = 0.0, 0.0
y_pred_initial = a * x + b
initial_loss = compute_mse(y, y_pred_initial)
print(f"初始损失 (a=0, b=0): {initial_loss:.3f}")
初始损失 (a=0, b=0): 182.431

解释:损失很大,因为 y_pred 全是 0,和真实 y 差很远。

7. 核心:梯度下降——如何“学习”?

现在,我们教模型如何“学习”:

  • 比喻:你在浓雾中的山坡上,想找到谷底(损失最小的地方)。你看不见路,但能感觉到脚下的坡度(梯度)。你每次都向坡度最陡的下坡方向走一步(更新参数),一步步接近谷底。

  • 数学计算(对 ab 求偏导):

    • ∂MSE/∂a = (2/n) * Σ((a*x + b - y) * x)
    • ∂MSE/∂b = (2/n) * Σ(a*x + b - y)
  • 更新规则

    • a = a - learning_rate * ∂MSE/∂a
    • b = b - learning_rate * ∂MSE/∂b

8. 动手实现:手动训练模型
# 超参数
learning_rate = 0.01  # 学习率,步长
epochs = 10000           # 训练轮数# 初始化参数
a, b = 0.0, 0.0# 记录历史,用于画图
loss_history = []
a_history = []
b_history = []for i in range(epochs):# 前向:计算预测值y_pred = a * x + b# 计算损失loss = compute_mse(y, y_pred)loss_history.append(loss)a_history.append(a)b_history.append(b)# 计算梯度n = len(x)da = (2 / n) * np.sum((y_pred - y) * x)  # 注意:y_pred - ydb = (2 / n) * np.sum(y_pred - y)# 更新参数a = a - learning_rate * dab = b - learning_rate * db# 打印进度if (i+1) % 1000 == 0:print(f"第 {i+1} 轮: a={a:.3f}, b={b:.3f}, 损失={loss:.3f}")print(f"\n训练完成!")
print(f"我们的模型找到: a ≈ {a:.3f}, b ≈ {b:.3f}")
print(f"真实规律是: a = {true_a}, b = {true_b}")

输出

第 100 轮: a=2.054, b=0.840, 损失=0.153
第 200 轮: a=2.021, b=1.070, 损失=0.124
第 300 轮: a=2.007, b=1.168, 损失=0.119
第 400 轮: a=2.001, b=1.211, 损失=0.118
第 500 轮: a=1.999, b=1.229, 损失=0.118
第 600 轮: a=1.997, b=1.237, 损失=0.118
第 700 轮: a=1.997, b=1.240, 损失=0.118
第 800 轮: a=1.997, b=1.242, 损失=0.118
第 900 轮: a=1.997, b=1.243, 损失=0.118
第 1000 轮: a=1.997, b=1.243, 损失=0.118训练完成!
我们的模型找到: a ≈ 1.997, b ≈ 1.243
真实规律是: a = 2, b = 1

看! 模型通过1000次迭代,a 从 0 逐渐接近 2,b 接近 1,损失不断下降。


9. 可视化学习过程
  • 损失下降曲线
plt.plot(loss_history)
plt.xlabel('训练轮数 (Epoch)')
plt.ylabel('损失 (MSE)')
plt.title('损失随训练过程下降')
plt.grid(True, alpha=0.3)
plt.show()

  • 参数变化过程
plt.plot(a_history, label='a (斜率)')
plt.plot(b_history, label='b (截距)')
plt.axhline(y=true_a, color='g', linestyle='--', label='真实 a')
plt.axhline(y=true_b, color='r', linestyle='--', label='真实 b')
plt.xlabel('训练轮数')
plt.ylabel('参数值')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

  • 最终拟合结果
plt.figure(figsize=(8, 6))
plt.scatter(x, y, color='blue', label='真实数据点', s=50)
plt.plot(x_line, y_line, 'g--', label=f'真实规律 y={true_a}x+{true_b}', linewidth=2)
plt.plot(x_line, a * x_line + b, 'r-', label=f'拟合线 y={a:.2f}x+{b:.2f}', linewidth=2)
plt.xlabel('x')
plt.ylabel('y')
plt.title('梯度下降拟合结果')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

  • excel拟合结果
10. 对比专业工具:Scikit-learn

让我们看看专业的机器学习库 scikit-learn 是怎么做的:

from sklearn.linear_model import LinearRegression# 注意:sklearn 需要 2D 输入
X = x.reshape(-1, 1)model = LinearRegression()
model.fit(X, y)print(f"Scikit-learn 结果: 斜率 a = {model.coef_[0]:.3f}, 截距 b = {model.intercept_:.3f}")

输出

Scikit-learn 结果: 斜率 a = 1.997, 截距 b = 1.243

完全一致! 我们手动实现的梯度下降,和 sklearn 的结果一样。


附. 深入理解:梯度下降 vs 解析法——殊途同归的两种智慧

你可能会问:“我们手动实现的梯度下降,和 scikit-learnLinearRegression,是同一种方法吗?”

答案是:不是。 它们在求解方式、数学基础和适用场景上有本质区别。但最终结果一致,是因为它们都在寻找同一个“最优解”。

下面,我们来揭开这个“黑箱”。


一、核心原理:迭代逼近 vs 数学解析
方法核心思想求解方式
手动梯度下降逐步“试错”,像盲人下山,一步步逼近最优解。迭代法:重复“计算梯度 → 更新参数”直到收敛。
scikit-learn LinearRegression直接用数学公式算出理论最优解,一步到位。解析法(闭式解):通过公式 w = (XᵀX)⁻¹Xᵀy 直接计算。
1. 手动梯度下降(迭代法)
  • 数学逻辑

    • 定义损失函数(MSE):L(a,b) = \frac{1}{n} \sum_{i=1}^{n} (y_i - (ax_i + b))^2

    • 计算梯度(偏导数):\frac{\partial L}{\partial a} = \frac{2}{n} \sum_{i=1}^{n} (ax_i + b - y_i) x_i\frac{\partial L}{\partial b} = \frac{2}{n} \sum_{i=1}^{n} (ax_i + b - y_i)

    • 更新参数:a = a - \eta \cdot \frac{\partial L}{\partial a}, \quad b = b - \eta \cdot \frac{\partial L}{\partial b}     (η 为学习率)

  • 特点

    • 近似解:需要足够迭代才能接近最优。

    • 适用于大数据:可分批处理(如随机梯度下降)。

    • 需调参:学习率、迭代次数等。

2. scikit-learn LinearRegression(解析法)
  • 数学逻辑

    • 将问题表示为矩阵形式:\mathbf{y} = \mathbf{X}\mathbf{w} + \boldsymbol{\epsilon}

    • 最优参数 ww 的闭式解为:\mathbf{w} = (\mathbf{X}^T \mathbf{X})^{-1} \mathbf{X}^T \mathbf{y}
      这个公式是通过对损失函数求导并令其为 0 推导出的理论最优解

  • 特点

    • 精确解:无需迭代,一步到位。

    • 计算复杂度高:矩阵求逆的复杂度为 O(n3)O(n3),特征数多时很慢。

    • 无超参数:直接计算,无需调学习率。


二、关键差异对比
维度梯度下降(迭代法)LinearRegression(解析法)
求解方式逐步迭代逼近直接公式计算
结果性质近似解(可无限接近)理论最优解
计算效率大数据更高效特征多时慢
超参数依赖需调学习率、迭代次数无超参数
适用场景大数据、高维(如深度学习)小数据、低维(如本例)
数值稳定性受学习率影响受多重共线性影响

三、为什么结果能一致?—— 殊途同归

尽管方法不同,但我们的手动实现和 sklearn 结果几乎一致,原因在于:

梯度下降的迭代过程,本质上是在数学上“收敛于”解析法的最优解。

就像用牛顿法求方程的根:虽然是迭代过程,但最终会无限接近理论上的精确解。

在本例中,由于数据量小、特征少,解析法高效且精确。而梯度下降通过300次迭代,也成功逼近了这个最优解。


四、通俗类比
  • 梯度下降:像盲人下山,通过感受坡度,一步步向下走,最终到达谷底,适用于复杂、大规模问题,是深度学习的基石。
  • 解析法:像使用 GPS,直接定位谷底坐标,一步到位,适用于简单、小规模问题,高效精确。

两种方法最终会到达同一个“山底”,只是路径和效率不同。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/95628.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/95628.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/95628.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

关于上拉电阻

上拉电阻的作用:辅助浮空状态输出高电平 其实就是确定这根线的电平,不能让他处于一种未知的状态。 其次也可以起到限制电流的作用,防止损坏原件 那么上拉电阻如何取值? 首先来看一下驱动能力。 因为线上是一定有寄生电容的&am…

PiscCode构建Mediapipe 手势识别“剪刀石头布”小游戏

在计算机视觉与人机交互领域,手势识别是一个非常有趣的应用场景。本文将带你用 Mediapipe 和 Python 实现一个基于摄像头的手势识别“剪刀石头布”小游戏,并展示实时手势与游戏结果。 1. 项目概述 该小游戏能够实现: 实时检测手势&#xff0…

【VoNR】VoNR 不等于 VoLTE on 5G

博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 本人就职于国际知名终端厂商,负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作,目前牵头6G技术研究。 博客内容主要围绕…

计算机网络:网络设备在OSI七层模型中的工作层次和传输协议

OSI七层模型(物理层、数据链路层、网络层、传输层、会话层、表示层、应用层)中,不同网络设备因功能不同,工作在不同层次。以下是典型网络设备的工作层次及核心功能:1. 物理层(第1层) 核心功能&a…

RSA-e和phi不互素

1.题目import gmpy2 import libnum p 1656713884642828937525841253265560295123546793973683682208576533764344166170780019002774068042673556637515136828403375582169041170690082676778939857272304925933251736030429644277439899845034340194709105071151095131704526…

基于单片机蒸汽压力检测/蒸汽余热回收

传送门 👉👉👉👉单片机作品题目速选一览表🚀 👉👉👉👉单片机作品题目功能速览🚀 🔥更多文章戳👉小新单片机-CSDN博客&#x1f68…

https 协议与 wss 协议有什么不同

HTTPS 是用于网页数据传输的安全协议,而 WSS 是用于实时双向通信(如聊天、直播)的安全协议,二者的设计目标、应用场景、底层逻辑均存在本质区别。以下从 7 个核心维度展开对比,并补充关键关联知识,帮助彻底…

主流分布式数据库集群选型指南

以下是关于主流分布式可扩展数据库集群的详细解析,涵盖技术分类、代表产品及适用场景,帮助您高效选型:一、分布式数据库核心分类 1. NewSQL 数据库(强一致性 分布式事务)产品开发方核心特性适用场景TiDBPingCAPHTAP架…

#T1359. 围成面积

题目描述编程计算由“*”号围成的下列图形的面积。面积计算方法是统计*号所围成的闭合曲线中水平线和垂直线交点的数目。如下图所示,在1010的二维数组中,有“*”围住了15个点,因此面积为15。输入1010的图形。输出输出面积。样例输入数据 10 0…

Hive on Tez/Spark 执行引擎对比与优化

在大数据开发中,Hive 已经成为最常用的数据仓库工具之一。随着业务数据规模的不断扩大,Hive 默认的 MapReduce 执行引擎 显得笨重低效。为了提升查询性能,Hive 支持了 Tez 和 Spark 作为底层执行引擎。本文将带你对比 Hive on Tez 与 Hive on Spark 的区别,并分享调优经验。…

深入理解 Next.js 的路由机制

深入理解 Next.js 的路由机制 作者:码力无边在上一篇文章中,我们成功创建并运行了第一个 Next.js 应用。当你打开项目文件夹时,你可能会注意到一个名为 pages 的目录。这个目录看似普通,但它却是 Next.js 路由系统的核心。今天&am…

modbus_tcp和modbus_rtu对比移植AT-socket,modbus_tcp杂记

modbus_rtu通信时没有连接过程&#xff0c;主机和从机各自初始化自身串口就行了&#xff0c;而rtu需要确定从机ID。注:在TCP连接中&#xff0c;不同的网卡有不同的IP&#xff0c;port对应具体的程序。/* 先读取数据 */for (i 0; i < len; i){if (pdPASS ! xQueueReceive(re…

Docker Compose 详解:从安装到使用的完整指南

在现代容器化应用开发中&#xff0c;Docker Compose 是一个不可或缺的工具&#xff0c;它能够帮助我们轻松定义和运行多容器的 Docker 应用程序。 一、什么是 Docker Compose&#xff1f; Docker Compose 是 Docker 官方提供的一个工具&#xff0c;用于定义和运行多容器 Dock…

springboot配置多数据源(mysql、hive)

MyBatis-Plus 不能也不建议同时去“控制” Hive。它从设计到实现都假定底层是 支持事务、支持标准 SQL 方言 的 关系型数据库&#xff08;MySQL、PostgreSQL、Oracle、SQL Server 等&#xff09;&#xff0c;而 Hive 两者都不完全符合。如果操作两个数据源都是mysql或者和关系数…

2025年上海市星光计划第十一届职业院校技能大赛高职组“信息安全管理与评估”赛项交换部分前6题详解(仅供参考)

1.北京总公司和南京分公司有两条裸纤采用了骨干链路配置,做必要的配置,只允许必要的Vlan 通过,不允许其他 Vlan 信息通过包含 Vlan1,禁止使用 trunk链路。 骨干链路位置​​:总公司 SW 与分公司 AC 之间的两条物理链路(Ethernet 1/0/5-6 必要 VLAN​​: •总公司:Vlan…

学习nginx location ~ .*.(js|css)?$语法规则

引言 nginx作为一款高性能的Web服务和反向代理服务&#xff0c;在网站性能优化中扮演着重要的角色。其中&#xff0c;location指令的正确配置是优化工作的关键之一。 这篇记录主要解析location ~ .*\.(js|css)?$这一特定的语法规则&#xff0c;帮助大家理解其在nginx配置中的…

Nmap网络扫描工具详细使用教程

目录 Nmap 主要功能 网络存活主机发现 (ARP Ping Scan) 综合信息收集扫描 (Stealth SYN Service OS) 全端口扫描 (Full Port Scan) NSE 漏洞脚本扫描 SMB 信息枚举 HTTP 服务深度枚举 SSH 安全审计 隐蔽扫描与防火墙规避 Nmap 主要功能 Nmap 主要有以下几个核心功能…

Spring Boot 3.x 的 @EnableAsync应用实例

语法结构使用 EnableAsync 其实就像为你的应用穿上一件时尚的外套&#xff0c;简单又高效&#xff01;只需在你的配置类上添加这个注解&#xff0c;轻松开启异步之旅。代码如下&#xff1a;想象一下&#xff0c;你的应用一瞬间变得灵活无比&#xff0c;像一个跳舞的机器人&…

Nginx Tomcat Jar包开机启动自动配置

一、Nginx配置1、创建systemd nginx 服务文件vi /usr/lib/systemd/system/nginx.service### 内容[Unit] DescriptionThe nginx HTTP and reverse proxy server Afternetwork.target[Service] Typeforking ExecStartPre/mnt/nginx/sbin/nginx -t ExecStart/mnt/nginx/sbin/nginx…

修订版!Uniapp从Vue3编译到安卓环境踩坑记录

Uniapp从Vue3编译到安卓环境踩坑记录 在使用Uniapp开发Vue3项目并编译到安卓环境时&#xff0c;我遇到了不少问题&#xff0c;现将主要踩坑点及解决方案整理如下&#xff0c;供大家参考。 1. 动态导入与静态导入问题 问题描述&#xff1a; 在Vue3项目中使用的动态导入语法在Uni…