目录

    • 什么是量化
    • 量化实现的原理
    • 实战
      • 准备数据
      • 执行量化
    • 验证量化
    • 结语

什么是量化

量化是一种常见的深度学习技术,其目的在于将原始的深度神经网络权重从高位原始位数被动态缩放至低位目标尾数。例如从FP32(32位浮点)量化值INT8(8位整数)权重。
FP32到INT8的整数缩放示意图,来源NVIDIA
这么做的目的,是为了在不影响神经网络的精度为前提下,减少模型运行时的内存消耗,提升推理系统整体的吞吐量。

量化实现的原理

量化的实现本质上以一种基于动态缩放的数值运算,因此在量化中,有几个重要的参数:

  • 缩放系数 ( s c a l e ) (scale) scale:用于表述从高位缩放至低位的缩放系数,如果没有它,量化就不存在了
  • x f x_f xf:代表输入的浮点高位值,一般是FP32或者FP64的输入值

那么,如何计算缩放系数 ( s c a l e ) (scale) scale呢?
首先,我们要找出输入值的最大值,原因是我们要找出整个输入的的量化范围,即:从哪里结束量化?因此,你可以使用 a m a x ( x f ) amax(x_f) amax(xf)来算出 x f x_f xf的极大值:

a m a x ( x f ) = m a x ( a b x ( x f ) ) amax(x_f) = max(abx(x_f)) amax(xf)=max(abx(xf))

找到最大值后,现在你要思考你需要多少位的量化,但通常在此之前,你需要算出你的输入数据最大可以容纳多少位数据:
n b i t = 2 ∗ a m a x ( x f ) n_{bit} = 2 *amax(x_f) nbit=2amax(xf)
在确定了这个值之后,除以你预期你的位数所能承载的最大数据量,就得到了缩放系数:
s c a l e = n b i t / p o w ( 2 , t b i t ) scale = n_{bit} / pow(2,t_{bit}) scale=nbit/pow(2,tbit)

好了,现在你有了两个量化过程中最重要的参数了,接下来就可以开始正式计算量化的结果了:
x q = C l i p ( R o u n d ( x f / s c a l e ) ) x_q = Clip(Round(x_f / scale)) xq=Clip(Round(xf/scale))

首先,我们需要将现有位数的输入除以我们得到的缩放系数,即得到了目标位数的浮点数据,但别忘了:我们在量化时通常是为了将浮点值操作量化为整数值操作,因此需要将其取整为整数。

那么? C l i p Clip Clip在做什么?因为我们不希望我们量化后结果的范围超出了目标位数的极大和极小值,因此使用 C l i p Clip Clip来裁切目标值为指定位数的极大和极小值。以INT8为例,则应该是:
x q = C l i p ( R o u n d ( x f / s c a l e ) , m i n = − 128 , m a x = 127 ) x_q = Clip(Round(x_f/scale),min=-128,max=127) xq=Clip(Round(xf/scale),min=128,max=127)
量化裁切示意图-来源NVIDIA

实战

说完了原理,我们该如何在ONNX中使用静态量化呢?
在这里,我们需要使用onnxruntime库来完成这个量化操作:

pip install onnxruntime-[target-ep]

其中,target-ep代表你期望模型在哪个类型的计算设备运行,如:

  • CUDA-GPU:则是pip install onnxruntime-gpu
  • DirectML:pip install onnxruntime-directml

准备数据

在准备数据时,我们不能像以前那样直接使用Dict[str,ndarray]的方式来调用静态量化,而是需要使用校准数据读取方式来读取:

from onnxruntime.quantization import (quantize_static, CalibrationDataReader,QuantType, QuantFormat, CalibrationMethod
)
from typing import *
# 创建一个DummpyDataReader类来继承CalibrationDataReader类
class DummyDataReader(CalibrationDataReader):def __init__(self, calibration_dataset:List[Dict[str,np.ndarray]]):self.dataset:List[Dict[str,np.ndarray]] = calibration_datasetself.enum_data:Any = None# 重载get_next迭代函数def get_next(self):if self.enum_data is None:self.enum_data:Iterator = iter(self.dataset)return next(self.enum_data, None)

接下来我们就可以准备输入数据了:

# 这里以Hubert Wav2Vec模型进行数据读取(1,audio_length)
# 采样率为16000Hz
import numpy as np
audio = np.load("./input.npy")
inputs = [{"feats": audio.astype(np.float32)},
]

执行量化

接下来我们就可以调用onnxruntime为我们提供的quantize_static函数了,在我们的实例中,会使用到如下的参数:

  • model_input [str]:输入的模型位置,通常为FP32的ONNX模型权重
  • model_output [str]:量化权重保存的位置
  • calibration_data_reader [CalibrationDataReader]:我们刚刚创建的校准数据读取类
  • quant_format [enum]:量化的格式,对于我们的实例中,使用QDQ(Quantize => Dequantize),即显示量化和反量化格式,因为我们不希望自己手动去算量化,对吧?事实证明使用这个模式的情况ONNXRuntime会自动帮你料理 s c a l e scale scale和零值点的计算,以及后续的反量化等。
  • activation_type [enum]:指定模型内部相关的激活函数使用什么数据类型来完成计算,在我们的例子中,QINT8相对合适,因为Wav2Vec是从音频中来提取特征表述,因此有符号比无符号效果会好很多。
  • weight_type [enum]:指定模型的权重是以什么数据类型来保存的,通常来说,如果你使用的是quantize_dynamic时,ONNXRuntime为了考虑兼容性,默认只会为你量化权重,而不会去管激活函数的量化。
  • calibrate_method [enum]:校准方法,指定在反量化阶段以什么方式来完成数据校准,ONNXRuntime支持下述的校准方式:
    • MinMax:极大极小值,这种校准方式适合基于特征表述的神经网络,如视觉模型,向量机
    • Entropy:基于熵,这种校准方式更适合于不确定性量化,即模型复杂度高,无法直接观测模型内部数据变化的神经网络,例如Transformer。适合处理高维度数据,对于我们这次示例中的Hubert十分有效,因为Hubert最终输出的特征向量大小是 ( b × n × 768 ) (b \times n \times 768) b×n×768
    • Percentile:基于百分位的数据校准模式,可以显著降低因量化产生的干扰值,但缺点就是容易***一刀切***,进而丢失数据
    • Distribution:基于分布的数据校准模式,当你看到***分布***这两字儿,你大概心里也应该有个谱了:没错,它是基于数据在FP32状态下的分布状态来进行对应比例的缩放校准的,而这也正是它的问题所在,即每进行一次校准时都有参考来FP32状态下的数据分布从而计算出INT8下可能的数据分布,因此对于时延要求不大的任务:如Diffusion可以用这类校准。

接下来我们就可以调用quantize_static()来执行静态量化了:

quantize_static(model_input="./hubert.onnx",model_output="./hubert_int8.onnx",calibration_data_reader=reader,quant_format=QuantFormat.QDQ,activation_type=QuantType.QInt8,weight_type=QuantType.QInt8,calibrate_method=CalibrationMethod.Entropy
)

之后你会看到这样的日志:

Collecting tensor data and making histogram ...
Finding optimal threshold for each tensor using 'entropy' algorithm ...
Number of tensors : 712
Number of histogram bins : 128 (The number may increase depends on the data it collects)
Number of quantized bins : 128

这在说明ONNXRuntime正在计算每个张量的最佳阈值和分布大小。

验证量化

接下来我们就可以正常读取这些模型写模型来看看不用位数下的输出精度了:

import onnxruntime as ort
# 加载FP32模型
model_fp32 = ort.InferenceSession("./hubert.onnx")
# 加载FP16模型
model_fp16 = ort.InferenceSession("./hubert_fp16.onnx")
# 加载INT8模型
model_int8 = ort.InferenceSession("./hubert_int8.onnx")
# 预测FP32
fp32_result = model_fp32.run(None,input_feed={"feats": audio.astype(np.float32)}
)
# 预测FP16
fp16_result = model_fp16.run(None,input_feed={"feats": audio.astype(np.float16)}
)
# 预测INT8
int8_result = model_int8.run(None,input_feed={"feats": audio.astype(np.float32)}
)# 绘制图像
import matplotlib.pyplot as plt
fig, ax = plt.subplots(3, 1, figsize=(8,6))ax[0].plot(fp32_result[0][0, 0, :], label="FP32")
ax[0].set_title("FP32 Output")ax[1].plot(fp16_result[0][0, 0, :], label="FP16")
ax[1].set_title("FP16 Output")ax[2].plot(int8_result[0][0, 0, :], label="INT8")
ax[2].set_title("INT8 Output")for a in ax:a.legend()a.grid()plt.tight_layout()
plt.show()

输出图像如下:
请添加图片描述
从图像也可以很明显的看出来:INT8的数据分布会更发散,虽然ONNXRuntime已经帮我们完成了反量化这一步骤。而FP16相比INT8则好看许多,虽然在浮点上位上少了很多表示位,但精度依然还是在线的,这也是量化时要权衡的问题:速度和精度,哪个对你的场景更重要?

结语

量化是一把双刃剑,虽然可以对比原来的推理环境实现大幅度的性能提升,但速度提升的代价就是精度的明显下降,因此在执行量化操作一定要权衡利弊,是否量化真的对你的场景真的很重要?你的任务是否真的很依赖那点儿因为降低精度而换回来的速度?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/85554.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/85554.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/85554.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【量子计算】格罗弗算法

文章目录 🔍 一、算法原理与工作机制⚡ 二、性能优势:二次加速的体现🌐 三、应用场景⚠️ 四、局限性与挑战🔮 五、未来展望💎 总结 格罗弗算法(Grover’s algorithm)是量子计算领域的核心算法之…

C++ 互斥量

在 C 中,互斥量(std::mutex)是一种用于多线程编程中保护共享资源的机制,防止多个线程同时访问某个资源,从而避免数据竞争(data race)和不一致的问题。 🔒 一、基础用法:s…

CSS Content符号编码大全

资源宝整理分享:​https://www.httple.net​ 前端开发中常用的特殊符号查询工具,包含Unicode编码和HTML实体编码,方便开发者快速查找和使用各种符号。支持基本形状、箭头、数学符号、货币符号等多种分类。 前端最常用符号 图标形状十进制十…

RPC常见问题回答

项目流程和架构设计 1.服务端的功能: 1.提供rpc调用对应的函数 2.完成服务注册 服务发现 上线/下线通知 3.提供主题的操作 (创建/删除/订阅/取消订阅) 消息的发布 2.服务的模块划分 1.网络通信模块 net 底层套用的moude库 2.应用层通信协议模块 1.序列化 反序列化数…

【JavaEE】(3) 多线程2

一、常见的锁策略 1、乐观锁和悲观锁 悲观锁:预测锁冲突的概率较高。在锁中加阻塞操作。乐观锁:预测锁冲突的概率较低。使用忙等/版本号等,不产生阻塞。 2、轻量级锁和重量级锁 重量级锁:加锁的开销较大,线程等待锁…

创客匠人服务体系解析:知识 IP 变现的全链路赋能模型

在知识服务行业深度转型期,创客匠人通过 “工具 陪跑 圈层” 的三维服务体系,构建了从 IP 定位到商业变现的完整赋能链条。这套经过 5 万 知识博主验证的模型,不仅解决了 “内容生产 - 流量获取 - 用户转化” 的实操难题,更推动…

国产ARM/RISCV与OpenHarmony物联网项目(六)SF1节点开发

一、终端节点功能设计 1. 功能说明 终端节点设计的是基于鸿蒙操作系统的 TCP 服务器程序,用于监测空气质量并提供远程控制功能。与之前的光照监测程序相比,这个程序使用 E53_SF1 模块(烟雾 / 气体传感器),主要功能包…

Plotly图表全面使用指南 -- Displaying Figures in Python

文中内容仅限技术学习与代码实践参考,市场存在不确定性,技术分析需谨慎验证,不构成任何投资建议。 在 Python 中显示图形 使用 Plotly 的 Python 图形库显示图形。 显示图形 Plotly的Python图形库plotly.py提供了多种显示图形的选项和方法…

getx用法详细解析以及注意事项

源码地址 在 Flutter 中,Get 是来自 get 包的一个轻量级、功能强大的状态管理与路由框架,常用于: 状态管理路由管理依赖注入(DI)Snackbar / Dialog / BottomSheet 管理本地化(多语言) 下面是 …

深度学习:人工神经网络基础概念

本文目录: 一、什么是神经网络二、如何构建神经网络三、神经网络内部状态值和激活值 一、什么是神经网络 人工神经网络(Artificial Neural Network, 简写为ANN)也简称为神经网络(NN),是一种模仿…

Unity2D 街机风太空射击游戏 学习记录 #12环射道具的引入

概述 这是一款基于Unity引擎开发的2D街机风太空射击游戏,笔者并不是游戏开发人,作者是siki学院的凉鞋老师。 笔者只是学习项目,记录学习,同时也想帮助他人更好的学习这个项目 作者会记录学习这一期用到的知识,和一些…

网站如何启用HTTPS访问?本地内网部署的https网站怎么在外网打开?

在互联网的世界里,数据安全已经成为了每个网站和用户都不得不面对的问题。近期,网络信息泄露事件频发,让越来越多的网站开始重视起用户数据的安全性,因此启用HTTPS访问成为了一个热门话题。作为一名网络安全专家,我希望…

计算机网络-----详解网络原理TCP/IP(上)

文章目录 📕1. UDP协议✏️1.1 UDP的特点✏️1.2 基于UDP的应用层协议 📕2. TCP协议✏️2.1 TCP协议段格式✏️2.2 TCP协议特点之确认应答✏️2.3 TCP协议特点之超时重传✏️2.4 TCP协议特点之连接管理✏️2.5 TCP协议特点之滑动窗口✏️2.6 TCP协议特点…

Lora训练

一种大模型高效训练方式&#xff08;PEFT&#xff09; 目标&#xff1a; 训练有限的ΔW&#xff08;权重更新矩阵&#xff09; ΔW为低秩矩阵→ΔWAB&#xff08;其中A的大小为dr, B的大小为rk&#xff0c;且r<<min(d,k)&#xff09;→ 原本要更新的dk参数量大幅度缩减…

蓝牙 5.0 新特性全解析:传输距离与速度提升的底层逻辑(面试宝典版)

蓝牙技术自 1994 年诞生以来,已经经历了多次重大升级。作为当前主流的无线通信标准之一,蓝牙 5.0 在 2016 年发布后,凭借其显著的性能提升成为了物联网(IoT)、智能家居、可穿戴设备等领域的核心技术。本文将深入解析蓝牙 5.0 在传输距离和速度上的底层技术逻辑,并结合面试…

Minio使用https自签证书

自签证书参考&#xff1a;window和ubuntu自签证书_windows 自签证书-CSDN博客 // certFilePath: 直接放在 resources 目录下 或者可以自定实现读取逻辑 // 读取的是 .crt 证书文件public static OkHttpClient createTrustingOkHttpClient(String certFilePath) throws Excep…

汽车前纵梁焊接总成与冲压件的高效自动化三维检测方案

汽车主体结构件上存在很多安装位&#xff0c;为保证汽车装配时的准确性&#xff0c;主体结构件需要进行全方位的尺寸和孔位置精度检测&#xff0c;以确保装配线的主体结构件质量合格。 前纵梁焊接总成是车身框架的核心承载部件&#xff0c;焊接总成由多片钣金冲压件焊接组成&a…

F接口基础.go

前言&#xff1a;接口是一组方法的集合&#xff0c;它定义了一个类型应该具备哪些行为&#xff0c;但不关心具体怎么实现这些行为。一个类型只要实现了接口中定义的所有方法&#xff0c;那么它就实现了这个接口。这种实现是隐式的&#xff0c;不需要显式声明。 目录 接口的定…

cartographer官方指导文件说明---第3章 cartographer前端算法流程介绍

cartographer官方指导文件说明 第3章 cartographer前端算法流程介绍 3.1 Scan Match扫描匹配 扫描匹配&#xff08;Scan Matching&#xff09;是 Cartographer 中实现局部SLAM的核心技术&#xff0c;它通过优化算法将当前激光扫描数据对齐到子图地图中。下面从计算过程、数学…

汽车整车厂如何用数字孪生系统打造“透明车间”

随着工业4.0时代的发展&#xff0c;数字孪生技术已成为现代制造业的重要利器。特别是在汽车整车厂&#xff0c;通过数字孪生系统的应用&#xff0c;能够有效打造一个“透明车间”&#xff0c;实现生产过程的全面可视化与实时监控&#xff0c;提高生产效率&#xff0c;降低成本&…