九、批量标准化

是一种广泛使用的神经网络正则化技术,对每一层的输入进行标准化,进行缩放和平移,目的是加速训练,提高模型稳定性和泛化能力,通常在全连接层或是卷积层之和,激活函数之前使用

核心思想

对每一批数据的通道进行标准化,解决内部协变量偏移

        加速网络训练;运行使用更大的学习率;减少对初始化的依赖;提供轻微的正则化效果

思路:在输入上执行标准化操作,学习两可训练的参数:缩放因子γ和偏移量β

 批量标准化操作 在训练阶段和测试阶段行为是不同的。测试阶段没有mini_batch数据,无法直接计算当前batch的均值和方差,所以使用训练阶段计算的全局统量(均值和方差)进行标准化

1. 训练阶段的批量标准化

1.1 计算均值和方差

对于给定的神经网络层,输入,m是批次大小。我们计算该批次数据的均值和方差

均值

方差

1.2 标准化

用计算得到的均值和方差对数据进行标准化,使得没个特征的均值为0,方差为1

标准化后的值

ε是很小的常数,防止除0

1.3 缩放和平移

标准化的数据通常会通过可训练的参数进行缩放和平移,以挥发模型的表达能力

缩放

平移

γ和β是在训练过程中学习到的参数,会随着网络的训练过程通过反向传播进行更新

1.4 更新全局统计量

指数移动平均更新全局均值和方差

momentum是超变量,控制当前mini-batch统计量对全局统计量的贡献

它在0到1之间,控制mini-batch统计量的权重,在pytorch默认为0.1

与优化器中的momentum的区别

标准化中的:

更新全局统计量

控制当前mini-batch统计量对全局统计量的贡献

优化器中:

加速梯度下降,跳出局部最优

2.测试阶段的批量标准化

测试阶段没有mini-batch数据,所以通过EMA计算的全局统计量来进行标准化

测试阶段用全局统计量对输入数据进行标准化

对标准化后的数据进行缩放和平移

为什么用全局统计量

一致性

  • 测试阶段,输入数据通常是单个样本或少量样本无法准确计算均值和方差

  • 使用全局统计量可以确保测试阶段的行为与训练阶段一致

稳定性

  • 全局统计量是通过训练阶段的大量 mini-batch 数据计算得到的,能够更好地反映数据的整体分布

  • 使用全局统计量可以减少测试阶段的随机性,使模型的输出更加稳定

效率

  • 在测试阶段,使用预先计算的全局统计量可以避免重复计算,提高效率。

3. 作用

3.1 缓解梯度问题

防止激活值过大或过小,避免激活函数的饱和,缓解梯度消失或爆炸

3.2 加速训练

输入值分布更稳定,提高学习训练的效率,加速收敛

3.3 减少过拟合

类似于正则化,有助于提高模型的泛化能力

避免对单一数据点的过度拟合

4. 函数说明

torch.nn.BatchNorm1d 是 PyTorch 中用于一维数据的批量标准化(Batch Normalization)模块。

torch.nn.BatchNorm1d(num_features,         # 输入数据的特征维度eps=1e-05,           # 用于数值稳定性的小常数momentum=0.1,        # 用于计算全局统计量的动量affine=True,         # 是否启用可学习的缩放和平移参数track_running_stats=True,  # 是否跟踪全局统计量device=None,         # 设备类型(如 CPU 或 GPU)dtype=None           # 数据类型
)

参数说明:

eps:用于数值稳定性的小常数,添加到方差的分母中,防止除零错误。默认值:1e-05

momentum:用于计算全局统计量(均值和方差)的动量默认值:0.1,参考本节1.4

affine:是否启用可学习的缩放和平移参数(γ和 β)。如果 affine=True,则模块会学习两个参数;如果 affine=False,则不学习参数,直接输出标准化后的值 。默认值:True

track_running_stats:是否跟踪全局统计量(均值和方差)。如果 track_running_stats=True,则在训练过程中计算并更新全局统计量,并在测试阶段使用这些统计量。如果 track_running_stats=False,则不跟踪全局统计量,每次标准化都使用当前 mini-batch 的统计量。默认值:True

4. 代码实现

import torch
from torch import nn
from matplotlib import pyplot as pltfrom sklearn.datasets import make_circles
from sklearn.model_selection import train_test_split
from torch.nn import functional as F
from torch import optim# 生成数据集:两个同心圆,内圈和外圈的点分别属于两个类别
x, y = make_circles(n_samples=2000, noise=0.1, factor=0.4, random_state=42)
# 转换为PyTorch张量
x = torch.tensor(x, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3,random_state=42)# 可视化数据集
plt.scatter(x[:, 0], x[:, 1], c=y, cmap='coolwarm', edgecolors="k")
plt.show()# 定义带批量归一化的神经网络
class NetWithBN(nn.Module):def __init__(self):super().__init__()# 第一层全连接层,输入维度2,输出维度64self.fc1 = nn.Linear(2, 64)# 第一层批量归一化self.bn1 = nn.BatchNorm1d(64)# 第二层全连接层,输入维度64,输出维度32self.fc2 = nn.Linear(64, 32)# 第二层批量归一化self.bn2 = nn.BatchNorm1d(32)# 第三层全连接层,输入维度32,输出维度2(两个类别)self.fc3 = nn.Linear(32, 2)def forward(self, x):# 前向传播:ReLU激活函数+批量归一化+全连接层x = F.relu(self.bn1(self.fc1(x)))x = F.relu(self.bn2(self.fc2(x)))x = self.fc3(x)return x# 定义不带批量归一化的神经网络
class NetWithoutBN(nn.Module):def __init__(self):super().__init__()# 第一层全连接层,输入维度2,输出维度64self.fc1 = nn.Linear(2, 64)# 第二层全连接层,输入维度64,输出维度32self.fc2 = nn.Linear(64, 32)# 第三层全连接层,输入维度32,输出维度2(两个类别)self.fc3 = nn.Linear(32, 2)def forward(self, x):# 前向传播:ReLU激活函数+全连接层x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 定义训练函数
def train(model, x_train, y_train, x_test, y_test, name, lr=0.1, epoches=500):# 定义交叉熵损失函数criterion = nn.CrossEntropyLoss()# 定义SGD优化器optimizer = optim.SGD(model.parameters(), lr=lr)# 用于记录训练损失和测试准确率train_loss = []test_acc = []for epoch in range(epoches):# 设置模型为训练模式model.train()# 前向传播y_pred = model(x_train)# 计算损失loss = criterion(y_pred, y_train)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 记录训练损失train_loss.append(loss.item())# 设置模型为评估模式model.eval()# 禁用梯度计算with torch.no_grad():# 前向传播y_test_pred = model(x_test)# 获取预测类别_, pred = torch.max(y_test_pred, dim=1)# 计算正确预测的数量correct = (pred == y_test).sum().item()# 计算测试准确率test_acc.append(correct / len(y_test))# 每100个epoch打印一次日志if epoch % 100 == 0:print(F"{name}|Epoch:{epoch},loss:{loss.item():.4f},acc:{test_acc[-1]:.4f}")return train_loss, test_acc# 创建带批量归一化的模型
model_bn = NetWithBN()
# 创建不带批量归一化的模型
model_nobn = NetWithoutBN()# 训练带批量归一化的模型
bn_train_loss, bn_test_acc = train(model_bn, x_train, y_train, x_test, y_test,name="BN")
# 训练不带批量归一化的模型
nobn_train_loss, nobn_test_acc = train(model_nobn, x_train, y_train, x_test, y_test,name="NoBN")# 定义绘图函数
def plot(bn_train_loss, nobn_train_loss, bn_test_acc, nobn_test_acc):# 创建绘图窗口fig = plt.figure(figsize=(10, 5))# 添加子图1:训练损失ax1 = fig.add_subplot(1, 2, 1)ax1.plot(bn_train_loss, "b", label="BN")ax1.plot(nobn_train_loss, "r", label="NoBN")ax1.legend()# 添加子图2:测试准确率ax2 = fig.add_subplot(1, 2, 2)ax2.plot(bn_test_acc, "b", label="BN")ax2.plot(nobn_test_acc, "r", label="NoBN")ax2.legend()# 显示图像plt.show()# 调用绘图函数
plot(bn_train_loss, nobn_train_loss, bn_test_acc, nobn_test_acc)

 

十、模型的保存和加载

 1.标准网络模型构建

class MyModel(nn.Module):def __init__(self,input_size,output_size):super(MyModel,self).__init__()self.fc1 = nn.Linear(input_size,128)self.fc2 = nn.Linear(128,64)self.fc3 = nn.Linear(64,output_size)def forward(self,x):x = self.fc1(x)x = self.fc2(x)output = self.fc3(x)return outputmodel = MyModel(input_size=10,output_size = 2)
x  =torch.randn(5,10)output = model(x)

 2. 序列化模型对象

模型保存

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

参数说明:

  • obj:要保存的对象,可以是模型、张量、字典等。

  • f:保存文件的路径或文件对象。可以是字符串(文件路径)或文件描述符。

  • pickle_module:用于序列化的模块,默认是 Python 的 pickle 模块。

  • pickle_protocol:pickle 模块的协议版本,默认是 DEFAULT_PROTOCOL(通常是最高版本)。

模型加载

torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

参数说明:

  • f:文件路径或文件对象。可以是字符串(文件路径)或文件描述符。

  • map_location:指定加载对象的设备位置(如 CPU 或 GPU)。默认是 None,表示保持原始设备位置。例如:map_location=torch.device('cpu') 将对象加载到 CPU。

  • pickle_module:用于反序列化的模块,默认是 Python 的 pickle 模块。

  • pickle_load_args:传递给 pickle_module.load() 的额外参数。

import torch
import torch.nn as nn
import pickleclass MyModel(nn.Module):def __init__(self,input_size,output_size):super(MyModel,self).__init__()self.fc1 = nn.Linear(input_size,output_size,128)self.fc2 = nn.Linear(128,64)self.fc3 = nn.Linear(64,output_size)def forward(self,x):x = self.fc1(x)x = self.fc2(x)output = self.fc3(x)return output
def test001():model = MyModel(input_size=128,output_size=32)torch.save(model,"model.pkl",pickle_module=pickle,pickle_protocol=2)def test002():model = torch.load("model.pkl",map_location = "cpu",pickle_module=pickle)print(model)test001()
test002()

.pkl是二进制文件,内容是通过pickle模块化序列的python对象。可能存在兼容问题(python2,3的区别)

.pth是二进制文件,序列化的pytorch模型或张量。

3. 模型保存参数

import torch
import torch.nn as nn
import torch.optim as optim
import pickleclass MyModle(nn.Module):def __init__(self,input_size,output_size):super(MyModle,self).__init__()self.fc1 = nn.Linear(input_size,128)self.fc2 = nn.Linear(128,64)self.fc3 = nn.Linear(64,output_size)def forward(self,x):x = self.fc1(x)x = self.fc2(x)output = self.fc3(x)return outputdef test003():model = MyModle(input_size=128,output_size=32)optimizer = optim.SGD(model.parameters(),lr = 0.01)save_dict = {"init_params":{"input_size":128,"output_size":32,},"accuracy":0.99,"model_state_dict":model.state_dict(),"optimizer_state_dict":optimizer.state_dict(),}torch.save(save_dict,"model_dict.pth")def test004():save_dict = torch.load("model_dict.pth")model = MyModle(input_size = save_dict["init_params"]["input_size"],output_size = save_dict["init_params"]["output_size"],)model.load_state_dict(save_dict["model_state_dict"])optimizer = optim.SGD(model.parameters(),lr = 0.01)optimizer.load_state_dict(save_dict["optimizer_state_dict"])print(save_dict["accuracy"])print(model)test003()
test004()

推理时加载模型参数简单如下

# 保存模型状态字典
torch.save(model.state_dict(), 'model.pth')
​
# 加载模型状态字典
model = MyModel(128, 32)
model.load_state_dict(torch.load('model.pth'))
​

十一、项目实战

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/89498.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/89498.shtml
英文地址,请注明出处:http://en.pswp.cn/web/89498.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据可视化-67】基于pyecharts的航空安全深度剖析:坠毁航班数据集可视化分析

🧑 博主简介:曾任某智慧城市类企业算法总监,目前在美国市场的物流公司从事高级算法工程师一职,深耕人工智能领域,精通python数据挖掘、可视化、机器学习等,发表过AI相关的专利并多次在AI类比赛中获奖。CSDN…

【科研绘图系列】R语言绘制分组箱线图

文章目录 介绍 加载R包 数据下载 导入数据 画图1 画图2 合并图 系统信息 参考 介绍 【科研绘图系列】R语言绘制分组箱线图 加载R包 library(ggplot2) library(patchwork)rm(list = ls()) options(stringsAsFactors = F)

基于Android的旅游计划App

项目介绍系统打开进入登录页面,如果没有注册过账号,点击注册按钮输入账号、密码、邮箱即可注册,注册后可登录进入系统,系统分为首页、预订、我的三大模块,下面具体详细说说三大模块功能说明。1.首页显示旅游备忘或旅游…

【LeetCode 2163. 删除元素后和的最小差值】解析

目录LeetCode中国站原文原始题目题目描述示例 1:示例 2:提示:讲解分割线的艺术:前后缀分解与优先队列的完美邂逅第一部分:算法思想 —— “分割线”与前后缀分解1. 想象一条看不见的“分割线”2. 前后缀分解&#xff1…

控制鼠标和键盘

控制鼠标和键盘的Python库Python中有多个库可以用于控制鼠标和键盘,常用的包括pyautogui、pynput、keyboard和mouse等。这些库提供了模拟用户输入的功能,适用于自动化测试、GUI操作等场景。使用pyautogui控制鼠标pyautogui是一个跨平台的库,支…

基于按键开源MultiButton框架深入理解代码框架(二)(指针的深入理解与应用)

文章目录2、针对该开源框架理解3、分析代码3.1 再谈指针、数组、数组指针3.2 继续分析源码2、针对该开源框架理解 在编写按键模块的框架中,一定要先梳理按键相关的结构体、枚举等变量。这些数据是判断按键按下、状态跳转、以及绑定按键事件的核心。 这一部分定义是…

web前端渡一大师课 CSS属性计算过程

你是否了解CSS 的属性计算过程呢? <body> <h1>这是一个h1标题</h1> </body> 目前我们没有设置改h1的任何样式,但是却能看到改h1有一定的默认样式,例如有默认的字体大小,默认的颜色 那么问题来了,我们这个h1元素上面除了有默认字体大小,默认颜色等…

Redis高频面试题:利用I/O多路复用实现高并发

Redis 通过 I/O 多路复用&#xff08;I/O Multiplexing&#xff09;技术实现高并发&#xff0c;这是其单线程模型能够高效处理大量客户端连接的关键。以下是通俗易懂的解释&#xff0c;结合 Redis 的工作原理&#xff0c;详细说明其实现过程。 1. 什么是 I/O 多路复用&#xff…

爬虫小知识(二)网页进行交互

一、提交信息到网页 1、模块核心逻辑 “提交信息到网页” 是网络交互关键环节&#xff0c;借助 requests 库的 post() 函数&#xff0c;能模拟浏览器向网页发数据&#xff08;如表单、文件 &#xff09;&#xff0c;实现信息上传&#xff0c;让我们能与网页背后的服务器 “沟通…

WPF学习(五)

文章目录一、FileStream和StreamWriter理解1.1、具体关系解析1.2、类比理解1.3、总结1.4、示例代码1.5、 WriteLine()和 Write&#xff08;&#xff09;的区别1.6、 StreamWriter.Close的作用二、一、FileStream和StreamWriter理解 在 C# 中&#xff0c;StreamWriter 和 FileS…

ctf.show-web习题-web2-最简单的sql注入-flag获取详解、总结

解题思路打开靶场既然提示是最简单的sql注入了&#xff0c;那么直接尝试永真登录1 or 11#这里闭合就是简单的单引号可以看到没登录成功&#xff0c;但是有回显&#xff1a;欢迎你&#xff0c;ctfshowsql注入最喜欢的就是回显了&#xff01;这题的思路就是靠这个回显&#xff0c…

upload-labs 靶场通关(1-20)

目录 Pass-01(JS 绕过) Pass-02(文件类型验证) Pass-03(黑名单验证) Pass-04(黑名单验证.htaccess) Pass-05(大小写绕过) Pass-06(末尾空格) Pass-07(增加一个.) Pass-08(增加一个::$DATA) Pass-09&#xff08;代码不严谨&#xff09; Pass-10&#xff08;PPHPHP&am…

[附源码+数据库+毕业论文]基于Spring+MyBatis+MySQL+Maven+vue实现的酒店预订管理系统,推荐!

摘 要 使用旧方法对酒店预订信息进行系统化管理已经不再让人们信赖了&#xff0c;把现在的网络信息技术运用在酒店预订信息的管理上面可以解决许多信息管理上面的难题&#xff0c;比如处理数据时间很长&#xff0c;数据存在错误不能及时纠正等问题。 这次开发的酒店预订管理系…

LSTM入门案例(时间序列预测)| pytorch实现(可复现)

需求 假如我有一个时间序列&#xff0c;例如是前113天的价格数据&#xff08;训练集&#xff09;&#xff0c;然后我希望借此预测后30天的数据&#xff08;测试集&#xff09;&#xff0c;实际上这143天的价格数据都已经有了。这里为了简单&#xff0c;每一天的数据只有一个价…

Axure RP 10 预览显示“无标题文档”的空白问题探索【护航版】

1. 安装情况 官网 Axure RP 10&#xff1a;Download Axure RP 10 - Axure &#xff08;PS&#xff1a;11都出了&#xff09; 版本&#xff1a;10.0.0.3924 激活码&#xff1a;49bb9513c40444b9bcc3ce49a7a022f9 &#xff08;10/11都可以用&#xff0c;但只尝试了10&#xff…

基于SpringBoot+Vue的汽车租赁系统(协同过滤算法、腾讯地图API、支付宝沙盒支付、WebsSocket实时聊天、ECharts图形化分析)

系统亮点&#xff1a;协同过滤算法、腾讯地图API、支付宝沙盒支付、WebsSocket实时聊天、ECharts图形化分析&#xff1b;01系统开发工具与环境搭建—前后端分离架构项目架构&#xff1a;B/S架构运行环境&#xff1a;win10/win11、jdk17前端&#xff1a;技术&#xff1a;框架Vue…

数据结构入门:像整理收纳一样简单!

在我们生活中&#xff0c;经常会面对这样的问题&#xff1a; “我要怎么整理我的衣柜&#xff1f;” “电脑里照片太多了&#xff0c;怎么归类才方便查找&#xff1f;” 其实&#xff0c;程序员也有类似的烦恼。他们不整理衣柜&#xff0c;而是“整理数据”。而这门关于如何“收…

力扣每日一题--2025.7.15

&#x1f4da; 力扣每日一题–2025.7.15 3135. 有效单词 &#xff08;简单&#xff09; 大家好&#xff01;今天我们要来聊聊一道有趣的编程题——有效单词 &#x1f4dd; 题目描述 题目分析 &#x1f4da; 题目要求我们判断一个字符串是否为有效单词。有效单词需要满足以下…

Mysql数据库——增删改查CRUD

文章目录一、数据库的基础命令二、创建表三、增(create)四、查询&#xff08;retrieve)五、条件查询&#xff08;where&#xff09;六、修改&#xff08;update&#xff09;七、删除&#xff08;delete&#xff09;一、数据库的基础命令 1.使用客户端连接服务器 mysql -u root…

关于pytorch虚拟环境及具体bug问题修改

本篇博客包含对于虚拟环境概念的讲解和代码实现过程中相关bug的解决关于虚拟环境我的pytorch虚拟环境在D盘&#xff0c;相应python解释器也在D盘&#xff08;一起&#xff09;&#xff0c;但是我的pycharm中的项目在C盘&#xff0c;使用的是pytorch的虚拟环境&#xff0c;这是为…