pytorch量化

[!note]

  • 官方定义:performing computations and storing tensors at lower bitwidths than floating point precision.
  • 支持INT8量化,可以降低4倍的模型大小以及显存需求,加速2-4倍的推理速度
  • 通俗理解:降低权重和激活值的精度(FP32→INT8),从而提高模型大小以及显存需求。

一、前置知识

1.1 算子融合

​ 将多个连续层的计算操作合并为单个复合算子,减少对内存的访问次数

e.g. 例如将Conv → BN → ReLU, 融合为ConvBnReLU

操作流程内存访问次数计算强度
未融合(3个算子)6次
已融合(1个算子)2次

​ NVIDA GPU:

// 未融合:多次启动核函数
conv_kernel<<<...>>>(input, weight, temp1);
bias_kernel<<<...>>>(temp1, bias, temp2);
relu_kernel<<<...>>>(temp2, output);// 已融合:单核函数完成所有操作
fused_kernel<<<...>>>(input, weight, bias, output) {float val = conv2d(input, weight);val += bias;output = max(val, 0.0f);
}

二、量化知识

2.1 对称量化 & 非对称量化

⚙️ 区别

  • 对称量化(Symmetric Quantization)

X i n t = r o u n d ( X f l o a t s c a l e ) , s c a l e = m a x ( ∣ X ∣ ) 2 n − 1 − 1 X_{int}=round(\frac{X_{float}}{scale}), scale = \frac{max(|X|)}{2^{n-1}-1} Xint=round(scaleXfloat),scale=2n11max(X)

  • 非对称量化(Affine Quantization)

X i n t = r o u n d ( X f l o a t s c a l e ) + z e r o _ p o i n t , s c a l e = m a x x − m i n ) x 2 n − 1 X_{int}=round(\frac{X_{float}}{scale}) + zero\_point, scale = \frac{max_x-min_)x}{2^{n}-1} Xint=round(scaleXfloat)+zero_point,scale=2n1maxxmin)x

z e r o _ p o i n t = r o u n d ( − m i n ( x ) s c a l e ) zero\_point = round(\frac{-min(x)}{scale}) zero_point=round(scalemin(x))

特性对称量化(Symmetric Quantization)非对称量化(Affine Quantization)
零点位置固定为0动态计算(zero_point)
数值范围[-127, 127] (int8)[0, 255] (uint8)
计算开销更低(无需zero_point计算)更高
精度损失对偏斜分布敏感更鲁棒,能更好处理数据分布偏斜的情况
典型应用权重量化(正负均衡)激活值量化
硬件支持广泛支持(如GPU/TPU)需要额外处理zero_point

🤖 工程实现角度:为什么 PTQ 常用非对称,QAT 用对称

模式推荐默认背后原因
PTQ权重:对称 激活:非对称因为激活是不可训练的静态量化,非对称能更好地适应非负分布
QAT权重:对称 激活:对称(人为设定)因为激活是可训练的,你可以通过训练让它“对称”起来,精度损失更可控
2.2 PTQ & QAT

[!note]

PTQ 是直接对训练后的模型参数进行量化,因此适合于快速部署;QAT是通过插入伪量化节点,在训练过程中模拟量化误差以达到更高的精度,因此需要重新训练。

⚙️ 区别

特性PTQ(训练后量化)QAT(量化感知训练)
训练阶段仅FP32训练插入伪量化节点训练
反向传播❌ 不支持✅ 通过STE支持
精度损失较大(尤其小模型)通常更小
计算开销低(仅需校准)高(需完整训练)
典型用途快速部署高精度要求的场景

[!tip]

QAT伪量化节点

  • 作用:在训练时模拟量化的误差。在每一层训练时,权重、激活值依然是FT32,但在每一层的传播中,值被“量化再还原”,模拟了量化过程。
  • 由于量化过程有round函数,是不可微的,因此需要Straight-Through Estimator(STE)近似梯度的 FakeQuant 模块

三、Pytorch实现量化的三种方式

参考链接:Quantization — PyTorch 2.7 documentation

特性Eager Mode QATFX Graph QATExport QAT
实现方式动态图模式符号化重写编译器优化
控制流支持
算子融合❌(只能手动融合)✅🌟
典型APIprepare_qatprepare_fxexport
Type只支持module支持 module & function支持 module & function

[!note]

无论是PTQ 还是 QAT , 每一种实现方式都需要 prepare_fx 和 convert_fx

model_prepared = quantize_fx.prepare_fx(model, qconfig_mapping, example_inputs)
model_quantized = quantize_fx.convert_fx(model_prepared)

🎯 核心功能:在模型的每一个 qconfig_mapping 指定的量化位置(如 Conv2d、Linear)处,插入对应的 observerfake_quant 节点。

📦 插入两类模块:

类型对应 prepare 的用途说明
Observer用于 PTQ统计 min/max 用来 校准计算 scale 和 zero_point
FakeQuantize用于 QAT模拟量化误差,保留梯度流动,支持训练
3.1 Eager Mode Quantization
import torch# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):def __init__(self):super().__init__()# QuantStub converts tensors from floating point to quantizedself.quant = torch.ao.quantization.QuantStub()self.conv = torch.nn.Conv2d(1, 1, 1)self.bn = torch.nn.BatchNorm2d(1)self.relu = torch.nn.ReLU()# DeQuantStub converts tensors from quantized to floating pointself.dequant = torch.ao.quantization.DeQuantStub()def forward(self, x):x = self.quant(x)x = self.conv(x)x = self.bn(x)x = self.relu(x)x = self.dequant(x)return x# create a model instance
model_fp32 = M()# model must be set to eval for fusion to work
model_fp32.eval()# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')# fuse the activations to preceding layers, where applicable
# this needs to be done manually depending on the model architecture
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,[['conv', 'bn', 'relu']])# Prepare the model for QAT. This inserts observers and fake_quants in
# the model needs to be set to train for QAT logic to work
# the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused.train())# run the training loop (not shown)
training_loop(model_fp32_prepared)# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, fuses modules where appropriate,
# and replaces key operators with quantized implementations.
model_fp32_prepared.eval()
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)
3.2 FX Graph Mode Quantization (maintenance)
import torch
from torch.ao.quantization import (get_default_qconfig_mapping,get_default_qat_qconfig_mapping,QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copymodel_fp = UserModel()#
# post training dynamic/weight_only quantization
## we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model
model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
# a tuple of one or more example inputs are needed to trace the model
example_inputs = (input_fp32)
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# no calibration needed when we only have dynamic/weight_only quantization
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)#
# post training static quantization
#model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)#
# quantization aware training for static quantization
#model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
model_to_quantize.train()
# prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
# training loop (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)
3.3 PyTorch 2 Export Quantization
import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.export import export_for_training
from torch.ao.quantization.quantizer import (XNNPACKQuantizer,get_symmetric_quantization_config,
)class M(torch.nn.Module):def __init__(self):super().__init__()self.linear = torch.nn.Linear(5, 10)def forward(self, x):return self.linear(x)# initialize a floating point model
float_model = M().eval()# define calibration function
def calibrate(model, data_loader):model.eval()with torch.no_grad():for image, target in data_loader:model(image)# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result should mostly stay the same
m = export_for_training(m, *example_inputs).module()
# we get a model with aten ops# Step 2. quantization
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they
# want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
# or prepare_qat_pt2e for Quantization Aware Training
m = prepare_pt2e(m, quantizer)# run calibration
# calibrate(m, sample_inference_data)
m = convert_pt2e(m)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/86290.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/86290.shtml
英文地址,请注明出处:http://en.pswp.cn/web/86290.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ES和 Kafka 集群搭建过程中的典型问题、配置规范及最佳实践

Kafka 集群搭建与配置经验库文档&#xff08;完整会话汇总&#xff09; 一、会话问题分类与解决方案 1. Elasticsearch 映射解析错误 问题现象&#xff1a; {"error":{"root_cause":[{"type":"mapper_parsing_exception","re…

Linux-信号量

目录 POSIX信号量 信号量的原理 信号量的概念 申请信号量失败被挂起等待 信号量函数 二元信号量模拟实现互斥功能 基于环形队列的生产消费模型 下面环形队列采用数组模拟&#xff0c;用模运算来模拟环状特性&#xff0c;类似如此 空间资源和数据资源 生产者和消费者申请…

Unity2D 街机风太空射击游戏 学习记录 #14 环射和散射组合 循环屏幕道具

概述 这是一款基于Unity引擎开发的2D街机风太空射击游戏&#xff0c;笔者并不是游戏开发人&#xff0c;作者是siki学院的凉鞋老师。 笔者只是学习项目&#xff0c;记录学习&#xff0c;同时也想帮助他人更好的学习这个项目 作者会记录学习这一期用到的知识&#xff0c;和一些…

vue3 定时刷新

在Vue 3中实现定时刷新&#xff0c;你可以使用多种方法。这里列举几种常见的方法&#xff1a; 方法1&#xff1a;使用setInterval 这是最直接的方法&#xff0c;你可以在组件的mounted钩子中使用setInterval来定时执行某些操作&#xff0c;例如重新获取数据。 <template&…

局域网环境下浏览器安全限制的实用方法

在现代 Web 开发和网络应用中&#xff0c;我们常常会遇到浏览器出于安全考虑对某些功能进行限制的情况。例如麦克风、摄像头、地理位置等敏感功能&#xff0c;通常只能在 HTTPS 协议或 localhost 下使用。然而在局域网开发、测试或特定应用场景中&#xff0c;我们可能需要突破这…

如果你在为理解RDA、PCA 和 PCoA而烦恼,不妨来看看丨TomatoSCI分析日记

当你学习了 RDA、PCA 和 PCoA 这三种常见排序方法后&#xff0c;脑子里是不是也冒出过类似的疑问&#xff1a; PCA、PCoA、RDA 不都能画图吗&#xff1f;是不是可以互相替代&#xff1f; RDA 图上也有样本点&#xff0c;那我还需要 PCoA 干什么&#xff1f; ... 这些看似“…

MySQL (二):范式设计

在 MySQL 数据库设计中&#xff0c;范式设计是构建高效、稳定数据库的关键环节。合理的范式设计能够减少数据冗余、消除操作异常&#xff0c;让数据组织更加规范和谐。然而&#xff0c;过度追求范式也可能带来多表联合查询效率降低的问题。本文将深入讲解第一范式&#xff08;1…

什么是财务共享中心?一文讲清财务共享建设方案

目录 一、财务共享中心是什么 1.标准化流程 2.集中化处理 3.智能化系统 4.专业化分工 二、财务共享中心的四大模块 1. 共享系统 2. 共享流程 3. 共享组织 4. 共享数据 三、为什么很多财务共享中心做不下去&#xff1f; 1.只搬人&#xff0c;不换流程 2.系统买了&a…

001 双指针

双指针 双指针&#xff08;Two Pointers&#xff09; 双指针&#xff08;Two Pointers&#xff09; 对撞指针&#xff08;Opposite Direction Two Pointers&#xff09;&#xff1a; 对撞指针从两端向中间移动&#xff0c;一个指针从最左端开始&#xff0c;另一个最右端开始&a…

【unitrix】 4.7 库数字取反(not.rs)

一、源码 这段代码是用Rust语言实现的一个库&#xff0c;主要功能是对数字进行位取反操作&#xff08;按位NOT运算&#xff09;。 /*库数字取反* 编制人: $ource* 修改版次:0版完成版* 本版次创建时间: 2025年6月25日* 最后修改时间: 无* 待完善问题&#xff1a;无*/ use cor…

在ASP.NET Core WebApi中使用日志系统(Serilog)

一.引言 日志是构建健壮 Web API 的重要组成部分&#xff0c;能够帮助我们追踪请求、诊断问题、记录关键事件。在 .Net 中&#xff0c;日志系统由内置的 Microsoft.Extensions.Logging 抽象提供统一接口&#xff0c;并支持多种第三方日志框架&#xff08;如 Serilog、NLog 等&…

(链表:哈希表 + 双向链表)146.LRU 缓存

题目 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 LRU是Least Recently Used的缩写&#xff0c;即最近最少使用&#xff0c;是一种常用的页面置换算法&#xff0c;选择最近最久未使用的页面予以淘汰。该算法赋予每个页面一个访问字段&#xff0c;用来记…

Go Web开发框架实践:模板渲染与静态资源服务

Gin 不仅适合构建 API 服务&#xff0c;也支持 HTML 模板渲染和静态资源托管&#xff0c;使其可以胜任中小型网站开发任务。 一、模板渲染基础 1. 加载模板文件 使用 LoadHTMLGlob 或 LoadHTMLFiles 方法加载模板&#xff1a; r : gin.Default() r.LoadHTMLGlob("templ…

缓存与加速技术实践-Kafka消息队列

目录 #1.1消息队列 1.1.1什么是消息队列 1.1.2消息队列的特征 1.1.3为什么需要消息队列 #2.1ksfka基础与入门 2.1.1kafka基本概念 2.1.2kafka相关术语 2.1.3kafka拓扑架构 #3.1zookeeper概述介绍 3.1.1zookeeper应用举例 3.1.2zookeeper的工作原理是什么&#xff1f; 3.1.3z…

鸿蒙前后端部署教程

第一步&#xff1a;部署Java后端 打开IDEA编辑器 第二步&#xff1a;用DevEco Studio运行鸿蒙端项目 然后按WinR键调出Win的命令行&#xff0c;输入ipconfig 打开后端IDEA可以查看数据库情况&#xff0c;如下图

Python 常用定时任务框架介绍及代码举例

文章目录 Python 常用定时任务框架简介&#x1f9e9; 一、轻量级方案&#xff08;适合简单任务&#xff09;1. **schedule库** ⚙️ 二、中级方案&#xff08;平衡功能与复杂度&#xff09;2. **APScheduler**3. **Celery Celery Beat** &#x1f680; 三、异步专用方案&#…

使用redis服务的redisson架构实现分布式锁

加锁 /*** 尝试为指定的许可证 ID 获取分布式锁。如果锁已被占用&#xff0c;则立即抛出业务异常。** param licenseId 需要加锁的许可证 ID&#xff08;即锁名称&#xff09;* return true 表示成功获取锁&#xff0c;但请注意&#xff1a;* 锁实际持有时间为 30 秒…

HTML表格元素

HTML表格元素深度解析与实战应用 一、表格基本结构与语义化 1. 基础表格元素详解 <table> 容器元素 核心作用&#xff1a;定义表格容器重要属性&#xff1a; border&#xff1a;已废弃&#xff0c;应使用CSS设置边框aria-label/aria-labelledby&#xff1a;为屏幕阅读…

如何使用 Dockerfile 创建自定义镜像

使用 Dockerfile 创建自定义镜像的过程非常清晰&#xff0c;通常包括定义基础镜像、安装依赖、复制代码、设置环境变量和启动命令等步骤。下面详细讲解从零创建自定义镜像的完整流程。 一、什么是 Dockerfile&#xff1f; Dockerfile 是一个文本文件&#xff0c;定义了如何构建…

设置AWS EC2默认使用加密磁盘

问题 EC2磁盘需要使用默认加密。这里需要设置一下默认加密。 EC2