胡说八道:

各位观众老爷,大家好,我是诗人啊_,今天和各位分享模型剪枝的相关知识和操作,一文速通~
屏幕前的你,帅气低调有内涵,美丽大方很优雅… 所以,求个点赞、收藏、关注呗~
正经标题模型剪枝理论入门及 PyTorch API 实战
此文讲解 torch.nn.utils.prune 模块的使用,模型剪枝的执行步骤请看 ↓↓↓↓↓

模型剪枝的概念与实践(PyTorch版)

前言

深度神经网络的大型预训练模型往往依赖庞大的参数量实现SOTA效果,但生物神经网络却通过稀疏连接完成复杂任务。模型剪枝正是受此启发,通过将稠密连接转化为稀疏连接,在保持性能的前提下压缩模型,本文基于PyTorch详细介绍模型剪枝的概念与实操。
在这里插入图片描述

一、什么是模型剪枝?

  • 核心思想:仿照生物神经网络的稀疏连接特性,移除冗余参数或结构,实现模型压缩与加速。
  • 本质:将稠密网络转化为稀疏网络,在精度损失可接受的范围内减少参数量和计算量。
  • PyTorch支持:需使用torch.nn.utils.prune模块,要求PyTorch版本≥1.4.0,支持多种剪枝方式:
    • 特定网络模块的剪枝
    • 多参数模块的剪枝
    • 全局剪枝
    • 用户自定义剪枝

在这里插入图片描述

二、剪枝的基本原理(以LeNet为例)

2.1 准备工作

先定义经典LeNet网络作为示例:

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as Fdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 6, 3)  # 输入1通道,输出6通道,3x3卷积核self.conv2 = nn.Conv2d(6, 16, 3)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, int(x.nelement() / x.shape[0]))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xmodel = LeNet().to(device=device)

2.2 剪枝核心机制:掩码(Mask)

剪枝通过掩码张量实现参数筛选,核心逻辑如下:

  1. 原始参数(如weight)被拆分为:
    • weight_orig:保留原始参数值(可训练)
    • weight_mask:掩码张量(0表示剪枝移除,1表示保留)
  2. 实际使用的参数weight = weight_orig * weight_mask(被掩码为0的参数失效)
  3. 剪枝后,weight从可训练参数(Parameter)变为普通属性(Attribute)

2.3 单模块剪枝示例

conv1层的weight参数为例,执行随机非结构化剪枝:

module = model.conv1
# 对conv1的weight参数剪枝30%
prune.random_unstructured(module, name="weight", amount=0.3)
剪枝后参数变化:
  • named_parameters()weight变为weight_orig(保留原始值)
  • named_buffers()中新增weight_mask(掩码张量)
  • module.weightweight_orig * weight_mask的结果(含0值的剪枝后参数)
# 剪枝后参数查看
print("参数列表:", list(module.named_parameters()))  # 含weight_orig、bias
print("掩码列表:", list(module.named_buffers()))      # 含weight_mask
print("剪枝后weight:\n", module.weight)               # 含0值的有效参数

2.4 剪枝永久化(remove操作)

剪枝默认是临时的,执行prune.remove()可将掩码效果永久应用到参数:

# 永久化剪枝(无法撤销)
prune.remove(module, 'weight')
永久化后变化:
  • weight_orig消失,weight恢复为可训练参数(值 = 剪枝后的有效参数)
  • weight_mask被移除(无需保留)

三、常见剪枝方式实战

3.1 特定模块剪枝

针对单个模块的特定参数(如weightbias)剪枝,支持多种策略:

剪枝函数作用适用场景
random_unstructured随机移除单个参数非结构化剪枝(单权重)
l1_unstructured移除L1范数最小的单个参数非结构化剪枝(优先移除小值)
ln_structured移除Lₙ范数最小的结构化单元结构化剪枝(通道/神经元)
示例:对bias参数执行L1剪枝
# 对conv1的bias参数剪枝3个(绝对值最小的3个)
prune.l1_unstructured(module, name="bias", amount=3)
print("剪枝后bias:", module.bias)  # 含0值的剪枝后偏置

3.2 多参数模块剪枝

对模型中多个模块批量剪枝(如所有卷积层/全连接层):

# 对所有卷积层和全连接层分别剪枝
for name, module in model.named_modules():if isinstance(module, nn.Conv2d):# 卷积层:L1非结构化剪枝20%prune.l1_unstructured(module, name="weight", amount=0.2)elif isinstance(module, nn.Linear):# 全连接层:L2结构化剪枝40%prune.ln_structured(module, name="weight", amount=0.4, n=2, dim=0)
效果:
  • 所有卷积层的weight均被剪枝20%
  • 所有全连接层的weight均被剪枝40%
  • 每个模块独立生成weight_origweight_mask

3.3 全局剪枝(Global Pruning)

局部剪枝(单模块/多模块)要求每层剪枝比例固定,而全局剪枝以整个网络为单位分配剪枝比例(总剪枝量固定,每层比例自适应)。

示例:全局剪枝20%参数
# 定义参与剪枝的模块和参数
parameters_to_prune = ((model.conv1, 'weight'),(model.conv2, 'weight'),(model.fc1, 'weight'),(model.fc2, 'weight'),(model.fc3, 'weight')
)# 全局剪枝20%(总参数量的20%)
prune.global_unstructured(parameters_to_prune,pruning_method=prune.L1Unstructured,amount=0.2
)
特点:
  • 总剪枝比例固定(如20%),但每层剪枝比例不同
  • 重要性低的层(参数值小)会被剪枝更多
# 查看各层剪枝比例
print("conv1稀疏度:{:.2f}%".format(100 * torch.sum(model.conv1.weight == 0) / model.conv1.weight.nelement()
))
print("全局总稀疏度:{:.2f}%".format(100 * (torch.sum(model.conv1.weight == 0) + torch.sum(model.conv2.weight == 0) + ...) / (model.conv1.weight.nelement() + model.conv2.weight.nelement() + ...)
))

3.4 用户自定义剪枝

通过继承BasePruningMethod实现自定义剪枝规则,只需重写__init__compute_mask方法。

示例:每隔一个参数剪枝一个(50%比例)
class MyPruningMethod(prune.BasePruningMethod):PRUNING_TYPE = "unstructured"  # 非结构化剪枝(单参数)def compute_mask(self, t, default_mask):mask = default_mask.clone()# 自定义规则:每隔一个参数剪枝一个(索引为偶数的置0)mask.view(-1)[::2] = 0return mask# 封装为剪枝函数
def my_unstructured_pruning(module, name):MyPruningMethod.apply(module, name)return module# 对fc3的bias参数应用自定义剪枝
my_unstructured_pruning(model.fc3, name="bias")
print("自定义剪枝掩码:", model.fc3.bias_mask)  # 0和1交替出现

四、剪枝模型的序列化

剪枝后的模型状态字典(state_dict)会保留:

  • 原始参数:weight_origbias_orig
  • 掩码张量:weight_maskbias_mask
# 剪枝前后状态字典对比
print("剪枝前:", model.state_dict().keys())
# 执行剪枝...
print("剪枝后:", model.state_dict().keys())  # 含orig和mask

总结

  1. 核心逻辑:通过掩码张量筛选参数,实现模型稀疏化
  2. 关键操作:单模块剪枝→多模块批量剪枝→全局剪枝→自定义剪枝
  3. 实用技巧
    • 非结构化剪枝(单权重)适合压缩模型,结构化剪枝(通道/神经元)适合加速推理
    • 剪枝后建议微调模型,恢复精度损失
    • 永久化剪枝(remove)可减小模型存储体积

通过合理的剪枝策略,可在保持模型性能的同时显著降低参数量和计算成本,是模型部署的重要优化手段。

我是诗人啊_程序员,致力于分享人工智能方面的知识,近期 NLP 自然语言处理系列文章发布中,如果感兴趣,来个关注呗~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/94942.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/94942.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/94942.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Kubernetes 服务发现与健康检查详解

Kubernetes 提供了多种机制来管理服务发现、负载均衡和容器健康状态监控。本文将围绕以下几个方面展开:Service 类型:ClusterIP、NodePort、Headless Service、LoadBalancer(MetallB)Ingress 的实现原理健康检查探针:L…

如何规划一年、三年、五年的IP发展路线图?

‍在知识付费领域,规划 IP 发展路线,需要从短期、中期、长期不同阶段,系统地布局内容、运营与商业变现,逐步提升 IP 影响力与商业价值。一年目标:立足定位,夯实基础精准定位,打磨内容利用创客匠…

C++从入门到实战(二十)详细讲解C++List的使用及模拟实现

C从入门到实战(二十)C List的使用及模拟实现前言一、什么是List1.1 List的核心特性1.2 List与vector的核心差异1.3 List的构造、拷贝构造与析构1.3.1 常用构造函数1.3.2 析构函数1.4 List的迭代器1.4.1 迭代器类型与用法示例1:正向迭代器遍历…

人工智能学习:机器学习相关面试题(一)

1、 机器学习中特征的理解 def: 特征选择和降维 特征选择:原有特征选择出子集 ,不改变原来的特征空间 降维:将原有的特征重组成为包含信息更多的特征, 改变了原有的特征空间降维的主要方法 Principal Component Analysis (主成…

亚马逊巴西战略升级:物流网络重构背后的生态革新与技术赋能之路

在全球电商版图中,拉美市场正以惊人的增长速度成为新的战略高地,而巴西作为其中的核心市场,凭借庞大的人口基数、高速发展的数字经济以及不断提升的消费能力,吸引着众多电商巨头争相布局。近日,亚马逊宣布将于2025年底…

PS自由变换

自由变换 自由变换用来对图层、选区、路径或像素内容进行灵活的像素调整。可以进行缩放、旋转、扭曲等多种操作。快捷键:CtrlT,操作完成后使用Enter键可以确认变换自由变换过程中如果出现失误,可以按ESC退出;满意可以按enter确定。…

【K8s】整体认识K8s之存储--volume

为什么要用volume?首先。容器崩溃或重启时,所有的数据都会丢失,我们可以把数据保存到容器的外部,比如硬盘nfs,这样,即使容器没了,数据还在;第二就是容器之间是隔离的。我们如果想共享…

flutter工程

安装flutter 在VSCode中安装flutter extension、flutter组件 国内源下载flutter 3.35.2的SDK,安装,官网下载不了 将flutter安装目录加入环境变量中 D:\program\flutter_sdk\flutter\bin 执行 C:\Windows\System32>flutter --version Flutter 3.35.2 •…

C/C++ 高阶数据结构 —— 二叉搜索树(二叉排序树)

​ 🎁个人主页:工藤新一 ​ 🔍系列专栏:C面向对象(类和对象篇) ​ 🌟心中的天空之城,终会照亮我前方的路 ​ 🎉欢迎大家点赞👍评论📝收藏⭐文章…

stm32F4挂载emmc以及重定义printf

1.Cubemx SDIO USART 使用串口输出调试信息 FATFS Clock Configuration 防止堆栈溢出 2.Keil5 新建自定义文件夹及文件 将文件夹添加进工程 新建.c与.h文件,保存到自定义的文件夹,并添加到工程中 bsp_emmc.c #include "bsp_emmc.h" #include…

基于AI的大模型在S2B2C商城小程序中的应用与定价策略自我评估

摘要:本文聚焦电商行业,结合开源AI大模型与AI智能名片S2B2C商城小程序的技术特性,提出基于行业数据挖掘与自我评估的定价策略。通过分析行业价格分布与销量占比,结合商品设计、品牌创意度、商品丰富度及内功等评估指标&#xff0c…

中国移动云电脑一体机-创维LB2004_瑞芯微RK3566_2G+32G_开ADB安装软件教程

中国移动云电脑一体机-创维LB2004_瑞芯微RK3566_2G32G_开ADB安装软件教程简介:中国移动云电脑一体机-创维LB2004,显示器是23.8英寸1920x1080分辨率,安卓盒子配置是瑞芯微RK3566-四核-1.8GHz处理器-2G32G,预装Android11系统。具体操…

普蓝自研AutoTrack-4X导航套件平台适配高校机器人实操应用

在当前高校机器人工程、人工智能、自动化等专业的教学与科研中,师生们常常面临一个核心痛点:缺乏一套 “开箱即用、可深研、能落地” 的自主移动导航平台 —— 要么是纯仿真环境脱离实际硬件,要么是硬件零散需大量时间搭建,要么是…

2025年工会证考试题库及答案

一、单选题1.工会法人资格审查登记机关自收到申请登记表之日起(  )日内对有关申请文件进行审查,对审查合格者,办理登记手续,发放《工会法人资格证书》及其副本和《工会法人法定代表人证书》。A.二十B.十五C.六十D.三十答案:D 解析:第七条基…

【OpenGL】LearnOpenGL学习笔记17 - Cubemap、Skybox、环境映射(反射、折射)

上接:https://blog.csdn.net/weixin_44506615/article/details/150935025?spm1001.2014.3001.5501 完整代码:https://gitee.com/Duo1J/learn-open-gl | https://github.com/Duo1J/LearnOpenGL 一、立方体贴图 (Cubemap) 立方体贴图就是一个包含了6张2…

第十七章 ESP32S3 SW_PWM 实验

本章将介绍使用 ESP32-S3 LED 控制器(LEDC)。 LEDC 主要用于控制 LED,也可产生PWM信号用于其他设备的控制。该控制器有 8 路通道,可以产生独立的波形,驱动 RGB LED 等设备。 LED PWM 控制器可在无需 CPU 干预的情况下自动改变占空比&#xff…

Flink CDC如何保障数据的一致性

Flink CDC如何保障数据的一致性 前言 在大规模流处理中,故障是无可避免的。机器会宕机,网络会抖动。一个可靠的流处理引擎不仅要能高效地处理数据,更要在遇到这些故障时,保证计算结果的正确性。Apache Flink 正是因其强大的容错机…

Spring Boot 定时任务入门

1. 概述 在产品的色彩斑斓的黑的需求中,有存在一类需求,是需要去定时执行的,此时就需要使用到定时任务。例如说,每分钟扫描超时支付的订单,每小时清理一次日志文件,每天统计前一天的数据并生成报表&#x…

学习:uniapp全栈微信小程序vue3后台(6)

26.实现描述评分标签的双向数据绑定 /pages/wallpaper/picadd Array.prototype.splice() splice() 方法就地移除或者替换已存在的元素和/或添加新的元素。 二次确认 展现 确认标签 删除标签 温故知新: 标签: 关闭标签 27.uni-data-select调用云端分类…

Azure Marketplace 和 Microsoft AppSource的区别

微软的商业应用生态中,Azure Marketplace 和 Microsoft AppSource 是微软并行的两个主要“应用市场”(Marketplace),它们共同构成了微软的“商业市场”(Commercial Marketplace)计划,但服务的目…