梯度下降(Gradient Descent)是深度学习中最核心的优化算法之一。大模型(如GPT、BERT)在训练时需要优化数十亿甚至上千亿的参数,而梯度下降及其变体(如SGD、Adam)正是实现这一优化的关键工具。它通过计算损失函数相对于参数的梯度,并沿梯度负方向迭代更新参数,从而最小化损失。

梯度下降解决的问题

在大模型训练中,我们需要最小化一个高维、非凸的损失函数。梯度下降的目标就是找到损失函数的局部甚至全局最优点,以使模型在训练数据和测试数据上表现良好。

主要解决的问题包括:

损失最小化:通过迭代不断减少模型预测与真实值之间的误差。

收敛效率:改进的优化算法(如Adam)可以加速收敛。

避免困在鞍点:高维空间中鞍点比局部极小值更常见,因此优化器需具备跳出鞍点的能力。

2. 原理与数学推导

2.1 基本公式

梯度下降的更新规则为:

公式如下:

θt+1=θt−η⋅∇θL(θt) \theta_{t+1} = \theta_t - \eta \cdot \nabla_\theta L(\theta_t) θt+1=θtηθL(θt)

其中:

  • θ\thetaθ 是模型参数;
  • L(θ)L(\theta)L(θ) 是损失函数;
  • η\etaη 是学习率(Learning Rate);
  • ∇θL\nabla_\theta LθL 是损失函数对参数的梯度。

2.2 损失函数的几何意义

损失函数可以看作一个“地形”,梯度下降就是沿着最陡峭的下坡路一步步走到山谷底部(全局或局部最小值)。


3. 梯度下降的种类与应用

算法特点适用场景
Batch GD使用全量数据,稳定但计算量大小数据集
SGD每次用一个样本,更新快但噪声大深度学习初期
Mini-Batch GD折中方案,批量样本大模型训练首选

4. 在大模型训练中的实践

  • 优化器:Adam / AdamW 广泛用于 LLM 训练;
  • Loss:交叉熵(Cross Entropy)是语言建模的常见选择;
  • 技巧:学习率调度(Warm-up)、梯度裁剪(Gradient Clipping)、正则化(Weight Decay)。

5. 可视化示例:梯度下降过程

以下示例演示了如何用 Python + Matplotlib 画出梯度下降在二维损失曲面上的收敛轨迹

import numpy as np
import matplotlib.pyplot as plt# 损失函数: f(x) = x^2 + 2x + 1
def loss(x):return x**2 + 2*x + 1# 梯度: f'(x) = 2x + 2
def grad(x):return 2*x + 2# 参数初始化
x = 5.0
eta = 0.2  # 学习率
history = [x]# 迭代梯度下降
for _ in range(15):x -= eta * grad(x)history.append(x)# 绘图
xs = np.linspace(-4, 6, 100)
ys = loss(xs)plt.figure(figsize=(8,4))
plt.plot(xs, ys, label="Loss Curve")
plt.scatter(history, [loss(h) for h in history], c="red", label="Steps", zorder=5)
plt.title("Gradient Descent Optimization Path")
plt.xlabel("Parameter x")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()

运行后会显示

  • 蓝色曲线:损失函数 L(x)=x2+2x+1L(x)=x^2+2x+1L(x)=x2+2x+1
  • 红点:梯度下降的更新轨迹,逐步逼近最小值。

6. 图示(直观理解)

损失 L(θ)
│       •            ← 初始参数 θ0
│     •
│   •
│ •
└──────────────────────────→ 参数 θ

7. 示例:PyTorch 训练循环(简化版)

import torch
import torch.nn as nn
import torch.optim as optim# 简单线性模型 y = wx + b
model = nn.Linear(1, 1)
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.01)x = torch.randn(100, 1)
y = 3 * x + 1 + 0.1 * torch.randn(100, 1)for epoch in range(100):optimizer.zero_grad()y_pred = model(x)loss = criterion(y_pred, y)loss.backward()optimizer.step()if epoch % 10 == 0:print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

这段代码模拟了一个使用 AdamW + MSE Loss 的小型训练过程。

7. Jupyter Notebook详细版本

可视化与轨迹演示的demo示意

pip install numpy matplotlib torch pillow
import matplotlib
matplotlib.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei']  # Mac/Windows 中文字体
matplotlib.rcParams['axes.unicode_minus'] = Falseimport numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import torch.nn as nn
import torch.optim as optim#############################
# 1. 一维梯度下降动画
#############################def loss_1d(x):return x**2 + 2*x + 1def grad_1d(x):return 2*x + 2x_init = 5.0
eta = 0.2
steps = [x_init]
x = x_init
for _ in range(15):x -= eta * grad_1d(x)steps.append(x)xs = np.linspace(-4, 6, 200)
ys = loss_1d(xs)
plt.figure(figsize=(8,4))
plt.plot(xs, ys, label="Loss Curve")
plt.scatter(steps, [loss_1d(s) for s in steps], c="red", label="Steps", zorder=5)
plt.title("1D 梯度下降路径")
plt.xlabel("参数 x")
plt.ylabel("损失 Loss")
plt.legend()
plt.grid(True)
plt.show()fig, ax = plt.subplots()
ax.plot(xs, ys, label="Loss Curve")
point, = ax.plot([], [], 'ro')
ax.legend()
ax.set_title("1D 梯度下降动画")
ax.set_xlabel("参数 x")
ax.set_ylabel("损失 Loss")def init():point.set_data([], [])return point,def update(frame):x_val = steps[frame]y_val = loss_1d(x_val)point.set_data([x_val], [y_val])return point,ani = animation.FuncAnimation(fig, update, frames=len(steps), init_func=init, blit=True)
plt.close(fig)
ani.save("gradient_descent_1d.gif", writer="pillow", fps=2)#############################
# 2. 三维损失曲面 + 路径
#############################def loss_2d(w):x, y = wreturn x**2 + y**2 + x*y + 2*x + 3*y + 5def grad_2d(w):x, y = wreturn np.array([2*x + y + 2, 2*y + x + 3])eta = 0.1
w = np.array([4.0, 4.0])
path = [w.copy()]
for _ in range(30):w -= eta * grad_2d(w)path.append(w.copy())X = np.linspace(-5, 5, 50)
Y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(X, Y)
Z = loss_2d([X, Y])fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.7)
path = np.array(path)
ax.plot(path[:,0], path[:,1], [loss_2d(p) for p in path], 'r-o')
ax.set_title("3D 损失曲面与梯度下降路径")
plt.show()#############################
# 3. 优化器对比:SGD vs Adam
#############################torch.manual_seed(0)
X = torch.randn(200,1)
y = 3*X + 1 + 0.1*torch.randn(200,1)def build_model():return nn.Linear(1,1)def train(optimizer_type, lr=0.01):model = build_model()criterion = nn.MSELoss()optimizer = optimizer_type(model.parameters(), lr=lr)losses = []for epoch in range(50):optimizer.zero_grad()y_pred = model(X)loss = criterion(y_pred, y)loss.backward()optimizer.step()losses.append(loss.item())return lossesloss_sgd = train(optim.SGD, lr=0.05)
loss_adam = train(optim.Adam, lr=0.01)plt.figure(figsize=(8,4))
plt.plot(loss_sgd, label="SGD")
plt.plot(loss_adam, label="Adam")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("优化器收敛速度对比:SGD vs Adam")
plt.legend()
plt.grid(True)
plt.show()

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/917334.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/917334.shtml
英文地址,请注明出处:http://en.pswp.cn/news/917334.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JVS更新日志】开源框架、APS排产、企业计划、物联网、逻辑引擎7.30更新说明!

项目介绍 JVS是企业级数字化服务构建的基础脚手架,主要解决企业信息化项目交付难、实施效率低、开发成本高的问题,采用微服务配置化的方式,提供了低代码数据分析物联网的核心能力产品,并构建了协同办公、企业常用的管理工具等&…

Eclipse中导入新项目,右键项目没有Run on Server,Tomcat的add and remove找不到项目

原因分析没有勾选Dynamic Web Module、Java、JavaScriptDynamic Web Module版本问题解决方法Eclipse中右键项目选择Properties左侧点击project facets勾选Dynamic Web Module、Java、JavaScript,注意Dynamic Web Module版本问题,要和tomcat版本对应。- Dynamic Web …

IntelliJ IDEA 2025系列通用软件安装教程(Windows版)

前言 JetBrains系列开发工具(如IntelliJ IDEA、PyCharm、WebStorm等)是程序员们非常喜爱的集成开发环境。2025年最新版本带来了更多强大的功能和改进。本教程将详细介绍如何在Windows系统上安装JetBrains 2025系列软件。 最近挖到一个宝藏级人工智能学习…

乌鸫科技前端二面

1. 你能给我介绍一下你参与的重要项目,并重点介绍一下做的内容?通俗解释: 挑一个你觉得最拿得出手、技术含量最高的项目,说说这个项目是干什么的(比如一个电商网站、一个后台管理系统),你在里面具体负责了…

《c++面向对象入门与实战》笔记

前年的书,翻出来整理一下7章.指针指针 sizeof为4*指针 sizeof为 所指类型的sizeof注意free后置空,避免野指针11章.类

easyExcel生成多个sheet的动态表头的实现

在使用 EasyExcel 实现“多个 Sheet 且每个 Sheet 表头是动态的”需求时&#xff0c;思路如下&#xff1a;✅ 实现思路概述 EasyExcel 的 ExcelWriter 支持多个 Sheet 写入。每个 Sheet&#xff1a; 使用 WriteSheet 创建&#xff1b;可以绑定一个动态生成的表头 List<List&…

SQL 连接类型示例:内连接与外连接

SQL 连接类型示例&#xff1a;内连接与外连接 示例数据表 假设我们有两个表&#xff1a; employees 表:emp_idemp_namedept_id1张三1012李四1023王五1034赵六NULLdepartments 表:dept_iddept_name101销售部102技术部104财务部1. 内连接 (INNER JOIN) 内连接只返回两个表中匹配的…

Ubuntu安装gpu驱动,cuda

系统初始化 1、安装基础软件 apt-get update apt-get -y install openssh-server openssh-client apt-utils freeipmi ipmitool sshpass ethtool zip unzip nano less git netplan.io iputils-ping mtr ipvsadm smartmontools python3-pip socat conntrack libvirt-clients li…

ctfshow_源码压缩包泄露

根据题目信息直接dirsearch解压下来一个.txt文件&#xff0c;一个index.phpflag{flag_here}不对那么就去看index.php也没有东西&#xff0c;于是查看wp发现是访问/fl000g.txt这才是对的还有很多源码泄露需要去了解• git源码泄露• svn源码泄露• DS_Store 文件泄露• 网站备份…

Python 程序设计讲义(54):Python 的函数——函数概述

Python 程序设计讲义&#xff08;54&#xff09;&#xff1a;Python 的函数——函数概述 目录Python 程序设计讲义&#xff08;54&#xff09;&#xff1a;Python 的函数——函数概述一、函数的类型1、内置函数2、自定义函数二、调用函数Python 提供了函数机制&#xff0c;把实…

学习Python中Selenium模块的基本用法(3:下载浏览器驱动续)

前一篇文章主要介绍下载针对火狐浏览器的WebDriver&#xff0c;写那篇文章时才找到能够下最新版本Chrome的WebDriver地址&#xff08;参考文献6&#xff09;&#xff0c;本文继续学习并验证针对Chrome浏览器的WebDriver下载和使用方法。Chrome的WebDriver版本与操作系统相关&am…

AIDL当Parcelable序列化的数据类通信时报“Class not found when unmarshalling“找不到该类时的解决方案

1. 报错栈 &#xff1a;cusText这个类找不到 2 16:01:29.796 1044 5718 E Parcel : Class not found when unmarshalling: com.cus.sdk.cusText 08-02 16:01:29.796 1044 5718 E Parcel : java.lang.ClassNotFoundException: com.cus.sdk.cusText 08-02 16:01:29.796 1…

Django模型查询与性能调优:告别N+1问题

文章目录一、查询基础QuerySet 详解一对多关联查询多对多关联查询二、N1查询问题问题分析检测方法解决方案三、高级查询优化values()values_list()values()和values_list()对比Q() 对象复杂查询查看生成的 SQL四、项目实战场景实战一、查询基础 QuerySet 详解 Django 中通过模…

PyTorch 中 Tensor 统计学函数及相关概念

文章目录PyTorch 中 Tensor 统计学函数及相关概念一、引言二、基础统计学函数&#xff08;一&#xff09;torch.mean()——均值计算&#xff08;二&#xff09;torch.sum()——总和计算&#xff08;三&#xff09;torch.prod()——元素积计算&#xff08;四&#xff09;torch.m…

浅拷贝与深拷贝的区别

浅拷贝和深拷贝是两种不同的对象复制方式&#xff0c;主要区别在于它们如何处理对象内部的引用类型字段。浅拷贝 (Shallow Copy)特点&#xff1a;只复制对象本身&#xff08;基本类型字段&#xff09;和对象中的引用&#xff08;地址&#xff09;不复制引用指向的实际对象原始对…

脚本统计MongoDB集合表数据量

脚本&#xff1a; #!/bin/bashipxxx.xx.xx.xx portxxxx dbxxxdb #user #passwmongo -host ${ip}:${port} <<EOF 2>/dev/null|grep -vE version|not match|session|compressors||Warning|delivers|upcoming|installation|https|switched|bye >collec use ${db}; sho…

图漾AGV行业常用相机使用文档

文章目录1.图漾相机设置IP1.1 前期准备2.FM851-E2相机2.1 FM851-E2适用场景2.2 FM851-E2 IO线和数据线定义2.2.1 IO接口定义2.2.2 数据接口线2.2.3 相机正面安装方向2.2.4 相机IO指示灯2.3 FM851-E2/FM855-E2-7相机RGB颜色异常【解决措施1】&#xff1a;【解决措施2】&#xff…

电力系统分析学习笔记(二)- 标幺值计算与变压器建模

电力系统分析学习笔记&#xff08;二&#xff09;- 标幺值计算与变压器建模 1. 电力系统参数计算的基本原理 1.1 基本级的概念与选择 基本级定义&#xff1a; 在多电压等级的电力系统中&#xff0c;需要将所有参数归算到同一个电压等级这个统一的电压等级称为基本级 基本级选择…

防火墙相关技术内容

防火墙的状态检测和会话技术一、防火墙的检测机制早期包过滤防火墙采用逐包检测机制&#xff0c;对每个报文独立检测其源地址、目的地址、端口等信息&#xff0c;根据预设规则决定转发或丢弃。安全隐患&#xff1a;仅基于单包信息判断&#xff0c;无法识别连接状态。例如&#…

在 Mac 上用 Vagrant 安装 K8s

文章目录&#x1f4cb; 1. 环境准备1.1 系统要求1.2 软件清单&#x1f680; 2. 安装步骤2.1 安装Parallels Desktop2.2 配置网络代理&#xff08;可选&#xff09;2.3 安装Homebrew2,4 准备项目目录2.5 安装Vagrant及插件2.6 配置Python环境2.6.1 安装Python管理工具2.6.2 配置…