为了提升推理速度并降低部署成本,模型剪枝已成为关键技术。本文将结合实践操作,讲解YOLOv8模型剪枝的方法原理、实施步骤及注意事项。

虽然YOLOv8n版本本身参数量少、推理速度快,能满足大多数工业检测需求,但谷歌研究表明:通过对大模型进行裁剪得到的小模型往往性能更优。

本文基于其他博客的剪枝方法的代码实现,专门针对YOLOv8模型进行剪枝优化,能够理解模型剪枝的底层操作。其核心创新点在于利用BN层(Batch Normalization)的特性,实现高效的通道级剪枝操作。

一、剪枝的理论基础

  • BN参数的重要性:BN层中的缩放参数(γ)代表了卷积核的重要程度,通过裁剪γ值较小的卷积核,可以实现剪枝。
  • 剪枝流程总体架构
    1. 训练稀疏模型(引入BN正则化)
    2. 计算剪枝阈值
    3. 剪除冗余卷积核
    4. 微调模型,恢复性能

二、YOLOv8剪枝的具体步骤

1. 预备工作

  • 模型训练: 先进行完整训练,获得基准性能指标。
  • 将LL_pruning.pyLL_train.py这两个文件放在根目录下

    LL_train.py代码如下所示:
    from ultralytics import YOLO  # 导入YOLO模型库  
    import os  # 导入os模块,用于处理文件路径  root = os.getcwd()  # 获取当前工作目录  ## 配置文件路径  
    name_yaml = os.path.join(root, "ultralytics/datasets/VOC.yaml")  # 数据集配置文件路径  
    name_pretrain = os.path.join(root, r"D:\practice_demo\ultralytics\runs\detect\jueyuanzi_yolov8m\best.pt")  # 预训练模型路径  ## 原始训练路径  
    path_train = os.path.join(root, "runs/detect/VOC")  # 原始训练结果保存路径  
    name_train = os.path.join(path_train, "weights/last.pt")  # 原始训练模型文件路径  ## 约束训练路径、剪枝模型文件  
    path_constraint_train = os.path.join(root, "runs/detect/VOC_Constraint")  # 约束训练结果保存路径  
    name_prune_before = os.path.join(path_constraint_train, "weights/last.pt")  # 剪枝前模型文件路径  
    name_prune_after = os.path.join(path_constraint_train, "weights/last_prune.pt")  # 剪枝后模型文件路径  ## 微调路径  
    path_fineturn = os.path.join(root, "runs/detect/VOC_finetune")  # 微调结果保存路径  def step1_train():  model = YOLO(name_pretrain)  # 加载预训练模型  model.train(data=name_yaml, imgsz=640, epochs=300, batch=32, name=path_train)  # 训练模型  ## 一定要添加【amp=False】  
    def step2_Constraint_train():  model = YOLO(name_train)  # 加载原始训练模型  model.train(data=name_yaml, imgsz=640, epochs=50, batch=32, amp=False, save_period=1, name=path_constraint_train)  # 训练模型  def step3_pruning():  from LL_pruning import do_pruning  # 导入剪枝函数  do_pruning(name_prune_before, name_prune_after)  # 执行剪枝操作  def step4_finetune():  model = YOLO(name_prune_after)  # 加载剪枝后的模型  model.train(data=name_yaml, imgsz=640, epochs=100, batch=32, save_period=1, name=path_fineturn)  # 微调模型  # 执行训练、约束训练、剪枝和微调步骤  
    step1_train()  # 训练模型  
    # step2_Constraint_train()  # 进行稀疏训练  
    # step3_pruning()  # 执行剪枝  
    # step4_finetune()  # 微调模型

LL_pruning.py代码如下所示:

​
from ultralytics import YOLO  # 导入YOLO模型
import torch  # 导入PyTorch库
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect  # 导入YOLO模型中的模块
import os  # 导入os模块,用于处理文件路径# os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # 可选:指定使用的GPU设备class PRUNE():def __init__(self) -> None:self.threshold = None  # 初始化阈值def get_threshold(self, model, factor=0.8):"""计算剪枝阈值:param model: YOLO模型:param factor: 剪枝比例,默认0.8"""ws = []  # 存储权重bs = []  # 存储偏置for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):  # 仅处理BatchNorm2d层w = m.weight.abs().detach()  # 获取权重的绝对值b = m.bias.abs().detach()  # 获取偏置的绝对值ws.append(w)  # 添加权重bs.append(b)  # 添加偏置print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())  # 打印权重和偏置的最大最小值# 合并所有权重ws = torch.cat(ws)# 计算剪枝阈值self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]def prune_conv(self, conv1: Conv, conv2: Conv):"""对卷积层的“相邻”卷积做通道级剪枝。参数----:param conv1: 第一个卷积层: Conv(Ultralytics封装的Conv模块,内部含 nn.Conv2d + BN + 激活)*上游* 被剪枝的卷积。删除它的某些 输出 通道。:param conv2: 第二个卷积层: Conv 或 Conv列表 / 纯 nn.Conv2d / None*下游* 接收 conv1 输出的卷积(可能有多支分支)。需要把 输入 通道同步删除。剪枝规则--------1. 用 conv1 中 BatchNorm 的缩放系数 γ 的绝对值做“重要性”指标。2. 选出 |γ| >= 全局阈值 的通道索引 keep_idxs(若太少则降低阈值,至少保留8个,防止结构非法)。3. 在 conv1 中:删掉其它通道 → 需要同时修改 BN 的各种统计量与 nn.Conv2d 的权重/偏置/out_channels。4. 在 conv2 中:这些被删的只是“输入特征图”,因此只更新 in_channels。"""# a. 根据BN中的参数,获取需要保留的indexgamma = conv1.bn.weight.data.detach()  # 获取BN层的权重beta = conv1.bn.bias.data.detach()  # 获取BN层的偏置keep_idxs = []  # 存储需要保留的索引local_threshold = self.threshold  # 使用全局阈值while len(keep_idxs) < 8:  # 确保至少保留8个卷积核keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]  # 获取满足条件的索引local_threshold = local_threshold * 0.5  # 如果不足8个,降低阈值n = len(keep_idxs)  # 保留的卷积核数量print(n / len(gamma))  # 打印保留的比例# b. 利用index对BN进行剪枝conv1.bn.weight.data = gamma[keep_idxs]  # 更新BN权重conv1.bn.bias.data = beta[keep_idxs]  # 更新BN偏置conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]  # 更新BN的方差conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]  # 更新BN的均值conv1.bn.num_features = n  # 更新BN的特征数量conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]  # 更新卷积层的权重conv1.conv.out_channels = n  # 更新卷积层的输出通道数# c. 利用index对conv1进行剪枝if conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]  # 更新卷积层的偏置# d. 利用index对conv2进行剪枝if not isinstance(conv2, list):conv2 = [conv2]  # 确保conv2是列表for item in conv2:if item is None: continue  # 跳过Noneif isinstance(item, Conv):conv = item.conv  # 获取卷积层else:conv = itemconv.in_channels = n  # 更新输入通道数conv.weight.data = conv.weight.data[:, keep_idxs]  # 更新卷积层的权重def prune(self, m1, m2):"""对模块进行剪枝:param m1: 第一个模块:param m2: 第二个模块"""if isinstance(m1, C2f):  # 如果m1是C2f模块,获取其cv2m1 = m1.cv2if not isinstance(m2, list):  # 确保m2是列表m2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1  # 获取C2f或SPPF的cv1self.prune_conv(m1, m2)  # 对卷积层进行剪枝def do_pruning(modelpath, savepath):"""执行剪枝操作:param modelpath: 原始模型路径:param savepath: 剪枝后模型保存路径"""pruning = PRUNE()  # 创建PRUNE实例### 0. 加载模型yolo = YOLO(modelpath)  # 从指定路径加载YOLO模型pruning.get_threshold(yolo.model, 0.8)  # 获取剪枝阈值,0.8为剪枝率### 1. 剪枝c2f中的Bottleneckfor name, m in yolo.model.named_modules():if isinstance(m, Bottleneck):  # 仅处理Bottleneck模块pruning.prune_conv(m.cv1, m.cv2)  # 对Bottleneck中的卷积层进行剪枝### 2. 指定剪枝不同模块之间的卷积核seq = yolo.model.model  # 获取模型的序列for i in [3, 5, 7, 8]:  # 指定需要剪枝的模块pruning.prune(seq[i], seq[i + 1])  # 对相邻模块进行剪枝### 3. 对检测头进行剪枝detect: Detect = seq[-1]  # 获取检测头last_inputs = [seq[15], seq[18], seq[21]]  # 获取最后输入的模块colasts = [seq[16], seq[19], None]  # 获取与最后输入相连的模块for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):pruning.prune(last_input, [colast, cv2[0], cv3[0]])  # 对输入模块和检测头进行剪枝pruning.prune(cv2[0], cv2[1])  # 对检测头的卷积层进行剪枝pruning.prune(cv2[1], cv2[2])  # 对检测头的卷积层进行剪枝pruning.prune(cv3[0], cv3[1])  # 对检测头的卷积层进行剪枝pruning.prune(cv3[1], cv3[2])  # 对检测头的卷积层进行剪枝### 4. 模型梯度设置与保存for name, p in yolo.model.named_parameters():p.requires_grad = True  # 设置所有参数的梯度为可计算# yolo.val()  # 验证模型性能torch.save(yolo.ckpt, savepath)  # 保存剪枝后的模型yolo.model.pt_path = yolo.model.pt_path.replace("last.pt", os.path.basename(savepath))  # 更新模型路径yolo.export(format="onnx")  # 导出为ONNX格式## 重新加载模型,修改保存命名,用以比较剪枝前后的onnx的大小yolo = YOLO(modelpath)  # 从指定路径加载YOLO模型yolo.export(format="onnx")  # 导出为ONNX格式if __name__ == "__main__":modelpath = "runs/detect1/14_Constraint/weights/last.pt"  # 原始模型路径savepath = "runs/detect1/14_Constraint/weights/last_prune.pt"  # 剪枝后模型保存路径do_pruning(modelpath, savepath)  # 执行剪枝操作​

2. 稀疏正则训练

  • 使用带有 BN正则的训练方式,促进BN参数稀疏化。

首先加载一个正常训练的yolov8模型权重(.pt文件),ultralytics/engine/trainer.py中添加如下代码,使得bn参数在训练时变得稀疏。

代码中对所有 BatchNorm 层加了 L1 正则,以便自动把不重要的通道“压”成零,后面再统一按阈值剪枝。关键代码如下:

...## add start=============================## add l1 regulation for step2_Constraint_trainl1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)for k, m in self.model.named_modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))## add end ==============================...
  • 为什么只对 BN 做正则?
    BatchNorm 的 γ(scale)系数直接影响通道输出强度:γ ≈ 0 时,该通道几乎不参与后续计算,用它来衡量“重要性”最直观。

  • L1 正则如何“稀疏”?
    在反向传播时,为每个 γ/β 的梯度额外加上 ±λ,这会让本就小的 γ 更快被拉向 0,从而在训练中自然分化出大 γ(保留通道)和小 γ(待剪通道)。

  • λ 为何随 epoch 递减?
    训练初期靠强正则快速分离;后期减弱正则,避免过度压榨保留通道,给微调留下空间。

  • bias 也正则吗?
    虽然偏置对通道筛选作用不如 γ 强,但适度收敛 β 能进一步去除边缘特征,提高稀疏度。

之后在LL_pruning.py中运行方框中的代码

注意事项:

稀疏训练需要关闭混合精度(amp=False
剪枝依赖于 BatchNorm 的 γ 值作为排序阈值,γ 越小越容易被剪除。若使用 FP16(混合精度),许多接近 0 的 γ 会被量化到同一值甚至下溢为 0,导致排序失真,同时 L1 正则梯度也容易消失,后续剪枝的阈值选择会变得不稳定。而使用 FP32(amp=False)能精确表示这些微小差异,确保稀疏模式可控。

稀疏训练的 batch size 不宜过大
由于关闭了混合精度,模型采用全精度计算,显存占用显著增加。若 batch size 设置过大,可能导致显存溢出(OOM),进而引发训练失败。

稀疏训练阶段要将 patience 设为 0 或较大值
稀疏训练的目标并非短期提升 mAP,而是让 BN 的 γ 在多个 epoch 内逐步被 L1 正则“压缩”。在此期间,验证集指标可能停滞甚至下降。若启用常规早停机制(默认 patience 为几十),训练可能在 γ 尚未充分分化前被提前终止,导致剪枝时阈值模糊、可剪通道不足。

3. 剪枝

执行以下代码;

剪枝中的注意点:

在 YOLOv8 中,当进行 split concat 操作时,若剪枝后的通道数不匹配会报错。LL_pruning.py 的剪枝代码怎么避免这一问题,暂时还没研究透,有大佬知道请不吝指教。

关于 do_pruning 方法启用 yolo.val() 后保存的剪枝模型缺失 BN 层的原因:
Ultralytics 的验证 / 导出流程会将 Conv + BatchNorm 静态融合到卷积权重和偏置中,从而提升推理速度和轻量化。这一过程会直接移除 BN 层,因此保存的 yolo.ckpt 是已融合的模型。

对比剪枝前后的模型文件(last.pt/last_prune.pt)及其 ONNX 转换结果:
剪枝后的 .pt 文件增大,而 ONNX 文件从 43MB 缩减至 36MB。这是因为 .pt 文件包含完整的 checkpoint 元数据,而 ONNX 仅保存精简的推理图结构,因此只需关注 ONNX 文件大小的优化即可。

4. 微调

在第二步稀疏正则训练中将BN约束注释

需要注意的是明明加载的是剪枝后的模型,但训练启动时打印的日志却显示为标准版模型的参数。并且经过验证,微调后的模型参数就是标准的yolo模型。所以需要进行一些修改,详细的讲解可以看YOLOv8 剪枝模型加载踩坑记:解决 YAML 覆盖剪枝结构的问题-CSDN博客

修改ultralytics/engine/model.py文件内容:
self.trainer.model包含从YAML文件加载的原始模型配置信息,以及从PT文件加载的剪枝后权重。只需将该变量的网络结构更新为剪枝后的网络结构就行,否则训练后的模型参数不会改变。

运行下面的代码

yolov8模型的剪枝到这就结束了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/90645.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/90645.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/90645.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JavaSE:随机数生成

随机数在游戏开发、密码学、模拟测试等场景中扮演着关键角色。本文将深入探讨Java中两种主流的随机数生成技术&#xff1a;Random类和Math.random()方法&#xff0c;并解析背后的类与对象概念&#xff0c;助你全面掌握随机数生成的核心机制。一、随机数生成的两大技术 Java提供…

Android 持久化存储原理与使用解析

一、核心存储方案详解1. SharedPreferences (SP)使用方式&#xff1a;// 获取实例 SharedPreferences sp getSharedPreferences("user_prefs", MODE_PRIVATE);// 写入数据 sp.edit().putString("username", "john_doe").putInt("login_cou…

无 sudo 权限的环境下将 nvcc (CUDA Toolkit) 安装到个人目录 linux

要在无 sudo 权限的环境下将 nvcc 安装到 home 个人目录&#xff0c;你可以手动安装 CUDA Toolkit 到你的 $HOME 目录&#xff0c;只需以下几步即可使用 nvcc 编译 CUDA 程序。 ✅ 步骤&#xff1a;本地安装 CUDA Toolkit&#xff08;含 nvcc&#xff09; 下载 CUDA Toolkit Ru…

从指标定义到AI执行流:衡石SENSE 6.0的BI PaaS如何重构ISV分析链路

一、痛点&#xff1a;ISV行业解决方案的“三重断链”传统ISV构建行业分析模块时面临的核心挑战&#xff1a;指标定义碎片化&#xff1a;客户A的“销售额”含税&#xff0c;客户B不含税&#xff0c;衍生指标无法复用&#xff1b;分析-执行割裂&#xff1a;发现库存异常后需人工导…

构建跨平台远程医疗系统中的视频通路技术方案探究

一、远程医疗走向日常化&#xff0c;音视频能力成为关键基础设施 随着医疗数字化与分级诊疗体系的不断演进&#xff0c;远程医疗正从试点探索阶段&#xff0c;逐步迈向常态化、标准化应用。从县域医院远程问诊、基层医疗协作&#xff0c;到大型三甲医院的术中协同、专科教学直…

Blackbox Exporter Docker 安装配置,并与 Prometheus 集成

1. 创建配置文件目录bashmkdir -p ~/docker/blackbox/config cd ~/docker/blackbox2. 创建 Blackbox Exporter 配置文件 config/blackbox.ymlyamlmodules:http_2xx: # HTTP 可用性检测(响应 2xx/3xx 状态码)prober: httphttp:valid_http_versions: ["HTTP/1.1", &qu…

杰理通用MCU串口+AT指令+485通讯工业语音芯片

一、概述 在现代智能设备与自动化系统中&#xff0c;语音交互功能日益普及&#xff0c;通用 MCU 语音芯片作为核心组件&#xff0c;承担着关键的语音处理任务。其强大的功能不仅体现在语音合成、识别等方面&#xff0c;还包括高效的通信能力。串口 AT 指令 485 通讯模式为通用…

Krpano 工具如何调节全景图片切割之后的分辨率

文章目录概要第一步1.1 复制一下这个文件中的key &#xff0c;打开 krpano Tools.exe第二步 修改切片之后的分辨率修改前的效果修改后的效果概要 前端渲染全景图模拟3D场景 Krpano 工具 获取到后的默认图片分辨率是2048*2048的&#xff0c;如果觉得分辨率低了可以自行在工具中…

物联网十大应用领域深度解析

一、智能物流技术基础&#xff1a;RFID、无线传感器网络、互联网与运筹学、供应链管理理论结合 应用场景&#xff1a;仓储管理&#xff1a;RFID标签实现库存实时监控&#xff0c;自动补货系统降低缺货率。配送优化&#xff1a;通过GPS与物联网数据分析规划最优路径&#xff0c;…

ElasticSearch基础数据查询和管理详解

目录 一、 ElasticSearch核心概念 1. 全文搜索&#xff08;Full-Text Search&#xff09; 2. 倒排索引&#xff08;Inverted Index&#xff09; 3. ElasticSearch常用术语 3.1 映射&#xff08;Mapping&#xff09; 3.2 索引&#xff08;Index&#xff09; 3.3 文档&…

SSE与Websocket有什么区别?

SSE&#xff08;Server-Sent Events&#xff09;和WebSocket都能实现服务器与客户端的实时通信&#xff0c;但它们在协议设计、应用场景和技术特性上有明显差异。以下从多个维度对比两者的区别&#xff1a; 1. 协议基础 SSE 基于HTTP协议&#xff0c;是HTTP的扩展。使用单向通…

力扣Hot100疑难杂症汇总

写在前面 这一篇博客主要用来记录力扣Hot100中我反复刷&#xff0c;但又反复错的难题&#xff0c;为了防止秋招手撕的时候尬住&#xff0c;写这篇博客记录一下那些容易遗忘而且对我来说难度较大的题目。后面复习的时候重点对着这个名单来刷题。 二叉树部分 114. 二叉树展开为…

硬核接线图+配置步骤:远程IO模块接入PLC全流程详解

远程IO模块和PLC&#xff08;可编程逻辑控制器&#xff09;的连接涉及多个方面&#xff0c;包括硬件准备、软件配置、接线方法以及注意事项等。PLC品牌大多分为国产、欧系、美系、日系。国产PLC主要有汇川、台达、和利时、信捷等品牌&#xff1b;欧美系PLC以西门子、施耐德、罗…

【数据结构】长幼有序:树、二叉树、堆与TOP-K问题的层次解析(含源码)

为什么我们要学那么多的数据结构&#xff1f;这是因为没有一种数据结构能够去应对所有场景。我们在不同的场景需要选择不同的数据结构&#xff0c;所以数据结构没有好坏之分&#xff0c;而评估数据结构的好坏要针对场景&#xff0c;就如我们已经学习的结构而言&#xff0c;如果…

wps dispimg python 解析实现参考

在 wps excel 中&#xff0c;可以把图片嵌入单元格&#xff0c;此时会图片单元格会显示如下内容 DISPIMG("ID_142D0E21999C4D899C0723FF7FA4A9DD",1)下面是针对这中图片文件的解析实现 参考博客&#xff1a;Python读取wps中的DISPIMG图片格式_wps dispimg-CSDN博客:h…

Java学习---Spring及其衍生(下)

接下来就到了Spring的另外2个知名的衍生框架&#xff0c;SpringBoot和SpringCloud。其中&#xff0c;SpringBoot 是由 Pivotal 团队开发的一个基于 Spring 的框架&#xff0c;它的设计目的是简化 Spring 应用程序的初始搭建和开发过程。SpringBoot 遵循 “约定优于配置” 的原则…

残月头像阁

残月头像阁 使用说明: 直接上传服务器即可## 项目简介残月头像阁是一个简洁美观的头像网站开源程序 支持快速部署与自定义采用拟态(Neumorphism)设计风格&#xff0c;提供多种分类的头像## 功能特性- &#x1f5bc;️ 多分类头像展示&#xff08;男生、女生、卡通、情侣、动漫&…

文献综述AI生成免费工具推荐:高效整理文献

做学术研究时&#xff0c;文献综述无疑是让很多学子和科研工作者头疼的环节。查阅、筛选、梳理大量文献&#xff0c;然后进行归纳总结&#xff0c;最终形成一篇条理清晰的文献综述&#xff0c;这一整个过程常常耗费数日甚至数周。而面对课业压力与紧迫的论文截止时间&#xff0…

OpenCV —— contours_matrix_()_[]

&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️Take your time ! &#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️…

android 小bug :文件冲突的问题

文章目录前言1、问题&#xff1a;两个文件冲突了2、原因&#xff1a;3、结果&#xff1a;后语前言 一个身份证模块识别的小bug&#xff0c;记录一下&#xff0c;这应该是第三次出现&#xff0c;每次出现都不太记得&#xff0c;还是得记录&#xff0c;不然都是重复检索的过程。…