神经网络训练过程详解

神经网络训练过程是一个动态的、迭代的学习过程,接下来基于一段代码展示模型是如何逐步学习数据规律的。

神经网络拟合二次函数:代码详解

下面将详细解释这段代码,它使用神经网络拟合一个带有噪声的二次函数 y = x² + 2x + 1

import torch
import numpy as np
import matplotlib.pyplot as plt# 1. 生成模拟数据
np.random.seed(22)  # 设置随机种子确保结果可重现
X = np.linspace(-5, 5, 100).reshape(-1,1)  # 创建100个点,范围-5到5
y_ = X**2 + 2* X + 1  # 二次函数:y = x² + 2x + 1
y = y_ + np.random.rand(100, 1) * 1.5  # 添加均匀分布噪声# 转换为PyTorch张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)# 打印数据信息
print("X 原类型:", type(X), " 形状:", X.shape)
print("X_tensor 类型:", type(X_tensor), " 形状:", X_tensor.shape)
print("\ny 原类型:", type(y), " 形状:", y.shape)
print("y_tensor 类型:", type(y_tensor), " 形状:", y_tensor.shape)
print("\nX_tensor 前 2 个元素:\n", X_tensor[:2])
print("y_tensor 前 2 个元素:\n", y_tensor[:2])

代码解析

1. 神经网络定义

import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()# 定义三个全连接层self.layer1 = nn.Linear(1, 64)    # 输入层到隐藏层1 (1个特征->64个神经元)self.layer2 = nn.Linear(64, 128)  # 隐藏层1到隐藏层2 (64个神经元->128个神经元)self.layer3 = nn.Linear(128, 1)   # 隐藏层2到输出层 (128个神经元->1个输出)self.relu = nn.ReLU()             # ReLU激活函数def forward(self, x):# 前向传播过程(带激活函数)x = self.relu(self.layer1(x))  # 第一层后接ReLU激活x = self.relu(self.layer2(x))  # 第二层后接ReLU激活x = self.layer3(x)             # 输出层(无激活函数)return xdef forward_line(self, x):# 前向传播过程(不带激活函数)x = self.layer1(x)  # 无激活x = self.layer2(x)  # 无激活x = self.layer3(x)  # 输出层return x

网络结构详解

输入层(1个神经元) → [ReLU激活] → 隐藏层1(64个神经元) → [ReLU激活] → 隐藏层2(128个神经元) → 输出层(1个神经元)
  • 输入层:1个神经元,对应单个特征x
  • 隐藏层1:64个神经元,使用ReLU激活函数
  • 隐藏层2:128个神经元,使用ReLU激活函数
  • 输出层:1个神经元,无激活函数(回归问题)
  • 为什么需要多层和ReLU?:为了拟合非线性关系(二次函数)

2. 模型训练

import torch.optim as optim# 初始化模型、损失函数和优化器
model = SimpleNN()
criterion = nn.MSELoss()  # 均方误差损失(回归问题常用)
optimizer = optim.Adam(model.parameters(), lr=0.01)  # Adam优化器
epochs = 200  # 训练轮数# 训练循环
loss_history = []  # 记录损失变化
for epoch in range(epochs):# 前向传播outputs = model(X_tensor)loss = criterion(outputs, y_tensor)# 反向传播和优化optimizer.zero_grad()  # 清空梯度loss.backward()        # 反向传播计算梯度optimizer.step()       # 更新参数# 记录损失loss_history.append(loss.item())# 每50轮打印一次损失if(epoch + 1)%50 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

训练过程说明

  1. 前向传播:输入数据通过网络得到预测值
  2. 损失计算:比较预测值与真实值的差异(均方误差)
  3. 反向传播
    • zero_grad(): 清空之前的梯度
    • backward(): 计算新的梯度
  4. 参数更新step()使用梯度更新权重
  5. 损失记录:跟踪训练过程中的损失变化

3. 结果可视化

# 创建图表
plt.figure(figsize=(12,4))# 左图:训练损失曲线
plt.subplot(1,2,1)
plt.plot(loss_history)
plt.title('Training Loss')
plt.xlabel('Epochs')
plt.ylabel('MSE Loss')# 右图:预测结果对比
plt.subplot(1, 2, 2)
with torch.no_grad():  # 禁用梯度计算(预测时不需要)predictions = model(X_tensor).numpy()  # 获取预测值# 绘制三种数据:
plt.scatter(X, y, label='original data')  # 散点:带噪声的原始数据
plt.plot(X, y_, 'g--', label='True Relation', linewidth=4)  # 绿色虚线:真实二次函数
plt.plot(X, predictions, 'r--', label='prediction', linewidth=4)  # 红色虚线:神经网络预测plt.title('pre vs true')
plt.legend()  # 显示图例plt.tight_layout()
plt.show()

可视化解读

  1. 损失曲线:展示训练过程中损失值如何下降
  2. 预测对比
    • 散点:带噪声的原始数据
    • 绿色虚线:真实的二次函数 y = x² + 2x + 1
    • 红色虚线:神经网络预测的曲线

4. 模型保存与预测

# 1. 保存完整模型
torch.save(model, 'full_model.pth')# 2. 加载完整模型(无需提前定义模型结构)
loaded_model = torch.load('full_model.pth')
loaded_model.eval()  # 设置为评估模式(关闭dropout等)# 使用模型进行预测
new_data = torch.tensor([[3.0], [-2.5]], dtype=torch.float32)
predictions = loaded_model(new_data).detach().numpy()print("\nPrediction Examples:")
for x, pred in zip(new_data, predictions):# 计算真实值(注意:这里是二次函数,不是线性)true_value = x.item()**2 + 2 * x.item() + 1print(f"Input {x.item():.1f} -> Predicted: {pred[0]:.2f} | True: {true_value:.2f}")

模型保存与加载

  1. torch.save(model, 'full_model.pth'):保存整个模型结构+参数
  2. torch.load('full_model.pth'):加载完整模型
  3. model.eval():将模型设置为评估模式(影响dropout、batchnorm等层)

预测示例

  • 输入x=3.0:真实值=3² + 2×3 + 1 = 16.00
  • 输入x=-2.5:真实值=(-2.5)² + 2×(-2.5) + 1 = 2.25

代码关键点总结

  1. 数据生成

    • 创建二次函数 y = x² + 2x + 1
    • 添加均匀分布噪声模拟真实数据
  2. 网络结构

    • 深度网络(1-64-128-1)适合拟合非线性关系
    • 使用ReLU激活函数引入非线性能力
  3. 训练配置

    • 均方误差损失(MSE)适合回归问题
    • Adam优化器自动调整学习率
    • 200个训练轮次足够收敛
  4. 可视化

    • 损失曲线监控训练过程
    • 预测对比评估模型性能
  5. 模型部署

    • 保存和加载完整模型
    • 对新数据进行预测

为什么这个网络能拟合二次函数?

  1. 非线性激活函数:ReLU使网络能学习非线性关系
  2. 足够容量:两个隐藏层提供足够的表达能力
  3. 优化能力:Adam优化器有效调整参数
  4. 迭代训练:200轮训练使网络逐步逼近目标函数

这个示例展示了神经网络如何学习复杂的非线性关系,即使数据中存在噪声,网络也能捕捉到潜在的函数规律。

训练过程可视化

让我们通过一个动画来理解训练过程(想象以下动态变化):

Epoch 0:  损失: 35.42  | 预测线: 随机波动
Epoch 50: 损失: 2.65   | 预测线: 开始呈现线性趋势
Epoch 100:损失: 2.29   | 预测线: 接近真实关系但仍有偏差
Epoch 200:损失: 1.97   | 预测线: 几乎与真实关系重合
Epoch 500:损失: 1.97   | 预测线: 稳定在最优解附近

训练过程分步解析

1. 初始化阶段 (Epoch 0)

  • 权重和偏置随机初始化(通常使用正态分布或均匀分布)
  • 神经网络对数据一无所知
  • 预测结果完全随机
  • 损失值非常高(约35.42)

2. 早期训练阶段 (Epoch 1-50)

# 第一次迭代
outputs = model(X_tensor)  # 随机预测
loss = criterion(outputs, y_tensor)  # 计算损失(很大)
loss.backward()  # 计算梯度
optimizer.step()  # 首次更新参数
  • 网络开始识别数据的基本模式
  • 预测线开始呈现大致正确的斜率
  • 损失值快速下降(从35.42到约2.65)
  • 模型学习速度最快(梯度最大)

3. 中期训练阶段 (Epoch 50-200)

# 典型迭代
outputs = model(X_tensor)  # 预测接近真实值
loss = criterion(outputs, y_tensor)  # 中等损失
loss.backward()  # 计算较小梯度
optimizer.step()  # 微调参数
  • 网络捕捉到线性关系的基本特征
  • 预测线越来越接近绿色真实关系线
  • 损失值缓慢下降(从2.65到1.97)
  • 学习速度变慢(梯度变小)

4. 后期训练阶段 (Epoch 200-500)

# 后期迭代
outputs = model(X_tensor)  # 预测非常接近真实值
loss = criterion(outputs, y_tensor)  # 小损失
loss.backward()  # 计算微小梯度
optimizer.step()  # 微小调整参数
  • 网络优化细节
  • 预测线与真实关系线几乎重合
  • 损失值稳定在约1.97
  • 模型收敛(参数变化很小)

训练过程关键元素详解

1. 前向传播 (Forward Pass)

outputs = model(X_tensor)
  • 输入数据通过神经网络各层
  • 计算过程:
    输入x → 线性变换: z1 = w1*x + b1→ ReLU激活: a1 = max(0, z1)→ 线性变换: output = w2*a1 + b2
    

2. 损失计算 (Loss Calculation)

loss = criterion(outputs, y_tensor)
  • 计算预测值与真实值的差异
  • 使用均方误差公式:
    MSE = 1/N * Σ(预测值 - 真实值)²
    

3. 反向传播 (Backward Pass)

loss.backward()
  • 计算损失函数对每个参数的梯度
  • 使用链式法则从输出层向输入层反向传播
  • 梯度表示"参数应该如何调整以减少损失"

4. 参数更新 (Parameter Update)

optimizer.step()
  • Adam优化器根据梯度更新参数
  • 更新公式简化表示:
    新参数 = 旧参数 - 学习率 * 梯度
    
  • 学习率(0.01)控制更新步长

训练过程可视化分析

损失曲线图

损失值
35 |*················
30 | *···············
25 |  *··············
20 |   *·············
15 |    *············
10 |     *···········5 |      **·········2 |         ****····1 |             ****0 +-----------------→ Epoch0  50 100 200 500
  • 曲线特点:开始陡峭下降,后期平缓
  • 表明:初期学习快,后期优化细调

预测结果演变

真实关系: y = 2x + 1 (绿色虚线)Epoch 0:预测线: 随机波动 (红色线)Epoch 50:预测线: 大致正确斜率但截距偏差Epoch 100:预测线: 接近真实线,部分区域过拟合噪声Epoch 500:预测线: 几乎与绿色虚线重合

为什么训练有效?

  1. 梯度下降原理:每次更新都向减少损失的方向移动
  2. 链式法则:高效计算所有参数的梯度
  3. 自适应优化器:Adam自动调整学习率
  4. 非线性能力:ReLU激活函数使网络能拟合复杂模式
  5. 迭代优化:多次重复使模型逐步接近最优解

训练结束后的模型状态

  • 权重和偏置已优化到最佳值
  • 网络学习到了潜在规律 y = 2x + 1
  • 能够准确预测新数据点
  • 损失值稳定在最低点(约1.97)

这个训练过程展示了神经网络如何从随机初始状态开始,通过反复的预测、评估和调整,最终学习到数据背后的规律。即使数据中存在噪声,神经网络也能识别出真实的线性关系。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/93717.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/93717.shtml
英文地址,请注明出处:http://en.pswp.cn/web/93717.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode100-560和为K的子数组

本文基于各个大佬的文章上点关注下点赞,明天一定更灿烂!前言Python基础好像会了又好像没会,所有我直接开始刷leetcode一边抄样例代码一边学习吧。本系列文章用来记录学习中的思考,写给自己看的,也欢迎大家在评论区指导…

【PZ-ZU47DR-KFB】璞致FPGA ZYNQ UltraScalePlus RFSOC QSPI Flash 固化常见问题说明

1 Flash 固化Flash 固化需要先生成 BOOT.bin 文件,这边以裸机的串口工程进行讲解如何生成 BOOT.bin 文件及 Flash 固化操作。有读者会遇到,只使用 PL 端的情况,也需要进行 Flash 固化。我们需要添加 PS 端最小配置(包含 Flash 配置…

数据结构:查找表

一、数据结构的概念数据结构是指相互之间存在一种或多种特定关系的数据元素的集合。它不仅仅是存储数据的方式,更强调数据之间的逻辑关系和操作方法。数据结构主要从以下几个角度来理解:1. 数据之间的关系逻辑结构:集合结构:元素之…

自建知识库,向量数据库 (十)之 文本向量化——仙盟创梦IDE

自建文章向量化技术:AI 浪潮下初学者的进阶指南 在人工智能(AI)蓬勃发展的浪潮中,向量化作为将文本数据转化为数值向量表示的关键技术,成为理解和处理文本的基石。本文将结合给定的代码示例,深入探讨自建文…

数据结构 -- 顺序表的特点、操作函数

线性表顺序存储的优缺点优点无需为表中的逻辑关系增加额外的存储空间,利用连续的内存单元存储数据,存储密度高。支持 随机访问,通过下标可在 O(1) 时间复杂度内定位元素(如数组按索引取值),查询效率稳定。缺…

反向代理实现服务器联网

下载脚本:https://gitee.com/995770513/ssh-reverse-socket然后解压到 D:\Download在本机运行 cd D:\Download\ssh-reverse-socket-master\ssh-reverse-socket-master python socket5_proxy.py --ssh_cmd "xaserver10.150.10.51 -p 22" --socket5_port 78…

C语言关于函数传参和返回值的一些想法2(参数可修改的特殊情况)

我最近写了一篇文章名为“C语言关于函数传参和返回值的一些想法”(C语言关于函数传参和返回值的一些想法-CSDN博客),里面提到了一种观点就是传参的参数在函数体内部是只读的,不能写它,因为如果写了,也就是污…

前端AI对话功能实现攻略

一、对话内容渲染 在前端页面的 AI 对话场景中,对话内容的渲染效果直接影响用户的阅读体验和交互效率。合理选择对话格式、优化流式对话呈现、嵌入自定义内容以及实现语音播报等功能,是提升整体体验的关键。 对话格式选择 MarkDown 作为一种轻量级标记语…

深入理解Redis持久化:让你的数据永不丢失

1 Redis持久化概述 1.1 什么是Redis持久化 Redis作为一个高性能的内存数据库,默认情况下数据存储在内存中,这意味着一旦服务器重启或发生故障,内存中的数据将会丢失。为了保证数据的持久性和可靠性,Redis提供了持久化机制,将内存中的数据保存到磁盘中。 持久化是Redis实…

IC验证 AHB-RAM 项目(二)——接口与事务代码的编写

目录准备工作接口相关代码编写事务相关代码编写准备工作 DVT(Design and Verification Tools)是一款专门为 IC 验证打造的 IDE 插件,可以理解为智能的 Verilog/System Verilog 编辑器,在 VS Code、Eclipse 软件中使用。 接口相关…

基于Spring Boot的智能民宿预订与游玩系统设计与实现 民宿管理系统 民宿预订系统 民宿订房系统

🔥作者:it毕设实战小研🔥 💖简介:java、微信小程序、安卓;定制开发,远程调试 代码讲解,文档指导,ppt制作💖 精彩专栏推荐订阅:在下方专栏&#x1…

大模型的底层运算线性代数

深度学习的本质是用数学语言描述并处理真实世界中的信息,而线性代数正是这门语言的基石。它不仅提供了高效的数值计算工具,更在根本上定义了如何以可计算、可组合、可度量的方式表示和变换数据。 1 如何描述世界📊 真实世界的数据&#xff08…

Rust 中 i32 与 *i32 的深度解析

Rust 中 &i32 与 *i32 的深度解析 在 Rust 中,&i32 和 *i32 是两种完全不同的指针类型,它们在安全性、所有权和使用方式上有本质区别。以下是详细对比: 核心区别概览 #mermaid-svg-rCa8lLmHB7MK9P6K {font-family:"trebuchet ms…

【PyTorch项目实战】OpenNMT本地机器翻译框架 —— 支持本地部署和自定义训练

文章目录一、OpenNMT(Neural Machine Translation,NMT)1. 概述2. 核心特性3. 系统架构4. 与其他翻译工具的区别二、基于 OpenNMT-py 的机器翻译框架1. 环境配置(以OpenNMT-py版本为例)(1)pip安装…

基于prompt的生物信息学:多组学分析的新界面

以前总以为综述/评论是假大空,最近在朋友的影响下才发现,大佬的综述/评论内容的确很值得一读,也值得分享的。比如这篇讲我比较感兴趣的AI辅助生信分析的,相信大家都是已经实践中用上了,看看大佬的评论,拓宽…

Nacos-8--分析一下nacos中的AP和CP模式

Nacos支持两种模式来满足不同场景下的需求:AP模式(强调可用性)和CP模式(强调一致性)。 这两种模式的选择主要基于CAP理论,该理论指出在一个分布式系统中,无法同时保证一致性(Consist…

水闸安全监测的主要核心内容

水闸安全监测是指通过一系列技术手段和管理措施,对水闸的结构状态、运行性能及环境条件进行实时或定期的观测与评估,以确保水闸在设计寿命期内的安全性和可靠性。其核心目标是及时发现潜在的安全隐患,防止事故发生,保障水利工程的…

嵌入式系统学习Day19(数据结构)

数据结构的概念: 相互之间存在一种或多种特定关系的数据元素的集合。数据之间关系:逻辑关系:集合,线性(1对1,中间位置的值有且仅有一个前驱,一个后继),树(1对…

Pandas中数据清理、连接数据以及合并多个数据集的方法

一、简介1.数据清理的重要性:在进行数据分析前,需进行数据清理,使每个观测值成一行、每个变量成一列、每种观测单元构成一张表格。2.数据组合的必要性:数据整理好后,可能需要将多张表格组合才能进行某些分析&#xff0…

JavaSSM框架从入门到精通!第二天(MyBatis(一))!

一、 Mybatis 框架1. Mybatis 框架简介Mybatis 是 apache 的一个开源项目,名叫 iBatis ,2010 年这个项目由 apache 迁移到了 google,并命名为 Mybatis,2013 年迁移到了 GitHub,可以在 GitHub 下载源码。2. Mybatis 的下…