一. 初步认识模型部署
1. 什么是ONNX?
ONNX 就是一个 中间人 或 通用翻译器。它让你在喜欢的框架(如 PyTorch)里训练好模型后,能轻松地把它变成一种 标准格式。然后,这个标准格式的模型可以被 很多不同的工具和硬件 (通过 ONNX Runtime 或其他支持 ONNX 的引擎) 理解和高效地运行,大大简化了模型从训练到实际应用的部署过程。它的目标是实现 “一次训练,随处运行”。
关键组成部分:
-
ONNX 格式: 基于 Protobuf 定义的文件格式,存储了模型的网络结构(计算图)、参数(权重、偏置)和元数据。
-
ONNX 操作符集: 定义了一组标准的、不断扩充的原子操作(如卷积、矩阵乘、激活函数等),这些操作是构建模型的基本单元。不同的框架在导出时,需要将其特有的操作映射到这些标准操作上。
-
ONNX 运行时:
-
ONNX Runtime (ORT): 由微软维护的高性能推理引擎,专门用于在各种硬件平台(CPU, GPU, FPGA, NPU等)上高效运行 ONNX 模型。它是 ONNX 生态中非常重要和流行的运行时。
-
其他运行时: 许多其他推理引擎和硬件加速库也原生支持加载和运行 ONNX 模型。
-
主要优势:
-
框架无关性: 打破训练框架的壁垒。
-
硬件灵活性: 方便地将模型部署到多样化的硬件环境。
-
部署效率: 避免了为每个目标平台重复开发模型推理代码。
-
生态系统: 拥有庞大的社区和众多厂商支持(微软、Meta、NVIDIA、Intel、AMD、Arm、华为、高通等)。
-
优化机会: ONNX Runtime 等引擎可以对模型图进行各种优化(如图优化、算子融合、量化),显著提升推理速度和降低资源消耗。
2. 模型部署
参考 https://zhuanlan.zhihu.com/p/516920606
在软件工程中,模型部署是把开发完成的软件投入使用的过程,包括环境配置、软件安装等步骤。那么对于深度学习来说,模型部署就是让训练好的模型在特定环境中运行的过程。会遇到的一些难题:
1)运行模型所需要的环境难以配置。深度学习模型通常是有一些框架编写,比如PyTorch、Tensor Flow。由于框架规模、依赖环境的限制,这些框架不适合在手机、开发板等生产环境中安装。
2)深度学习模型的结构通常比较庞大,需要大量的算力才能满足实时运行的需求。模型运行效率需要优化。
因为这些难题,模型部署不能靠简单的环境配置和安装完成。流水线大致如下:
为了让模型最终能够部署到某个环境上,开发者可以使用任意一种深度学习框架来定义网络结构,并通过训练确定网络中的参数。之后,模型的结构和参数会被转换成一种只描述网络结构的中间表示,一些针对网络结构的优化会在中间表示上进行。最后,用面向硬件的高性能编程框架(如CUDA,OpenCL)编写,能高效执行深度学习网络中算子的推理引擎会把中间表示转换成特定的文件格式,并在对应的硬件平台上高效运行模型。
这一条流水线解决了模型部署中两大问题:使用对接深度学习框架和推理引擎的中间表示,开发者不必担心如何在新环境中运行各个复杂的框架;通过中间表示的网络结构优化和推理引擎对运算的底层优化,模型的运算效率大幅提升。
2.1 配置部署环境
让我们用 PyTorch 实现一个超分辨率模型,并把模型部署到 ONNX Runtime 这个推理引擎上。
import os
import cv2
import numpy as np
import requests
import torch
import torch.onnx
from torch import nn
class SuperResolutionNet(nn.Module):def __init__(self, upscale_factor):super().__init__()self.upscale_factor = upscale_factorself.img_upsampler = nn.Upsample(scale_factor=self.upscale_factor,mode='bicubic',align_corners=False)self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)self.conv2 = nn.Conv2d(64,32,kernel_size=1, padding=0)self.conv3 = nn.Conv2d(32,3,kernel_size=5,padding=2)self.relu = nn.ReLU()def forward(self, x):x = self.img_upsampler(x)out = self.relu(self.conv1(x))out = self.relu(self.conv2(out))out = self.conv3(out)return out
urls = ['https://download.openmmlab.com/mmediting/restorers/srcnn/srcnn_x4k915_1x16_1000k_div2k_20200608-4186f232.pth', 'https://raw.githubusercontent.com/open-mmlab/mmediting/master/tests/data/face/000001.png']
names = ['srcnn.pth', 'face.png']
for url,name in zip(urls, names):if not os.path.exists(name):open(name, 'wb').write(requests.get(url).content)def init_torch_model():torch_model = SuperResolutionNet(upscale_factor=3)state_dict = torch.load('srcnn.pth')['state_dict']for old_key in list(state_dict.keys()):new_key = '.'.join(old_key.split('.')[1:])state_dict[new_key] = state_dict.pop(old_key)torch_model.load_state_dict(state_dict)torch_model.eval()return torch_modelmodel = init_torch_model()
input_img = cv2.imread('face.png').astype(np.float32)input_img = np.transpose(input_img, [2, 0, 1])
input_img = np.expand_dims(input_img, 0)
torch_output = model(torch.from_numpy(input_img)).detach().numpy()# NCHW to HWC
torch_output = np.squeeze(torch_output, 0)
torch_output = np.clip(torch_output, 0, 255)
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8) # Show image
cv2.imwrite("face_torch.png", torch_output)
SRCNN 先把图像上采样到对应分辨率,再用 3 个卷积层处理图像。为了方便起见,我们跳过训练网络的步骤,直接下载模型权重(由于 MMEditing 中 SRCNN 的权重结构和我们定义的模型不太一样,我们修改了权重字典的 key 来适配我们定义的模型),同时下载好输入图片。为了让模型输出成正确的图片格式,我们把模型的输出转换成 HWC 格式,并保证每一通道的颜色值都在 0~255 之间。如果脚本正常运行的话,一幅超分辨率的人脸照片会保存在 “face_torch.png” 中。
在PyTorch模型测试正确后,我们来正式开始部署这个模型。我们下一步的任务是把pyTorch模型转换成用中间表示ONNX描述的模型。
2.2 中间表示- ONNX
介绍ONNX之前,我们先从本质上来认识一下神经网络的结构。神经网络实际上知识描述了数据计算的过程,其结构可以用计算图来表示。比如a + b可以用下图的计算图来表示:
为了加速计算,一些框架会使用对神经网络“先编译,后执行”的静态图来描述网络。静态图的缺点是难以描述控制流(比如 if-else 分支语句 和 for 循环语句),直接对其引入控制语句会导致产生不同的计算图。比如循环执行n次a = a+b,对于不同的n,会产生不同的计算图:
ONNX(Open Neural Network Exchange)是Facebook和 微软在2017年共同发布的,用于标准描述计算图的一种格式。目前,在数家机构的共同维护下,ONNX已经对接了多种深度学习架构和多推理引擎。因此,ONNX被当成了深度学习框架到推理引擎的桥梁,就像编译器的中间语言一样。由于各个框架兼容性不一,我们通常只用ONNX表示更容易部署的静态图。
我们用下面的代码把PyTorch的模型转换成ONNX格式的模型:
x = torch.randn(1, 3, 256, 256) with torch.no_grad(): torch.onnx.export( model, x, "srcnn.onnx", opset_version=11, input_names=['input'], output_names=['output'])
其中,torch.onnx.export是PyTorch自带的把模型转换成ONNX格式的函数。让我们先来看一下前三个必选参数:前三个参数分别是要转换的模型、模型的任意一组输入、导出的ONNX文件的文件名。转换模型时,需要原模型和输出文件迷宫是很容易理解的,但为什么需要为模型提供一组输入呢?这就涉及到ONNX的转换原理来。从PyTorch的模型到ONNX的模型,本质上是一种语言上的翻译。直觉上的想法是像编译器一样彻底解析原模型的代码,记录所有控制流。但是前面我们通常只用ONNX记录不考虑控制流的静态图。因此,PyTorch提供了一种叫做追踪(trace)的模型转换方法:给定一组输入,再实际执行一遍模型,即把这组输入对应的计算图记录下来,保存为ONNX格式。export函数用的就是追踪导出方法,需要给任意一组输入,让模型跑起来。我们测试图片的三通道,256*256大小的,这里也构造一个同样形状的随机张量。
剩下的参数中,opset_version表示ONNX算子集的版本。深度学习的发展会不断的诞生新算子,为了支持这些新增的算子,ONNX会经常发布新的算子集。我们领opset_version=11,即使用第11个ONNX算子集,是因为SRC NN中的bicubic(双三次插值)在opset11中才得到支持。剩下的两个参数input_names, output_names是输入、输出tensor的名称,我们稍后会用到这些名称。
如果上述代码运行成功,目录下会新增一个"srcnn.onnx"的 ONNX 模型文件。我们可以用下面的脚本来验证一下模型文件是否正确。
import onnx onnx_model = onnx.load("srcnn.onnx")
try: onnx.checker.check_model(onnx_model)
except Exception: print("Model incorrect")
else: print("Model correct")
其中,onnx.load函数用于读取一个ONNX模型。onnx.checker.check_model用于检查模型格式是否正确,如果有错误的话会直接报错。我们的模型是正确的,控制台中应该会打印出“Model correct”
接下来,让我们来看一看 ONNX 模型具体的结构是怎么样的。我们可以使用 Netron (开源的模型可视化工具)来可视化 ONNX 模型。把 srcnn.onnx 文件从本地的文件系统拖入网站,即可看到如下的可视化结果:
点击 input 或者 output,可以查看 ONNX 模型的基本信息,包括模型的版本信息,以及模型输入、输出的名称和数据类型。
点击某一个算子节点,可以看到算子的具体信息。比如点击第一个 Conv 可以看到:
每个算子记录了算子属性、图结构、权重三类信息。
- 算子属性信息即图中attributs里的信息,对于卷积来说,算子属性包括了卷积大小(kernel_shape)、卷积步长(strides)等内容。这些算子属性最终会用来生成一个具体的算子。
- 图结构信息指算子节点在就按图中的名称、邻边信息。对于图中的卷积来说,该算子节点叫Conv_2,输入数据叫做11,输出数据叫做12。根据每个算子节点的图结构信息,就能完整地复原出网络的计算图。
- 权重信息指的是网络经过训练后,算子存储的权重信息。对于卷积来说,权重信息包括卷积核的权重值和卷积后的偏差值。点击图中 conv1.weight, conv1.bias 后面的加号即可看到权重信息的具体内容。
现在,我们有了SRCNN的ONNX模型。让我们看看最后该如何把这个模型运行起来。
2.3 推理引擎- ONNX Runtime
ONNX Runtime是由微软维护的一个跨平台机器学习推理加速器,也就是我们前面提到的“推理引擎”。ONNX Runtime是直接对接ONNX的,即ONNX Runtime是直接对接ONNX的,即ONNX Runtime可以直接读取并运行.onnx文件,而不需要再把.onnx格式的文件转换成其他格式的文件。也就是说,对于PyTorch-ONNX-ONNX Runtime这条部署流程线,只要在目标设备中得到.onnx文件,并在ONNX Runtime上运行模型,模型部署就算大功告成了。
通过刚刚的操作,我们把PyTorch编写的模型转换成ONNX,并通过可视化检查了模型的正确性。最后让我们用ONNX Runtime运行一下模型,完成模型部署的最后一步。
ONNX Runtime提供了Python接口。接着刚才的脚本,我们可以添加如下代码运行模型.
import onnxruntime ort_session = onnxruntime.InferenceSession("srcnn.onnx")
ort_inputs = {'input': input_img}
ort_output = ort_session.run(['output'], ort_inputs)[0] ort_output = np.squeeze(ort_output, 0)
ort_output = np.clip(ort_output, 0, 255)
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite("face_ort.png", ort_output)
这段代码中,出去后处理操作外,和ONNX Runtime相关的代码只有三行。 让我们简单解析一下这三行代码。onnxruntime.InferenceSession用于获取一个ONNX Runtime推理器,其参数是用于推理的ONNX模型文件。推理器的run方法用于模型推理,其第一个参数为输出张量名的列表,第二个参数为输入值的字典。其中输入值字典的key为张量名,value为numpy类型的张量值。输入输出张量的名称需要和torch.onnx.export中设置的输入输出名对应。
如果代码正常运行的话,另一幅超分辨率照片会保存在"face_ort.png"中。这幅图片和刚刚得到的"face_torch.png"是一模一样的。这说明 ONNX Runtime 成功运行了 SRCNN 模型,模型部署完成了!以后有用户想实现超分辨率的操作,我们只需要提供一个 "srcnn.onnx" 文件,并帮助用户配置好 ONNX Runtime 的 Python 环境,用几行代码就可以运行模型了。或者还有更简便的方法,我们可以利用 ONNX Runtime 编译出一个可以直接执行模型的应用程序。我们只需要给用户提供 ONNX 模型文件,并让用户在应用程序选择要执行的 ONNX 模型文件名就可以运行模型了。
总结
- 模型部署,指把训练好的模型在特定的环境中运行的结果。模型部署要解决模型框架兼容性差和模型运行速度慢两大问题。
- 模型部署的常见流水线是“深度学习框架-中间表示-推理引擎”。其中比较常见的一个中间表示是ONNX。
- 深度学习模型实际上就是一个计算图。模型部署时通常把模型转换成静态的就按图,即没有控制流(分支语句和循环语句)的计算图。
- PyTorch框架自带对ONNX的支持,只需要构造一组随机的输入,并对模型调用torch.onnx.export即可完成Pytorch到ONNX的转换。
- 推理引擎ONNX Runtime对ONNX模型有原生的支持。给定一个.onnx文件,只需要简单的使用ONNX Runtime的Python API就可以完成模型推理。
二. 模型部署中常见的难题
一般模型部署会碰到以下几类困难:
- 模型的动态化。出于性能的考虑,各推理框架都默认模型的输入形状、输出形状、结构时静态的。而为了让模型的泛用性更强,部署时需要在尽可能不影响原来逻辑的前提下,让模型的输入输出或者结构动态话。
- 新算子的实现。深度学习日新月异,提出新算子的速度往往快于ONNX维护者支持的速度。为了部署新模型,部署工程师往往需要自己在ONNX和推理引擎中支持新算子。
- 中间表示于推理引擎的兼容问题。由于各推理引擎的实现不同,对ONNX难以形成统一的支持。为了确保模型在不同的推理引擎中有同样的运行效果,部署工程师往往得为某个推理引擎定制模型代码,这为模型部署引入了许多工作量。
3.1 实现动态放大的超分辨率模型
在原来的SRCNN中, 图片的放大比例是写死在模型里的
def init_torch_model(): torch_model = SuperResolutionNet(upscale_factor=3)
我们使用 upscale_factor 来控制模型的放大比例。初始化模型的时候,我们默认令 upscale_factor 为 3,生成了一个放大 3 倍的 PyTorch 模型。这个 PyTorch 模型最终被转换成了 ONNX 格式的模型。如果我们需要一个放大 4 倍的模型,需要重新生成一遍模型,再做一次到 ONNX 的转换。
现在我们希望图片的放大倍数可以自由设置。而我们给用户的只有一个.onnx文件和运行超分辨率模型的应用程序。我们不修改.onnx文件的前提下改变放大倍数。
因此,我们必须修改原来的模型,令模型的放大倍数变成推理时的输入。
import os
import cv2
import numpy as np
import requests
import torch
import torch.onnx
from torch import nn
from torch.nn.functional import interpolateclass SuperResolutionNet(nn.Module):def __init__(self, upscale_factor):super().__init__()self.upscale_factor = upscale_factorself.img_upsampler = nn.Upsample(scale_factor=self.upscale_factor,mode='bicubic',align_corners=False)self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)self.conv2 = nn.Conv2d(64,32,kernel_size=1, padding=0)self.conv3 = nn.Conv2d(32,3,kernel_size=5,padding=2)self.relu = nn.ReLU()def forward(self, x, upscale_factor):x = interpolate(x,scale_factor = upscale_factor,mode ='bicubic',align_corners=False)out = self.relu(self.conv1(x))out = self.relu(self.conv2(out))out = self.conv3(out)return out
urls = ['https://download.openmmlab.com/mmediting/restorers/srcnn/srcnn_x4k915_1x16_1000k_div2k_20200608-4186f232.pth', 'https://raw.githubusercontent.com/open-mmlab/mmediting/master/tests/data/face/000001.png']
names = ['srcnn.pth', 'face.png']
for url,name in zip(urls, names):if not os.path.exists(name):open(name, 'wb').write(requests.get(url).content)def init_torch_model():torch_model = SuperResolutionNet(upscale_factor=3)state_dict = torch.load('srcnn.pth')['state_dict']for old_key in list(state_dict.keys()):new_key = '.'.join(old_key.split('.')[1:])state_dict[new_key] = state_dict.pop(old_key)torch_model.load_state_dict(state_dict)torch_model.eval()return torch_modelmodel = init_torch_model()
input_img = cv2.imread('face.png').astype(np.float32)input_img = np.transpose(input_img, [2, 0, 1])
input_img = np.expand_dims(input_img, 0)
torch_output = model(torch.from_numpy(input_img), 3).detach().numpy()# NCHW to HWC
torch_output = np.squeeze(torch_output, 0)
torch_output = np.clip(torch_output, 0, 255)
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8) # Show image
cv2.imwrite("face_torch.png", torch_output)
SuperResolutionNet 未修改之前,nn.Upsample 在初始化阶段固化了放大倍数,而 PyTorch 的 interpolate 插值算子可以在运行阶段选择放大倍数。因此,我们在新脚本中使用 interpolate 代替 nn.Upsample,从而让模型支持动态放大倍数的超分。 在使用模型推理时,我们把放大倍数设置为 3。
torch_output = model(torch.from_numpy(input_img), 3).detach().numpy()
最后,图片保存在文件 "face_torch_2.png" 中。一切正常的话,"face_torch_2.png" 和 "face_torch.png" 的内容一模一样。
导出模型时:
x = torch.randn(1, 3, 256, 256)with torch.no_grad():torch.onnx.export(model, (x,3),"srcnn2.onnx",opset_version=11,input_names=['input', 'factor'],output_names = ['output'])
运行这些脚本时,会报一长串错误。没办法,我们碰到了模型部署中的兼容性问题。
3.2 自定义算子
直接使用Pytorch模型的话,我们修改几行代码就能实现模型输入的动态化。但在模型部署中,我们要花数倍的时间来设法解决这一问题。现在,让我们顺着解决问题的思路,体验一下模型部署的困难,并学习使用自定义算子的方式,解决超分辨率模型的动态化问题。
刚刚的报错是因为 PyTorch 模型在导出到 ONNX 模型时,模型的输入参数的类型必须全部是 torch.Tensor。而实际上我们传入的第二个参数" 3 "是一个整形变量。这不符合 PyTorch 转 ONNX 的规定。我们必须要修改一下原来的模型的输入。为了保证输入的所有参数都是 torch.Tensor 类型的,我们做如下修改:
... class SuperResolutionNet(nn.Module): def forward(self, x, upscale_factor): x = interpolate(x, scale_factor=upscale_factor.item(), mode='bicubic', align_corners=False) ... # Inference
# Note that the second input is torch.tensor(3)
torch_output = model(torch.from_numpy(input_img), torch.tensor(3)).detach().numpy() ... with torch.no_grad(): torch.onnx.export(model, (x, torch.tensor(3)), "srcnn2.onnx", opset_version=11, input_names=['input', 'factor'], output_names=['output'])
由于 PyTorch 中 interpolate 的 scale_factor 参数必须是一个数值,我们使用 torch.Tensor.item() 来把只有一个元素的 torch.Tensor 转换成数值。之后,在模型推理时,我们使用 torch.tensor(3) 代替 3,以使得我们的所有输入都满足要求。现在运行脚本的话,无论是直接运行模型,还是导出 ONNX 模型,都不会报错了。
但是,导出 ONNX 时却报了一条 TraceWarning 的警告。这条警告说有一些量可能会追踪失败。
/var/folders/pd/9txrcrys3rdfqk4hyxxtvszr0000gn/T/ipykernel_28173/3437727912.py:16: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! scale_factor = upscale_factor.item(),
这是怎么回事呢?让我们把生成的 srcnn2.onnx 用 Netron 可视化一下:
可以发现,虽然我们把模型推理的输入设置为两个,但是ONNX模型还是长的和原来一摸一样,很自由一个叫input的输入,这是由于我们使用了torch.Tensor.item()把数据从Tensor里取出来,而导出ONNX模型时这个操作时无法被记录的,只好报了一条Trace Warning。这导致interpolate插值函数的放大倍数还是被设置成了“3”这个固定值,我们导出的“srcnn2.onnx”和最开始的“srcnn.onnx”完全相同。
直接修改原来的模型似乎行不通,我们得从 PyTorch 转 ONNX 的原理入手,强行令 ONNX 模型明白我们的想法了。
仔细观察 Netron 上可视化出的 ONNX 模型,可以发现在 PyTorch 中无论是使用最早的 nn.Upsample,还是后来的 interpolate,PyTorch 里的插值操作最后都会转换成 ONNX 定义的 Resize 操作。也就是说,所谓的PyTorch转ONNX实际上就是把每个PyTorch的操作映射成了ONNX定义的算子。
其中,展开scales,可以看到scales是一个长度为4的一维张量,其内容为[1,1,3,3]
一维张量[1, 1, 3, 3]表示Resize操作每一个维度的缩放系数;其类型为Initializer,表示这个值是根据常量初始化出来的。如果我们能够自己生成一个ONNX的Resize算子,让scales成为一个可变量而不是常量,就像它上面的X一样,那这个超分辨模型就能动态缩放了。
现有实现插值的 PyTorch 算子有一套规定好的映射到 ONNX Resize 算子的方法,这些映射出的 Resize 算子的 scales 只能是常量,无法满足我们的需求。我们得自己定义一个实现插值的 PyTorch 算子,然后让它映射到一个我们期望的 ONNX Resize 算子上。
下面的脚本定义了一个 PyTorch 插值算子,并在模型里使用了它。我们先通过运行模型来验证该算子的正确性:
class NewInterpolate(torch.autograd.Function): @staticmethod def symbolic(g, input, scales): return g.op("Resize", input, g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)), scales, coordinate_transformation_mode_s="pytorch_half_pixel", cubic_coeff_a_f=-0.75, mode_s='cubic', nearest_mode_s="floor") @staticmethod def forward(ctx, input, scales): scales = scales.tolist()[-2:] return interpolate(input, scale_factor=scales, mode='bicubic',
def symbolic(g, input, scales):
-
g
:表示一个图(graph)上下文,用于构建ONNX图
input
:要调整大小的输入张量 -
scales
:缩放因子,是一个一维张量(长度为输入维度数),指定每个维度上的缩放比例
参数详解:
-
"Resize"
:要创建的ONNX算子的类型(调整大小操作)
input
:输入张量,即需要调整大小的数据
g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
:
创建一个空的常量张量作为roi
(感兴趣区域)参数
这里使用空张量表示使用整个输入区域
dtype=torch.float32
指定数据类型为32位浮点数scales
:-
缩放因子张量
例如:对于图像数据(N, C, H, W),scales可能为[1.0, 1.0, 2.0, 2.0]表示高度和宽度各放大2倍
-
-
coordinate_transformation_mode_s="pytorch_half_pixel"
:-
坐标变换模式
pytorch_half_pixel
:使用PyTorch风格的半像素坐标变换
确保ONNX的resize操作与PyTorch的interpolate()行为一致
-
-
cubic_coeff_a_f=-0.75
:-
三次插值的系数
-0.75
是常用的值(对应Catmull-Rom样条插值)
仅当插值模式为'cubic'时生效
-
-
mode_s='cubic'
:-
插值模式
'cubic'
表示使用三次样条插值
其他可能值:'nearest', 'linear'等
-
-
nearest_mode_s="floor"
:-
当插值模式为nearest时使用的舍入方法
"floor"
表示向下取整
这里虽然模式是cubic,但ONNX要求必须指定此参数
-
先理清一下思路,我们希望新的插值算子有两个输入,一个是被用于操作的图像,一个是图像的放缩比例。前面讲到,为了对接ONNX中的Resize算子的scales参数,这个放缩比例是一个【1, 1, x, x】张量,其中x为放大倍数。在之前放大3倍的模型中,这个参数被固定成了【1,1,3,3】。因此,在插值算子中,我们希望模型的第二个输入是一个【1,1,w,h】的张量,其中w和h分别为图片宽和高的放大倍数。
搞清楚了插值算子的输入,再看一看算子的具体实现。算子的推理行为由算子的forward方法决定。该方法的第一个参数必须为ctx,后面的参数为算子的自定义输入,我们设置两个输入,分别为被操作的图像和放缩比例。为保证推理正确,需要把【1,1,w, h】格式的输入对接到原来的interpolate函数上。我们的做法是截取输入张量的后两个元素,把这两个元素以list的格式传入interpolate的scale_factor参数。
接下来,我们要决定新算子映射到ONNX算子的方法,映射到ONNX的方法由一个算子的symbolic方法决定。symbolic方法的第一个参数必须是g,之后的参数是算子的自定义输入,和forward函数一样。ONNX算子的具体定义由g.op实现。g.op的每个参数都可以映射到ONNX中的算子属性:
对于其他参数,我们可以照着现在的Resize算子填。而要注意的是,我们现在希望scales参数是由输入动态决定的。因此,在填入ONNX的scales时,我们要把symbolic方法的输入参数中scales填入。
接着,让我们把新模型导出成ONNX模型:
x = torch.randn(1, 3, 256, 256)
with torch.no_grad():torch.onnx.export(model, (x, factor),"srcnn3.onnx",opset_version=11,input_names=['input', 'factor'],output_names=['output'])
可以看到,正如我们所期望的,导出的 ONNX 模型有了两个输入!第二个输入表示图像的放缩比例。
运行上面的代码,可以得到一个边长放大4倍的超分辨率图片 "face_ort_3.png"。动态的超分辨率模型生成成功了!只要修改 input_factor,我们就可以自由地控制图片的缩放比例。
我们刚刚的工作,实际上是绕过 PyTorch 本身的限制,凭空“捏”出了一个 ONNX 算子。事实上,我们不仅可以创建现有的 ONNX 算子,还可以定义新的 ONNX 算子以拓展 ONNX 的表达能力。后续教程中我们将介绍自定义新 ONNX 算子的方法。
总结
- 模型部署中常见的几类困难有:模型的动态化;新算子的实现;框架间的兼容。
- PyTorch转ONNX,实际上就是把每一个操作转化成ONNX定义的某一个算子。比如对于PyTorch中的Upsample和interpolate,在转ONNX后最终都会成为ONNX的Resize算子。
- 通过修改继承自torch.autograd.Function算子的symbolic方法,可以改变该算子映射到ONNX算子的行为。
三. PyTorch转ONNX详解
ONNX 是目前模型部署中最重要的中间表示之一。学懂了 ONNX 的技术细节,就能规避大量的模型部署问题。
在把 PyTorch 模型转换成 ONNX 模型时,我们往往只需要轻松地调用一句torch.onnx.export
就行了。这个函数的接口看上去简单,但它在使用上还有着诸多的“潜规则”。在这篇教程中,我们会详细介绍 PyTorch 模型转 ONNX 模型的原理及注意事项。除此之外,我们还会介绍 PyTorch 与 ONNX 的算子对应关系,以教会大家如何处理 PyTorch 模型转换时可能会遇到的算子支持问题。
4.1 torch.onnx.export细节
在这一节里,我们将详细介绍 PyTorch 到 ONNX 的转换函数—— torch.onnx.export
。我们希望大家能够更加灵活地使用这个模型转换接口,并通过了解它的实现原理来更好地应对该函数的报错(由于模型部署的兼容性问题,部署复杂模型时该函数时常会报错)。
1. 计算图导出方法
(帮助理解为主,新的pytorch计算图有所改变)
TorchScript 是一种序列化和优化PyTorch模型的格式,在优化过程中,一个torch.nn.Module模型会被转换成TorchScript的torch.jit.ScriptModule模型。现在,TorchScript也被当成一种中间表示使用。
torch.onnx.export中需要的模型实际上是一个torch.jit.ScriptModule。而把普通的PyTorch模型转成一个这样的TorchScript模型,有跟踪(trace)和记录(script)两种导出计算图的方法。如果给torch.onnx.export传入了一个普通PyTorch模型(torch.nn.Module),那么这个模型会默认使用trace方法导出:
回忆一下我们第一篇教程知识:跟踪法只能通过实际运行一遍模型的方法导出模型的静态图,即无法识别出模型中的控制流(如循环);记录法则能通过解析模型来正确记录所有的控制流。我们以下面这段代码为例来看一看这两种转换方法的区别:
import torch class Model(torch.nn.Module): def __init__(self, n): super().__init__() self.n = n self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x): for i in range(self.n): x = self.conv(x) return x models = [Model(2), Model(3)]
model_names = ['model_2', 'model_3'] for model, model_name in zip(models, model_names): dummy_input = torch.rand(1, 3, 10, 10) dummy_output = model(dummy_input) model_trace = torch.jit.trace(model, dummy_input) model_script = torch.jit.script(model) torch.onnx.export(model_trace, dummy_input, f'{model_name}_trace.onnx',input_names=['input'], # 新增:指定输入名称output_names=['output'], # 新增:指定输出名称dynamic_axes={ # 新增:处理动态维度'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})torch.onnx.export(model_script, dummy_input, f'{model_name}_script.onnx',input_names=['input'], # 新增:指定输入名称output_names=['output'], # 新增:指定输出名称dynamic_axes={ # 新增:处理动态维度'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
在这段代码里,我们定义了一个带循环的模型,模型通过参数n
来控制输入张量被卷积的次数。之后,我们各创建了一个n=2
和n=3
的模型。我们把这两个模型分别用跟踪和记录的方法进行导出。
值得一提的是,由于这里的两个模型(model_trace, model_script)是TorchScript模型,export函数已经不需要再运行一遍模型了。(如果模型是用跟踪法得到的,那么在执行torch.jit.trace的时候就运行过一遍了;而用记录法导出时,模型不需要实际运行)参数中的dummy input和dummy output仅仅是为了获取输入和输出张量的类型和形状。
运行上面的代码,我们把得到的4个onnx文件用Netron可视化:
首先看跟踪法得到的ONNX模型结构,可以看出来,对于不同的n,ONNX模型结构式不一样的。
而用记录法的话,最终的ONNX模型用Loop节点来表示循环。这样哪怕对于不同的n,ONNX模型也有同样的结构。
2. 参数讲解
了解完转换函数的原理后,我们来详细介绍一下该函数的主要参数作用:
function) def export(model: Module | ExportedProgram | ScriptModule | ScriptFunction,args: tuple[Any, ...] = (),f: str | PathLike | None = None,*,kwargs: dict[str, Any] | None = None,export_params: bool = True,verbose: bool | None = None,input_names: Sequence[str] | None = None,output_names: Sequence[str] | None = None,opset_version: int | None = None,dynamic_axes: Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None = None,keep_initializers_as_inputs: bool = False,dynamo: bool = False, external_data: bool = True,dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,custom_translation_table: dict[(...) -> Any, ((...) -> Any) | Sequence[(...) -> Any]] | None = None,report: bool = False,optimize: bool = True,verify: bool = False,profile: bool = False,dump_exported_program: bool = False,artifacts_dir: str | PathLike = ".",fallback: bool = False,training: TrainingMode = _C_onnx.TrainingMode.EVAL,operator_export_type: OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,do_constant_folding: bool = True,custom_opsets: Mapping[str, int] | None = None,export_modules_as_functions: bool | Collection[type[Module]] = False,autograd_inlining: bool = True ) -> (ONNXProgram | None)
input_names, output_names
设置输入和输出张量的名称。如果不设置的话,会自动分配一些简单的名字(如数字)。
ONNX 模型的每个输入和输出张量都有一个名字。很多推理引擎在运行 ONNX 文件时,都需要以“名称-张量值”的数据对来输入数据,并根据输出张量的名称来获取输出数据。在进行跟张量有关的设置(比如添加动态维度)时,也需要知道张量的名字。
在实际的部署流水线中,我们都需要设置输入和输出张量的名称,并保证 ONNX 和推理引擎中使用同一套名称。
opset_version
转换时参考哪个 ONNX 算子集版本,默认为 9。后文会详细介绍 PyTorch 与 ONNX 的算子对应关系。
dynamic_axes
指定输入输出张量的哪些维度是动态的。
为了追求效率,ONNX 默认所有参与运算的张量都是静态的(张量的形状不发生改变)。但在实际应用中,我们又希望模型的输入张量是动态的,尤其是本来就没有形状限制的全卷积模型。因此,我们需要显式地指明输入输出张量的哪几个维度的大小是可变的。
我们来看一个dynamic_axes
的设置例子:
import torch class Model(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x): x = self.conv(x) return x model = Model()
dummy_input = torch.rand(1, 3, 10, 10)
model_names = ['model_static.onnx',
'model_dynamic_0.onnx',
'model_dynamic_23.onnx'] dynamic_axes_0 = { 'in' : [0], 'out' : [0]
}
dynamic_axes_23 = { 'in' : [2, 3], 'out' : [2, 3]
} torch.onnx.export(model, dummy_input, model_names[0],
input_names=['in'], output_names=['out'])
torch.onnx.export(model, dummy_input, model_names[1],
input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_0)
torch.onnx.export(model, dummy_input, model_names[2],
input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_23)
首先,我们导出 3 个 ONNX 模型,分别为没有动态维度、第 0 维动态、第 2 第 3 维动态的模型。
在这份代码里,我们是用列表的方式表示动态维度,例如:
dynamic_axes_0 = { 'in' : [0], 'out' : [0]
}
由于 ONNX 要求每个动态维度都有一个名字,这样写的话会引出一条 UserWarning,警告我们通过列表的方式设置动态维度的话系统会自动为它们分配名字。一种显式添加动态维度名字的方法如下:
dynamic_axes_0 = { 'in' : {0: 'batch'}, 'out' : {0: 'batch'}
}
我们在模型导出计算图时用的是一个形状为(1, 3, 10, 10)
的张量。现在,我们来尝试以形状分别是(1, 3, 10, 10), (2, 3, 10, 10), (1, 3, 20, 20)
为输入,用ONNX Runtime运行一下这几个模型,看看哪些情况下会报错,并保存对应的报错信息。得到的输出信息应该如下:
Input[0] on model model_static.onnx succeed.
Input[1] on model model_static.onnx error.
Input[2] on model model_static.onnx error.
Input[0] on model model_dynamic_0.onnx succeed.
Input[1] on model model_dynamic_0.onnx succeed.
Input[2] on model model_dynamic_0.onnx error.
Input[0] on model model_dynamic_23.onnx succeed.
Input[1] on model model_dynamic_23.onnx error.
Input[2] on model model_dynamic_23.onnx succeed.
可以看出,形状相同的(1, 3, 10, 10)
的输入在所有模型上都没有出错。而对于batch(第 0 维)或者长宽(第 2、3维)不同的输入,只有在设置了对应的动态维度后才不会出错。我们可以错误信息中找出是哪些维度出了问题。比如我们可以用以下代码查看input[1]
在model_static.onnx
中的报错信息:
print(exceptions[(1, 'model_static.onnx')]) # output
# [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got invalid dimensions for input: in for the following indices index: 0 Got: 2 Expected: 1 Please fix either the inputs or the model.
这段报错告诉我们名字叫in
的输入的第 0 维不匹配。本来该维的长度应该为 1,但我们的输入是 2。实际部署中,如果我们碰到了类似的报错,就可以通过设置动态维度来解决问题。
完整的export例子:
import torch
import torch.onnx# 定义模型
class MyModel(torch.nn.Module):def forward(self, x):return torch.nn.functional.relu(x)model = MyModel().eval()# 导出配置
torch.onnx.export(model=model,args=torch.randn(1, 3, 224, 224),f="my_model.onnx",input_names=["input"],output_names=["output"],dynamic_axes={"input": {0: "batch_size"},"output": {0: "batch_size"}},opset_version=14,do_constant_folding=True,verbose=True,training=torch.onnx.TrainingMode.EVAL
)
3. 使用提示
通过学习之前的知识,我们基本掌握了 torch.onnx.export
函数的部分实现原理和参数设置方法,足以完成简单模型的转换了。但在实际应用中,使用该函数还会踩很多坑。这里我们模型部署团队把在实战中积累的一些经验分享给大家。
使模型在 ONNX 转换时有不同的行为
有些时候,我们希望模型在导出至 ONNX 时有一些不同的行为模型在直接用 PyTorch 推理时有一套逻辑,而在导出的ONNX模型中有另一套逻辑。比如,我们可以把一些后处理的逻辑放在模型里,以简化除运行模型之外的其他代码。torch.onnx.is_in_onnx_export()
可以实现这一任务,该函数仅在执行 torch.onnx.export()
时为真。以下是一个例子:
import torch class Model(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x): x = self.conv(x) if torch.onnx.is_in_onnx_export(): x = torch.clip(x, 0, 1) return x
这里,我们仅仅在模型导出时把输出张量的数值限制在【0,1】之间。使用is_in_onnx_Export确实能让我们方便地在代码中添加和模型部署相关的逻辑。但是,这些代码对只关心模型训练的开发者和用户来说很不友好,突兀的部署逻辑会降低代码整体的可读性。同时,is_in_onnx_export只能在每个需要添加的部署逻辑的地方“打补丁”,难以进行统一的管理。(MMDeploy重新机制可以规避这些问题)。
4.2 PyTorch 对 ONNX 的算子支持
在确保torch.onnx.export()
的调用方法无误后,PyTorch 转 ONNX 时最容易出现的问题就是算子不兼容了。这里我们会介绍如何判断某个 PyTorch 算子在 ONNX 中是否兼容,以助大家在碰到报错时能更好地把错误归类。而具体添加算子的方法我们会在之后的文章里介绍。
在转换普通的torch.nn.Module
模型时,PyTorch 一方面会用跟踪法执行前向推理,把遇到的算子整合成计算图;另一方面,PyTorch 还会把遇到的每个算子翻译成 ONNX 中定义的算子。在这个翻译过程中,可能会碰到以下情况:
- 该算子可以一对一地翻译成一个 ONNX 算子。
- 该算子在 ONNX 中没有直接对应的算子,会翻译成一至多个 ONNX 算子。
- 该算子没有定义翻译成 ONNX 的规则,报错。
那么,该如何查看 PyTorch 算子与 ONNX 算子的对应情况呢?由于 PyTorch 算子是向 ONNX 对齐的,这里我们先看一下 ONNX 算子的定义情况,再看一下 PyTorch 定义的算子映射关系。
ONNX算子文档
ONNX 算子的定义情况,都可以在官方的算子文档中查看。这份文档十分重要,我们碰到任何和 ONNX 算子有关的问题都得来”请教“这份文档。https://github.com/onnx/onnx/blob/main/docs/Operators.md
这份文档中最重要的开头的这个算子变更表格。表格的第一列是算子名,第二列是该算子发生变动的算子集版本号,也就是我们之前在torch.onnx.export
中提到的opset_version
表示的算子集版本号。通过查看算子第一次发生变动的版本号,我们可以知道某个算子是从哪个版本开始支持的;通过查看某算子小于等于opset_version
的第一个改动记录,我们可以知道当前算子集版本中该算子的定义规则。
通过点击表格中的链接,我们可以查看某个算子的输入、输出参数规定及使用示例。比如上图是 Relu 在 ONNX 中的定义规则,这份定义表明 Relu 应该有一个输入和一个输入,输入输出的类型相同,均为 tensor。
4.3 PyTorch 对 ONNX 算子的映射
在 PyTorch 中,和 ONNX 有关的定义全部放在 torch.onnx
目录中
https://github.com/pytorch/pytorch/tree/main/torch/onnx
其中,symbolic_opset{n}.py
(符号表文件)即表示 PyTorch 在支持第 n 版 ONNX 算子集时新加入的内容。我们之前讲过, bicubic 插值是在第 11 个版本开始支持的。我们以它为例来看看如何查找算子的映射情况。
首先,使用搜索功能,在torch/onnx
文件夹搜索"bicubic",可以发现这个这个插值在第 11 个版本的定义文件中:
其中,symbolic_opset{n}.py
(符号表文件)即表示 PyTorch 在支持第 n 版 ONNX 算子集时新加入的内容。我们之前讲过, bicubic 插值是在第 11 个版本开始支持的。我们以它为例来看看如何查找算子的映射情况。
首先,使用搜索功能,在torch/onnx
文件夹搜索"bicubic",可以发现这个这个插值在第 11 个版本的定义文件中:
我们按照代码的调用逻辑,逐步跳转直到最底层的 ONNX 映射函数:
upsample_bicubic2d = _interpolate("upsample_bicubic2d", 4, "cubic") -> def _interpolate(name, dim, interpolate_mode): return sym_help._interpolate_helper(name, dim, interpolate_mode) -> def _interpolate_helper(name, dim, interpolate_mode): def symbolic_fn(g, input, output_size, *args): ... return symbolic_fn
最后,在symbolic_fn
中,我们可以看到插值算子是怎么样被映射成多个 ONNX 算子的。其中,每一个g.op
就是一个 ONNX 的定义。比如其中的 Resize
算子就是这样写的:
return g.op("Resize", input, empty_roi, empty_scales, output_size, coordinate_transformation_mode_s=coordinate_transformation_mode, cubic_coeff_a_f=-0.75, # only valid when mode="cubic" mode_s=interpolate_mode, # nearest, linear, or cubic nearest_mode_s="floor") # only valid when mode="nearest"
通过在前面提到的ONNX 算子文档中查找 Resize 算子的定义,我们就可以知道这每一个参数的含义了。用类似的方法,我们可以去查询其他 ONNX 算子的参数含义,进而知道 PyTorch 中的参数是怎样一步一步传入到每个 ONNX 算子中的。
掌握了如何查询 PyTorch 映射到 ONNX 的关系后,我们在实际应用时就可以在 torch.onnx.export()
的opset_version
中先预设一个版本号,碰到了问题就去对应的 PyTorch 符号表文件里去查。如果某算子确实不存在,或者算子的映射关系不满足我们的要求,我们就可能得用其他的算子绕过去,或者自定义算子了。
总结
- 跟踪法和记录法在导出带控制语句的计算图时有什么区别。
torch.onnx.export()
中该如何设置input_names, output_names, dynamic_axes
。- 使用
torch.onnx.is_in_onnx_export()
来使模型在转换到 ONNX 时有不同的行为。 - 如何查询 ONNX 算子文档(https://github.com/onnx/onnx/blob/main/docs/Operators.md)。
- 如何查询 PyTorch 对某个 ONNX 版本的新特性支持情况。
- 如何判断 PyTorch 对某个 ONNX 算子是否支持,支持的方法是怎样的。
四. 在 PyTorch 中支持更多 ONNX 算子
模型部署入门系列教程持续更新啦,在上一篇教程中,我们系统地学习了 PyTorch 转 ONNX 的方法,可以发现 PyTorch 对 ONNX 的支持还不错。但在实际的部署过程中,难免碰到模型无法用原生 PyTorch 算子表示的情况。这个时候,我们就得考虑扩充 PyTorch,即在 PyTorch 中支持更多 ONNX 算子。
而要使 PyTorch 算子顺利转换到 ONNX ,我们需要保证以下三个环节都不出错:
- 算子在 PyTorch 中有实现
- 有把该 PyTorch 算子映射成一个或多个 ONNX 算子的方法
- ONNX 有相应的算子
可在实际部署中,这三部分的内容都可能有所缺失。其中最坏的情况是:我们定义了一个全新的算子,它不仅缺少 PyTorch 实现,还缺少 PyTorch 到 ONNX 的映射关系。但所谓车到山前必有路,对于这三个环节,我们也分别都有以下的添加支持的方法:
- PyTorch 算子
- 组合现有算子
- 添加 TorchScript 算子
- 添加普通 C++ 拓展算子
- 映射方法
- 为 ATen 算子添加符号函数
- 为 TorchScript 算子添加符号函数
- 封装成
torch.autograd.Function
并添加符号函数
- ONNX 算子
- 使用现有 ONNX 算子
- 定义新 ONNX 算子
4.1 支持ATen算子
实际的部署过程中,我们都有可能会碰到一个最简单的算子缺失问题:算子在ATen中已经实现了,ONNX中也有相关算子的定义,但是相关算子映射成ONNX的规则没有写。在这种情况下,我们只需要为ATen算子补充描述映射规则的符号函数就行了。
ATen是PyTorch内置的C++张量计算库,PyTorch算子在底层绝大多数计算都是用ATen实现的。
Asinh算子在ATen中有实现,却缺少了映射到ONNX算子的符号函数。在这里,我们来尝试为它补充符号函数,并导出一个包含这个算子的ONNX模型。
获取ATen中算子接口定义
为了编写符号函数,我们需要获得asinh推理接口的输入参数定义。这时,我们要去torch/_C/_VariableFunctions.pyi
和 torch/nn/functional.pyi
这两个文件中搜索我们刚刚得到的这个算子名。这两个文件是编译PyTorch时本地自动生成的文件,里面包含了ATen算子的PyTorch调用接口。通过搜索,我们可以知道 asinh
在文件 torch/_C/_VariableFunctions.pyi
中,其接口定义为:
def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
经过这些步骤,我们确认了缺失的算子名为asinh, 它是一个有实现的ATen算子。我们还记下了asinh的调用接口。接下来,我们要为它补充符号函数,使它转换成ONNX模型时不再报错。
添加符号函数
到目前为止,我们已经多次接触了定义PyTorch到ONNX映射规则的符号函数了。
符号函数,可以堪称PyTorch算子类的一个静态方法。在把PyTorch模型转换成ONNX模型时,各个PyTorch算子的符号函数会被依次调用,以完成PyTorch算子到ONNX算子的转换。
符号函数的定义一般如下:
def symbolic(g: torch._C.Graph, input_0: torch._C.Value, input_1: torch._C.Value, ...):
其中,torch._C.Graph和torch._C.Value都对应PyTorch的C++实现里的一些类。我们只需要知道第一个参数就固定叫g,它表示和计算图相关的内容;后面的每个参数都表示算子的输入,需要和算子的前向推理接口的输入相同。对于ATen算子来收,他们的向前推理接口就是上述两个.pyi文件里的函数接口。
g有一个方法op。在把Pytorch算子转换成ONNX算子时,需要在符号函数中调用此方法来为最终的计算图添加一个ONNX算子。其定义如下:
def op(name: str, input_0: torch._C.Value, input_1: torch._C.Value, ...)
其中,第一个参数时算子名称。如果该算子是普通的ONNX算子,只需要把它在ONNX官方文档里的名称填进去即可。
在最简单的情况下,我们只要把 PyTorch 算子的输入用g.op()
一一对应到 ONNX 算子上即可,并把g.op()
的返回值作为符号函数的返回值。在情况更复杂时,我们转换一个 PyTorch 算子可能要新建若干个 ONNX 算子。
补充完了背景知识,让我们回到 asinh
算子上,来为它编写符号函数。我们先去翻阅一下 ONNX 算子文档,学习一下我们在符号函数里的映射关系 g.op()
里应该怎么写。Asinh
的文档写道:该算子有一个输入 input
,一个输出 output
,二者的类型都为张量。
到这里,我们已经完成了信息收集环节。我们在上一小节得知了 asinh
的推理接口定义,在这一小节里收集了 ONNX 算子 Asinh
的定义。现在,我们可以用代码来补充这二者的映射关系了。在刚刚导出 asinh
算子的代码中,我们添加以下内容:
import torchclass Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):return torch.asinh(x)#from torch.onnx.symbolic_registry import register_op
from torch.onnx import register_custom_op_symbolicdef asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)
# 注册到特定操作
register_custom_op_symbolic('aten::asinh', # 注意使用正确的操作名asinh_symbolic,11 # opset版本
)#register_op('asinh', asinh_symbolic, '', 9)model = Model()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, input, 'asinh.onnx')
这里的asinh_symbolic就是asinh的符号函数。从除g以外的第二个函数参数开始,其输入参数应该严格对应它在ATen中的定义:
def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
在符号函数的函数体中,g.op("Asinh", input)
则完成了 ONNX 算子的定义。其中,第一个参数"Asinh"
是算子在 ONNX 中的名称。至于第二个参数 input
,如我们刚刚在文档里所见,这个算子只有一个输入,因此我们只要把符号函数的输入参数 input
对应过去就行。ONNX 的 Asinh
的输出和 ATen 的 asinh
的输出是一致的,因此我们直接把 g.op()
的结果返回即可。
定义完符号函数后,我们要把这个符号函数和原来的 ATen 算子“绑定”起来。这里,我们要用到 register_op
这个 PyTorch API 来完成绑定。如示例所示,只需要一行简单的代码即可把符号函数 asinh_symbolic
绑定到算子 asinh
上:
register_op('asinh', asinh_symbolic, '', 9)
register_op第一个参数是目标ATen算子名,第二个是要注册的符号函数,这两个参数很好理解。第三个参数是算子的“域”,对于普通 ONNX 算子,直接填空字符串即可。第四个参数表示向哪个算子集版本注册。我们遵照 ONNX 标准,向第 9 号算子集注册。值得注意的是,这里向第 9 号算子集注册,不代表较新的算子集(第 10 号、第 11 号……)都得到了注册。在示例中,我们先只向第 9 号算子集注册。
测试算子
在完成了一份自定义算子后,我们一定要测试一下算子的正确性。一般我们要用 PyTorch 运行一遍原算子,再用推理引擎(比如 ONNX Runtime)运行一下 ONNX 算子,最后比对两次的运行结果。对于我们刚刚得到的 asinh.onnx
,可以用如下代码来验证:
import onnxruntime
import torch
import numpy as npclass Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):return torch.asinh(x)model = Model()
input = torch.rand(1, 3, 10, 10)
torch_output = model(input).detach().numpy()sess = onnxruntime.InferenceSession('asinh.onnx')
ort_output = sess.run(None, {'onnx::Asinh_0': input.numpy()})[0]assert np.allclose(torch_output, ort_output)
4.2 支持TorchScript算子
对于一些比较复杂的运算,仅使用 PyTorch 原生算子是无法实现的。这个时候,就要考虑自定义一个 PyTorch 算子,再把它转换到 ONNX 中了。新增 PyTorch 算子的方法有很多,PyTorch 官方比较推荐的一种做法是添加 TorchScript 算子 。
由于添加算子的方法较繁琐,我们今天跳过新增 TorchScript 算子的内容,以可变形卷积(Deformable Convolution)算子为例,介绍为现有 TorchScript 算子添加 ONNX 支持的方法。
可变形卷积(Deformable Convolution)是在 Torchvision 中实现的 TorchScript 算子,虽然尚未得到广泛支持,但是出现在许多模型中。
有了支持 ATen 算子的经验之后,我们可以知道为算子添加符号函数一般要经过以下几步:
- 获取原算子的前向推理接口。
- 获取目标 ONNX 算子的定义。
- 编写符号函数并绑定。
在为可变形卷积添加符号函数时,我们也可以尝试走一遍这个流程。
其中,torchvision.ops.DeformConv2d就是Torchvision中的可变形卷积层。相比于普通卷积,可变形卷积的其他参数都大致相同,唯一的区别就是在推理时需要多输入一个表示偏移量的张量。
然后,我们查询算子的前向推理接口。DeformConv2d
层最终会调用 deform_conv2d
这个算子。我们可以在 torchvision/csrc/ops/deform_conv2d.cpp
中查到该算子的调用接口:
m.def(TORCH_SELECTIVE_SCHEMA( "torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, ...... bool use_mask) -> Tensor"));
自定义ONNX算子
很遗憾的是,如果我们去 ONNX 的官方算子页面搜索 "deform",将搜不出任何内容。目前,ONNX 还没有提供可变形卷积的算子,我们要自己定义一个 ONNX 算子了。
我们在前面讲过,g.op()
是用来定义 ONNX 算子的函数。对于 ONNX 官方定义的算子,g.op()
的第一个参数就是该算子的名称。而对于一个自定义算子,g.op()
的第一个参数是一个带命名空间的算子名,比如:
g.op("custom::deform_conv2d, ...)
其中,"::"前面的内容就是我们的命名空间。该概念和 C++ 的命名空间类似,是为了防止命名冲突而设定的。如果在 g.op()
里不加前面的命名空间,则算子会被默认成 ONNX 的官方算子。
PyTorch 在运行 g.op()
时会对官方的算子做检查,如果算子名有误,或者算子的输入类型不正确, g.op()
就会报错。为了让我们随心所欲地定义新 ONNX 算子,我们必须设定一个命名空间,给算子取个名,再定义自己的算子。
我们在第一篇教程讲过:ONNX 是一套标准,本身不包括实现。在这里,我们就简略地定义一个 ONNX 可变形卷积算子,而不去写它在某个推理引擎上的实现。在后续的文章中,我们再介绍在各个推理引擎中添加新 ONNX 算子支持的方法。此处,我们只关心如何导出一个包含新 ONNX 算子节点的 onnx 文件。因此,我们可以为新算子编写如下简单的符号函数
import torch
import torchvision class Model(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(3, 18, 3) self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3) def forward(self, x): return self.conv2(x, self.conv1(x)) from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args @parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "none")
def symbolic(g, input, weight, offset, mask, bias, stride_h, stride_w, pad_h, pad_w, dil_h, dil_w, n_weight_grps, n_offset_grps, use_mask): return g.op("custom::deform_conv2d", input, offset) register_custom_op_symbolic("torchvision::deform_conv2d", symbolic, 9) model = Model()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, input, 'dcn.onnx')
在这个符号函数中,我们以刚刚搜索到的算子输入参数作为符号函数的输入参数,并只用 input
和 offset
来构造一个简单的 ONNX 算子。
这段代码中,最令人疑惑的就是装饰器 @parse_args
了。简单来说,TorchScript 算子的符号函数要求标注出每一个输入参数的类型。比如"v"表示 Torch 库里的 value
类型,一般用于标注张量,而"i"表示 int 类型,"f"表示 float 类型,"none"表示该参数为空。具体的类型含义可以在 torch.onnx.symbolic_helper.py
(https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_helper.py)中查看。这里输入参数中的 input, weight, offset, mask, bias
都是张量,所以用"v"表示。后面的其他参数同理。我们不必纠结于 @parse_args
的原理,根据实际情况对符号函数的参数标注类型即可。
代码成功运行的话,我们应该能得到如下的 ONNX 模型:
4.3 使用torch.autograd.Function
最后,我们来学习一种简单的为 PyTorch 添加 C++ 算子实现的方法,来代替较为复杂的新增 TorchScript 算子。同时,我们会用 torch.autograd.Function
封装这个新算子。torch.autograd.Function
能完成算子实现和算子调用的隔离。不管算子是怎么实现的,它封装后的使用体验以及 ONNX 导出方法会和原生的 PyTorch 算子一样。这是我们比较推荐的为算子添加 ONNX 支持的方法。
为了应对更复杂的情况,我们来自定义一个奇怪的 my_add
算子。这个算子的输入张量 a, b
,输出 2a + b
的值。我们会先把它在 PyTorch 中实现,再把它导出到 ONNX 中。
为PyTorch添加C++拓展
为 PyTorch 添加简单的 C++ 拓展还是很方便的。对于我们定义的 my_add
算子,可以用以下的 C++ 源文件来实现。我们把该文件命名为 "my_add.cpp":
// my_add.cpp #include <torch/torch.h> torch::Tensor my_add(torch::Tensor a, torch::Tensor b)
{ return 2 * a + b;
} PYBIND11_MODULE(my_lib, m)
{ m.def("my_add", my_add);
}
由于在 PyTorch 中添加 C++ 拓展和模型部署关系不大,这里我们仅给出这个简单的示例,并不对其原理做过多讲解。
在这段代码中,torch::Tensor
就是 C++ 中 torch 的张量类型,它的加法和乘法等运算符均已重载。因此,我们可以像对普通标量一样对张量做加法和乘法。
轻松地完成了算子的实现后,我们用 PYBIND11_MODULE
来为 C++ 函数提供 Python 调用接口。这里的 my_lib
是我们未来要在 Python 里导入的模块名。双引号中的 my_add
是 Python 调用接口的名称,这里我们对齐 C++ 函数的名称,依然用 "my_add"这个名字。
之后,我们可以编写如下的 Python 代码并命名为 "setup.py",来编译刚刚的 C++ 文件:
from setuptools import setup
from torch.utils import cpp_extension setup(name='my_add', ext_modules=[cpp_extension.CppExtension('my_lib', ['my_add.cpp'])], cmdclass={'build_ext': cpp_extension.BuildExtension})
这段代码使用了 Python 的 setuptools 编译功能和 PyTorch 的 C++ 拓展工具函数,可以编译包含了 torch 库的 C++ 源文件。这里我们需要填写的只有模块名和模块中的源文件名。我们刚刚把模块命名为 my_lib
,而源文件只有一个 my_add.cpp
,因此拓展模块那一行要写成 ext_modules=[cpp_extension.CppExtension('my_lib', ['my_add.cpp'])],
。
之后,像处理普通的 Python 包一样执行安装命令,我们的 C++ 代码就会自动编译了。
python setup.py develop
用 torch.autograd.Function封装
直接用 Python 接口调用 C++ 函数不太“美观”,一种比较优雅的做法是把这个调用接口封装起来。这里我们用 torch.autograd.Function
来封装算子的底层调用:
import torch
import my_lib
class MyAddFunction(torch.autograd.Function): @staticmethod def forward(ctx, a, b): return my_lib.my_add(a, b) @staticmethod def symbolic(g, a, b): two = g.op("Constant", value_t=torch.tensor([2])) a = g.op('Mul', a, two) return g.op('Add', a, b)
我们在前面的教程中已经见过 torch.autograd.Function
,这里我们正式地对其做一个介绍。Function
类本身表示 PyTorch 的一个可导函数,只要为其定义了前向推理和反向传播的实现,我们就可以把它当成一个普通 PyTorch 函数来使用。
PyTorch 会自动调度该函数,合适地执行前向和反向计算。对模型部署来说,Function
类有一个很好的性质:如果它定义了 symbolic
静态方法,该 Function
在执行 torch.onnx.export()
时就可以根据 symbolic
中定义的规则转换成 ONNX 算子。这个 symbolic
就是前面提到的符号函数,只是它的名称必须是 symbolic
而已。
在 forward
函数中,我们用 my_lib.my_add(a, b)
就可以调用之前写的C++函数了。这里 my_lib
是库名,my_add
是函数名,这两个名字是在前面C++的 PYBIND11_MODULE
中定义的。
在 symbolic
函数中,我们用 g.op()
定义了三个算子:常量、乘法、加法。这里乘法和加法的用法和前面提到的 asinh
一样,只需要根据 ONNX 算子定义规则把输入参数填入即可。而在定义常量算子时,我们要把 PyTorch 张量的值传入 value_t
参数中。
在 ONNX 中,我们需要把新建常量当成一个算子来看待,尽管这个算子并不会以节点的形式出现在 ONNX 模型的可视化结果里。
把算子封装成 Function
后,我们可以把 my_add
算子用起来了。
my_add = MyAddFunction.apply class MyAdd(torch.nn.Module): def __init__(self): super().__init__() def forward(self, a, b): return my_add(a, b)
在这份代码里,我们先用 my_add = MyAddFunction.apply
获取了一个奇怪的变量。这个变量是用来做什么的呢?其实,apply
是torch.autograd.Function
的一个方法,这个方法完成了 Function
在前向推理或者反向传播时的调度。我们在使用 Function
的派生类做推理时,不应该显式地调用 forward()
,而应该调用其 apply
方法。
这里我们使用 my_add = MyAddFunction.apply
把这个调用方法取了一个更简短的别名 my_add
。以后在使用 my_add
算子时,我们应该忽略 MyAddFunction
的实现细节,而只通过 my_add
这个接口来访问算子。这里 my_add
的地位,和 PyTorch 的 asinh
, interpolate
, conv2d
等原生函数是类似的。
有了访问新算子的接口后,我们可以进一步把算子封装成一个神经网络中的计算层。我们定义一个叫做的 MyAdd
的 torch.nn.Module
,它封装了my_add
,就和封装了conv2d
的 torch.nn.Conv2d
一样。
测试算子
费了好大的功夫来“包装”我们的新算子后,我们终于可以来使用它了。和之前的测试流程一样,让我们用下面的代码来导出一个包含新算子的 ONNX 模型,并验证一下它是否正确。
model = MyAdd()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, (input, input), 'my_add.onnx')
torch_output = model(input, input).detach().numpy() import onnxruntime
import numpy as np
sess = onnxruntime.InferenceSession('my_add.onnx')
ort_output = sess.run(None, {'a': input.numpy(), 'b': input.numpy()})[0] assert np.allclose(torch_output, ort_output)
在这份代码中,我们直接把 MyAdd
作为要导出的模型。我们计算了一个 PyTorch 模型的运行结果,又导出 ONNX 模型,计算了 ONNX 模型在 ONNX Runtime 上的运算结果。如果一切正常的话,这两个结果是一样的,这份代码不会报任何错误,没有任何输出。
可视化一下 my_add.onnx
,可以看出,和我们设计得一样,my_add
算子被翻译成了两个 ONNX 算子节点(其中常量算子被放入了 Mul
的参数中)。
整理一下,整个流程的 Python 代码如下:
import torch
import my_lib
class MyAddFunction(torch.autograd.Function): @staticmethod def forward(ctx, a, b): return my_lib.my_add(a, b) @staticmethod def symbolic(g, a, b): two = g.op("Constant", value_t=torch.tensor([2])) a = g.op('Mul', a, two) return g.op('Add', a, b) my_add = MyAddFunction.apply class MyAdd(torch.nn.Module): def __init__(self): super().__init__() def forward(self, a, b): return my_add(a, b) model = MyAdd()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, (input, input), 'my_add.onnx')
torch_output = model(input, input).detach().numpy() import onnxruntime
import numpy as np
sess = onnxruntime.InferenceSession('my_add.onnx')
ort_output = sess.run(None, {'a': input.numpy(), 'b': input.numpy()})[0] assert np.allclose(torch_output, ort_output)
总结
- ATen 是 PyTorch 的 C++ 张量运算库。通过查询
torch/_C/_VariableFunctions.pyi
和torch/nn/functional.pyi
,我们可以知道 ATen 算子的 Python 接口定义。 - 用
register_custom_op_symbolic
可以为 ATen 算子补充注册符号函数 - 用
register_custom_op_symbolic
可以为 TorchScript 算子补充注册符号函数 - 如何在 PyTorch 里添加 C++ 拓展
- 如何用
torch.autograd.Function
封装一个自定义 PyTorch 算子 - 如何编写符号函数
symbolic(g, ...)
。 - 如何用
g.op()
把一个 PyTorch 算子映射成一个或多个 ONNX 算子,或者是自定义的 ONNX 算子。
五. ONNX 模型的修改与调试
不知道大家会不会有这样一些疑问:ONNX 模型在底层是用什么格式存储的?如何不依赖深度学习框架,只用 ONNX 的 API 来构造一个 ONNX 模型?如果没有源代码,只有一个 ONNX 模型,该如何对这个模型进行调试?别急,今天我们就来为大家一一揭晓。
在这期教程里,我们将围绕 ONNX 这一套神经网络定义标准本身,探究 ONNX 模型的构造、读取、子模型提取、调试。首先,我们会学习 ONNX 的底层表示方式。之后,我们会用 ONNX API 构造和读取模型。最后,我们会利用 ONNX 提供的子模型提取功能,学习如何调试 ONNX 模型。
5.1 ONNX的底层实现
1. ONNX的存储格式
ONNX在底层时用Protobuf定义的。Protobuf,全称Protocol Buffer,是Google提出的一套表示和序列化数据的机制。使用protobuf时,用户需要先写一份数据定义文件,再根据这份定义文件把数据存储进一份二进制文件。可以说,数据定义文件就是数据类,二进制文件就是数据类的实例。
这里给出一个 Protobuf 数据定义文件的例子:
message Person { required string name = 1; required int32 id = 2; optional string email = 3;
}
这段定义表示在Person这种数据类型中,必须包含 name、id这两个字段,选择性包含email字段。根据这份定义文件,用户就可以选择一种编程语言,定义一个含有成员变量name、id、email的Person类,把这个类的某个实例用Protobuf存储成二进制文件;反之,用户也可以用二进制文件和对应的数据定义文件,读取出一个Person类的实例。而对于ONNX,Protobuf的数据定义文件在其开源库,这些文件定义了神经网络中模型、节点、张量的数据类型规范;而二进制文件就是我们熟悉的“.onnx”文件,每一个onnx文件按照数据定义规范,存储了一个神经网络的所有相关数据。直接用protobuf生成ONNX模型还是比较麻烦。幸亏的事,ONNX提供了很多使用API,我们可以在完全不了解Protobuf前提下,构造和读取ONNX模型。
2. ONNX的结构定义
再用API对ONNX模型操作之前,我们还需要先了解一下ONNX的结构定义规则,学习一下ONNX在Protobuf定义文件里事怎样描述一个神经网络的。
回想一下,神经网络本质上是一个计算图,计算图的节点是算子,边是参与运算的张量。而通过可视化ONNX模型,我们知道ONNX记录了所有算子节点的属性信息,并把参与运算的张量信息存储在算子节点的输入输出信息中。实际上,ONNX模型的结构可以用类图大致表示如下:
如图所示,一个 ONNX 模型可以用 ModelProto
类表示。ModelProto
包含了版本、创建者等日志信息,还包含了存储计算图结构的 graph
。GraphProto
类则由输入张量信息、输出张量信息、节点信息组成。张量信息 ValueInfoProto
类包括张量名、基本数据类型、形状。节点信息 NodeProto
类包含了算子名、算子输入张量名、算子输出张量名。
让我们来看一个具体的例子。假如我们有一个描述 output=a*x+b
的 ONNX 模型 model
,用 print(model)
可以输出以下内容:
ir_version: 8
graph { node { input: "a" input: "x" output: "c" op_type: "Mul" } node { input: "c" input: "b" output: "output" op_type: "Add" } name: "linear_func" input { name: "a" type { tensor_type { elem_type: 1 shape { dim {dim_value: 10} dim {dim_value: 10} } } } } input { name: "x" type { tensor_type { elem_type: 1 shape { dim {dim_value: 10} dim {dim_value: 10} } } } } input { name: "b" type { tensor_type { elem_type: 1 shape { dim {dim_value: 10} dim {dim_value: 10} } } } } output { name: "output" type { tensor_type { elem_type: 1 shape { dim { dim_value: 10} dim { dim_value: 10} } } } }
}
opset_import {version: 15}
对应上文中的类图,这个模型的信息由ir_version, opset_import等全局信息和graph图信息组成。而graph包含一个乘法节点、一个加法节点、三个输入张量a,x,b以及一个输出张量output。在下一节里,我们会会API构造出这个模型,并输出这段结果。
3. 读写ONNX模型
3.1 构造ONNX模型
在上一小节中,我们知道了 ONNX 模型是按以下的结构组织起来的:
- ModelProto
- GraphProto
- NodeProto
- ValueInfoProto
- GraphProto
现在,让我们抛开PyTorch,尝试完全用ONNX的PyThon API构造一个描述线性函数output=a*x +b的ONNX模型。我们将根据上面的结构,自底向上地构造这个模型。首先我们可以用helper.make_tensor_value_info构造出一个描述张量信息的ValueInfoProto对象。如前面的类图所示,我们要传入张量,他们的表示方法都是一样的。因此,这里我们用类似的方式为三个输入a,x,b和一个输出output构造ValueInfoProto对象。如下:
import onnx
from onnx import helper
from onnx import TensorProtoa = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])
之后,我们要构造算子节点信息NodeProto,这可以通过在helper.make_node中传入算子类型、输入算子名、输出算子名这三个信息来实现。我们这里先构造了描述c=a*x的乘法节点,再构造了output=c+b的加法节点。如下:
mul = helper.make_node('Mul', ['a', 'b'], ['c'])
add = helper.make_node('Add', ['c', 'b'], ['output'])
在计算机中,图一般是用一个节点集和一个边集表示的。而ONNX巧妙地把边的信息保存在了节点信息里,省去了保存边集的步骤。在ONNX,如果某节点的输入名和之前某节点的输出名相同,就默认了这两个节点是相连的。如上图例子所示:Mul节点定义了输出c, Add节点定义了输入c, 则Mul节点和Add节点是相连的。
正是因为这种边的隐式定义规则,所以ONNX对节点的输入由一定的要求:一个节点的输入,要么是整个模型的输入,要么是之前某个节点的输出。如果我们把a, x, b中的某个输入节点从计算图中拿出,或者把Mul的输出从c改成d,则最终的ONNX模型都是不满足标准的。
一个不满足标准的 ONNX 模型可能无法被推理引擎正确识别。ONNX 提供了 API
onnx.checker.check_model
来判断一个 ONNX 模型是否满足标准。
接下来,我们用helper.make_graph来构造计算图GraphProto。helper.make_graph函数需要传入节点、图名称、输入张量信息、输出张量信息这4个参数。如下面的代码所示,我们把之前构造出来的NodeProto对象和ValueInfoProto对象按照顺序传入即可。
graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output])
这里make_graph的节点参数有一个要求:计算图的节点必须以拓扑序给出。
我们以刚刚构造出来的这个计算图为研究对象,通过下图展示的两个例子来直观理解拓扑序。
这里我们只关注 Mul
和 Add
节点以及它们之间的边 c
。在情况 1 中:如果我们的节点以 [Mul, Add]
顺序给出,那么遍历到 Add
时,它的输入 c
可以在之前的Mul
的输出中找到。但是,如情况 2 所示:如果我们的节点以 [Add, Mul]
的顺序给出,那么 Add
就找不到输入边,计算图也无法成功构造出来了。这里的 [Mul, Add]
就是符合有向图的拓扑序的,而 [Add, Mul]
则不满足。
最后,我们用helper.make_model把计算图GraphProto封装进模型ModelProto里,一个ONNX模型就构造完成了。make_model函数中还可以添加模型制作者、版本等信息。
model = helper.make_model(graph)
构造完模型之后,我们用下面这三行代码来检查模型正确性、把模型以文本形式输出、存储到一个 ".onnx" 文件里。这里用 onnx.checker.check_model
来检查模型是否满足 ONNX 标准是必要的,因为无论模型是否满足标准,ONNX 都允许我们用 onnx.save
存储模型。我们肯定不希望生成一个不满足标准的模型。
onnx.checker.check_model(model)
print(model)
onnx.save(model, 'linear_func.onnx')
成功执行这些代码的话,程序会以文本格式输出模型的信息,其内容应该和我们在上一节展示的输出一样。
整理一下,用 ONNX Python API 构造模型的代码如下:
import onnx
from onnx import helper
from onnx import TensorProtoa = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])mul = helper.make_node('Mul', ['a', 'x'], ['c'])
add = helper.make_node('Add', ['c', 'b'], ['output'])graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output])
model = helper.make_model(graph)
onnx.checker.check_model(model)
print(model)
onnx.save(model, 'linear_func.onnx')
老规矩,我们可以用 ONNX Runtime 运行模型,来看看模型是否正确:
import onnxruntime
import numpy as np sess = onnxruntime.InferenceSession('linear_func.onnx')
a = np.random.rand(10, 10).astype(np.float32)
b = np.random.rand(10, 10).astype(np.float32)
x = np.random.rand(10, 10).astype(np.float32) output = sess.run(['output'], {'a': a, 'b': b, 'x': x})[0] assert np.allclose(output, a * x + b)
一切顺利的话,这段代码不会有任何报错信息。这说明我们的模型等价于执行 a * x + b
这个计算。
3.2 读取并修改ONNX模型
通过用 API 构造 ONNX 模型,我们已经彻底搞懂了 ONNX 由哪些模块组成。现在,让我们看看该如何读取现有的".onnx"文件并从中提取模型信息。
首先,我们可以用下面的代码读取一个 ONNX 模型:
import onnx
model = onnx.load('linear_func.onnx')
print(model)
之前在输出模型时,我们传给onnx.save的是一个ModelProto对象。同理,我们用上面的onnx.load读取ONNX模型时,我们收获的也是一个Model.Proto对象。输出这个对象后,我们应该得到和之前完全相同的输出。
接下来,我们来看看怎么把图GraphProto
、节点 NodeProto
、张量信息 ValueInfoProto
读取出来:
graph = model.graph
node = graph.node
input = graph.input
output = graph.output
print(node)
print(input)
print(output)
使用如上这些代码,我们可以分别访问模型的图、节点、张量信息。这里大家或许会有疑问:该怎样找出 graph.node,graph.input
中 node, input
这些属性名称呢?其实,属性的名称就写在每个对象的输出里。我们以 print(node)
的输出为例:
[input: "a"
input: "x"
output: "c"
op_type: "Mul" ,
input: "c"
input: "b"
output: "output"
op_type: "Add" ]
在这段输出中,我们能看出 node
其实就是一个列表,列表中的对象有属性 input, output, op_type
(这里 input
也是一个列表,它包含的两个元素都显示出来了)。我们可以用下面的代码来获取 node
里第一个节点 Mul
的属性:
node_0 = node[0]
node_0_inputs = node_0.inputnode_0_outputs = node_0.outputinput_0 = node_0_inputs[0]
print(input_0) #a
input_1 = node_0_inputs[1]
print(input_1) #x
output = node_0_outputs[0]
print(output) #c
op_type = node_0.op_type
print(op_type) #Mul
当我们想知道 ONNX 模型某数据对象有哪些属性时,我们不必去翻 ONNX 文档,只需要先把数据对象输出一下,然后在输出结果找出属性名即可。
读取 ONNX 模型的信息后,修改 ONNX 模型就是一件很轻松的事了。我们既可以按照上一小节的模型构造方法,新建节点和张量信息,与原有模型组合成一个新的模型,也可以在不违反 ONNX 规范的前提下直接修改某个数据对象的属性。
这里我们来看一个直接修改模型属性的例子:
import onnx
model = onnx.load('linear_func.onnx') node = model.graph.node
node[1].op_type = 'Sub' onnx.checker.check_model(model)
onnx.save(model, 'linear_func_2.onnx')
在读入之前的 linear_func.onnx
模型后,我们可以直接修改第二个节点的类型 node[1].op_type
,把加法变成减法。这样,我们的模型描述的是 a * x - b
这个线性函数。大家感兴趣的话,可以用 ONNX Runtime 运行新模型 linear_func_2.onnx
,来验证一下它和 a * x - b
是否等价。
4. 调试ONNX模型
在实际部署中,如果用深度学习框架导出的 ONNX 模型出了问题,一般要通过修改框架的代码来解决,而不会从 ONNX 入手,我们把 ONNX 模型当成一个不可修改的黑盒看待。
现在,我们已经深入学习了 ONNX 的原理,可以尝试对 ONNX 模型本身进行调试了。在这一节里,让我们看看该如何巧妙利用 ONNX 提供的子模型提取功能,对 ONNX 模型进行调试。
4.1 子模型提取
ONNX 官方为开发者提供了子模型提取(extract)的功能。子模型提取,顾名思义,就是从一个给定的 ONNX 模型中,拿出一个子模型。这个子模型的节点集、边集都是原模型中对应集合的子集。让我们来用 PyTorch 导出一个复杂一点的 ONNX 模型,并在它的基础上执行提取操作:
import torch class Model(torch.nn.Module): def __init__(self): super().__init__() self.convs1 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), torch.nn.Conv2d(3, 3, 3), torch.nn.Conv2d(3, 3, 3)) self.convs2 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), torch.nn.Conv2d(3, 3, 3)) self.convs3 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), torch.nn.Conv2d(3, 3, 3)) self.convs4 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), torch.nn.Conv2d(3, 3, 3), torch.nn.Conv2d(3, 3, 3)) def forward(self, x): x = self.convs1(x) x1 = self.convs2(x) x2 = self.convs3(x) x = x1 + x2 x = self.convs4(x) return x model = Model()
input = torch.randn(1, 3, 20, 20) torch.onnx.export(model, input, 'whole_model.onnx')
在前面的章节中,我们学过,ONNX 的边用同名张量表示的。也就是说,这里的边序号,实际上是前一个节点的输出张量序号和后一个节点的输入张量序号。由于这个模型是用 PyTorch 导出的,这些张量序号都是 PyTorch 自动生成的。
接着,我们可以下面的代码提取出一个子模型:
import onnx onnx.utils.extract_model('whole_model.onnx', 'partial_model.onnx', ['/convs1/convs1.2/Conv_output_0'], ['/convs4/convs4.0/Conv_output_0'])
paritial mode的子模型可视化结果如下图:
通过观察代码和输出图,应该不难猜出这段代码的作用是把原计算图从边 22 到边 28 的子图提取出来,并组成一个子模型。onnx.utils.extract_model
就是完成子模型提取的函数,它的参数分别是原模型路径、输出模型路径、子模型的输入边(输入张量)、子模型的输出边(输出张量)。
直观地来看,子模型提取就是把输入边到输出边之间的全部节点都取出来。那么,这个功能在使用上有什么限制呢?基于 whole_model.onnx
, 我们来看一看三个子模型提取的示例。
添加额外输出
我们在提取时新设定了一个输出张量,如下面的代码所示:
onnx.utils.extract_model('whole_model.onnx', 'submodel_1.onnx', ['/convs1/convs1.2/Conv_output_0'], ['/convs2/convs2.1/Conv_output_0', '31'])
添加冗余输入
如果我们还是像开始一样提取边 22 到边 28 之间的子模型,但是多添加了一个输入 input.1
,那么提取出的子模型会有一个冗余的输入 input.1
,如下面的代码所示:
onnx.utils.extract_model('whole_model.onnx', 'submodel_2.onnx', ['22', 'input.1'], ['28'])
从下图可以看到:无论给这个输入传入什么值,都不会影响子模型的输出。可以认为如果只用子模型的部分输入就能得到输出,那么那些”较早“的多出来的输入就是冗余的。
输入信息不足
这次,我们尝试提取的子模型输入是边 24,输出是边 28。如下面的代码和图所示:
#error
onnx.utils.extract_model('whole_model.onnx', 'submodel_3.onnx', ['24'], ['28'])
从图中可以看出,想通过边 24 计算边 28 的结果,至少还需要输入边 26,或者更上面的边。仅凭借边 24 是无法计算出边 28 的结果的,因此这样提取子模型会报错。
通过上面几个使用示例,我们可以整理出子模型提取的实现原理:新建一个模型,把给定的输入和输出填入。之后把图的所有有向边反向,从输出边开始遍历节点,碰到输入边则停止,把这样遍历得到的节点做为子模型的节点。
如果还没有彻底弄懂这个提取原理,没关系,我们只要尽量保证在填写子模型的输入输出时,让输出恰好可以由输入决定即可。
5. 输出 ONNX 中间节点的值
在使用 ONNX 模型时,最常见的一个需求是能够用推理引擎输出中间节点的值。这多见于深度学习框架模型和 ONNX 模型的精度对齐中,因为只要能够输出中间节点的值,就能定位到精度出现偏差的算子。我们来看看如何用子模型提取实现这一任务。
在刚刚的第一个子模型提取示例中,我们添加了一条原来模型中不存在的输出边。用同样的原理,我们可以在保持原有输入输出不变的同时,新增加一些输出,提取出一个能输出中间节点的”子模型“。例如:
onnx.utils.extract_model('whole_model.onnx', 'more_output_model.onnx', ['input.1'], ['31','/convs3/convs3.1/Conv_output_0', '/convs2/convs2.1/Conv_output_0'])
这样,用 ONNX Runtime 运行 more_output_model.onnx
这个模型时,我们就能得到更多的输出了。
为了方便调试,我们还可以把原模型拆分成多个互不相交的子模型。这样,在每次调试时,可以只对原模型的部分子模块调试。比如:
onnx.utils.extract_model('whole_model.onnx', 'debug_model_1.onnx', ['input.1'], ['/convs2/convs2.0/Conv_output_0'])
onnx.utils.extract_model('whole_model.onnx', 'debug_model_2.onnx', ['/convs1/convs1.1/Conv_output_0'], ['/convs3/convs3.1/Conv_output_0'])
onnx.utils.extract_model('whole_model.onnx', 'debug_model_3.onnx', ['/convs1/convs1.2/Conv_output_0'], ['/convs2/convs2.1/Conv_output_0'])
onnx.utils.extract_model('whole_model.onnx', 'debug_model_4.onnx', ['/convs3/convs3.1/Conv_output_0', '/convs2/convs2.1/Conv_output_0'],['31'])
子模型提取固然是一个便利的 ONNX 调试工具。但是,在实际的情况中,我们一般是用 PyTorch 等框架导出 ONNX 模型。这里有两个问题:
- 一旦 PyTorch 模型改变,ONNX 模型的边序号也会改变。这样每次提取同样的子模块时都要重新去 ONNX 模型里查序号,如此繁琐的调试方法是不会在实践中采用的。
- 即使我们能保证 ONNX 的边序号不发生改变,我们也难以把 PyTorch 代码和 ONNX 节点对应起来——当模型结构变得十分复杂时,要识别 ONNX 中每个节点的含义是不可能的。
在 MMDeploy 中,我们为 PyTorch 模型添加了模型分块功能。使用这个功能,我们可以通过只修改 PyTorch 模型的实现代码来把原模型导出成多个互不相交的子 ONNX 模型。我们会在后续教程中对其介绍。
总结
在这篇教程中,我们抛开了 PyTorch,学习了 ONNX 模型本身的知识。老规矩,我们来总结一下这篇教程的知识点:
- ONNX 使用 Protobuf 定义规范和序列化模型。
- 一个 ONNX 模型主要由
ModelProto
,GraphProto
,NodeProto
,ValueInfoProto
这几个数据类的对象组成。 - 使用
onnx.helper.make_xxx
,我们可以构造 ONNX 模型的数据对象。 onnx.save()
可以保存模型,onnx.load()
可以读取模型,onnx.checker.check_model()
可以检查模型是否符合规范。onnx.utils.extract_model()
可以从原模型中取出部分节点,和新定义的输入、输出边构成一个新的子模型。- 利用子模型提取功能,我们可以输出原 ONNX 模型的中间结果,实现对 ONNX 模型的调试。
至此,我们对 ONNX 相关知识的学习就告一段落了。回顾一下,我们先学习了 PyTorch 转 ONNX 有关 API 的用法;接着,我们学习了如何用自定义算子解决 PyTorch 和 ONNX 表达能力不足的问题;最后我们单独学习了 ONNX 模型的调试方法。通过对 ONNX 由浅入深的学习,我们基本可以应对模型部署中和 ONNX 有关的绝大多数问题了。
六. 实现 PyTorch-ONNX 精度对齐工具
精度对齐,是模型部署中重要的一个环节。把深度学习框架模型转换成中间表示模型后,部署工程师们要做的第一件事就是精度对齐,确保模型的计算结果与之前相当。精度对齐时最常用的方法就是使用测试集评估一遍中间表示模型,看看模型的评估指标 准确度和相似度是否下降。
而在 PyTorch 到 ONNX 这条部署路线上,这种精度对齐方式有一些不便:一旦我们发现 PyTorch 模型和 ONNX 模型的评估指标有了出入,我们很难去追踪精度是在哪一个模块出了问题。这是因为 PyTorch 和 ONNX 模块总是难以对应。如下面的例子所示:
假设我们现在有一个由很多卷积块 convs1, convs2...
组成的网络,我们想对齐 PyTorch 模型和 ONNX 模型的精度。第一步,我们想比较第一个卷积块的输出 x = self.convs1(x)
。模块在PyTorch 模型中的输出可以很轻松地得到,可是,这个输出究竟对应 ONNX 模型里的哪一个输出呢?在小模型里,我们或许能够通过阅读 PyTorch 模型的源码,推断出每个 ONNX 模块与 PyTorch 模块的对应关系;但是,在大模型中,我们是难以建立 PyTorch 与 ONNX 的对应关系的。
在这篇教程中,我们就来利用之前学过的自定义算子、子模型提取等工具,实现一个简单的 PyTorch-ONNX 精度对齐工具。
6.1 设计思路
为了把 PyTorch 和 ONNX 模块对应起来,我们可以使用一种储存了调试信息的自定义算子,如下图所示:
我们可以定义一个叫做Debug的ONNX算子,它有一个属性调试名name。而由于每一个ONNX算子节点又自带了输出张量的名称,这样一来,ONNX节点的输出名和调试名绑定在了一起。我们可以顺着PyTorch里调试名,找到对应ONNX里的输出,完成PyTorch和ONNX的对应。
比如在上图的例子中,我们把第一个卷积块输出x=self.convs1(x)接入一个带有调试名x_0的调试算子。在最后生成的ONNX模型中,假设调试名x_0对应的输出张量叫做a。知道了这一信息后,我们只需要先运行一遍 PyTorch 模型,记录第一个卷积块的输出;再运行一遍 ONNX 模型,用上篇教程中提到的截取 ONNX 中间结果的方法,记录中间张量 a
的值。这样,我们就可以对齐某 PyTorch 模块和它对应的 ONNX 模块的输出了。
6.2 代码实现
debug算子
首先,我们需要实现之前提到的Debug算子:
import torch
class DebugOp(torch.autograd.Function):@staticmethoddef forward(ctx, x, name):return x@staticmethoddef symbolic(g, x, name):return g.op("my::Debug", x, name_s=name)debug_apply = DebugOp.apply
Debug 算子的调用接口有两个参数:输入张量 x
和调试名 name
。为了把这个算子“伪装”成一个普通的算子,使之能正常地参与推理、构建计算图的操作,我们还是需要正确定义对输入 x
进行操作的 forward
函数。而在表示 PyTorch 与 ONNX 映射规则的 symbolic
函数里,我们要定义一个带有调试名的 ONNX 算子,并把输入的 name
传给算子。
由于 Debug 算子本身不表示任何计算,因此在 forward
函数中,直接把输入 x
返回即可。
而 symbolic
函数定义了一个新算子 my::Debug
:算子有一个输入 x
,一个属性 name
。我们直接把算子调用接口里的 x
,name
传入即可。
这里需要补充介绍算子定义函数 g.op()
的一些规范。在g.op()
中,算子的属性需要以 {attibute_name}_{type}=attibute_value
这样的格式传入。其中 {attibute_name}
为属性名,{type}
指定了算子属性的数据类型。比如说我们上面的算子属性写成 name_s
,实际上是定义了一个字符串类型,名字叫做 name
的属性。除了表示字符串类型的 _s
外,还有表示 float
型的 _f
,表示 tensor
型的 _t
。
在完成算子的定义后,我们可以通过 debug_apply = DebugOp.apply
获取算子的调用接口。这样以后就可以通过 debug_apply(x, name)
来使用这个算子了。
Debugger类
接着,我们来实现精度对齐工具的核心——Debugger 类。这个类包含了实现精度对齐所需的所有操作。其定义如下:
Debugger类有三个成员变量:
- torch_value 记录了运行PyTorch模型后每个调试张量的值
- onnx_value 记录了运行ONNX模型后每个调试张量的值
- output_debug_name:记录了把调试张量加入ONNX的输出后,每个输出张量的调试名
稍后我们会在类实现的代码中看到这些成员变量的具体用法。
Debugger类有以下方法:
- debug封装了之前变好的debug_apply。该方法需要在原PyTorch模型中调用,可以为导出的ONNX的模型添加Debug算子节点,同时记录PyTorch调试张量值。
- extract_debug_model和ONNX的子模型提取函数的用法类似,可以把带调试节点的ONNX模型转化成一个可以输出调试张量的ONNX模型。
- run_debug_model会使用ONNX Runtime运行模型,得到ONNX调试张量值。
- print_debug_result会比较PyTorch和ONNX的调试张量值,输出比较的结果
这4个方法会一次被调用:
生成调试节点
def debug(self, x, name):self.torch_value[name] = x.detach().cpu().numpy()return debug_apply(x, name)
如前文所述,debug完成两件事:记录PyTorch模型中的调试张量的值、添加Debug节点。我们使用self.torch_value[name]=x.detach().cpu().numpy()把调试张量转成numpy格式并保存进torch_value词典里。之后,我们调用之前编写的debug_apply算子。
提取调试模型
import onnx
def extract_debug_model(self, input_path, output_path):model = onnx.load(input_path)inputs = [input.name for input in model.graph.input]outputs = []for node in model.graph.node:if node.op_type == 'Debug':#记录调试张量名debug_name = node.attribute[0].s.decode('ASCII')self.output_debug_name.append(debug_name)#添加输入output_name = node.output[0]outputs.append(output_name)#转换Debug 节点为Indentity节点node.np_type = 'Identity'node.domain = ''del node.attribute[:]e = onnx.util.Extractor(model)extracted = e.extrac_model(inputs, outputs)onnx.save(extracted, output_path)
在PyTorch模型中插入debug方法后,我们可以得到一个包含了若干Debug节点的ONNX模型。 但是这个ONNX模型不是我们最终拿来执行的模型。为了得到Debug节点的输出(即调试张量的值),我们需要做三项处理以提取出一个可运行的调试模型:
- 记录每个调试张量的调试名,为之后对齐PyTorch、ONNX调试张量值做准备。
- 把所有Debug节点的输出加入到整个模型的输出中,这样在运行模型后就能得到这些中间节点的输出了。
- 自定义的Debug节点在推理引擎中时没有实现的,为了让处理湖的ONNX模型运行起来,需要把Debug节点转化成可运行的Identity(恒等)节点。
完成了这三项处理后,我们才能进行模型提取。下面,我们来看看模型提取和这几项处理是怎么实现的。
首先,看一下和模型提取有关的代码:
model = onnx.load(input_path)
inputs = [input.name for input in model.graph.input]
outputs = [] # 获取 outputs
... # 调用提取模型 API
e = onnx.utils.Extractor(model)
extracted = e.extract_model(inputs, outputs) # 保存模型
onnx.save(extracted, output_path)
在提取模型时,我们要准备新模型的输入和输出。输入张量inputs还是保持原状,而输出张量outputs会在之后填入Debug节点的输出。获取完outputs后,我们调用提取模型的API,得到处理过后的模型,并保存此模型。
接着,看一下主处理逻辑:
for node in model.graph.node: if node.op_type == 'Debug': ...
为了获取和Debug节点相关的信息,我们需要遍历ONNX模型的所有节点,找出那些类型为Debug的节点,对这些节点执行操作。
下面的代码实现了记录调试张量名:
debug_name = node.attribute[0].s.decode('ASCII')
self.output_debug_name.append(debug_name)
这段代码的作用是:从节点的第一个属性(即name)中取出调试名信息,并存入output_debug_name中。节点第一个属性的值可以通过node.attribute[0]获得。由于name是属性是字符串,这里要用.s获取属性的字符串值。又由于ONNX是以二进制的形式保存所有数据的,这里要用.decode(‘ASCII’)把二进制字符串转成一个文本字符串。
接下来的代码用于填写新模型输出outputs:
output_name = node.output[0]
outputs.append(output_name)
node.output[0]就是debug节点的输出张量在ONNX里的名称,把这个名称加入新模型的输出后,只需要运行新模型,就可以得到该输出张量的值了。
最后这段代码用于更改Debug节点的类型:
node.op_type = 'Identity'
node.domain = ''
del node.attribute[:]
为了消除 ONNX 不支持的 Debug 节点,一种比较简单的方式是直接把 Debug 节点修改成不执行任何操作的 Indentity
类型的节点。为了做这个转换,我们要先修改节点类型名 node.op_type
为Identity
,再把节点的域(即命名空间)node.domain
修改成空,最后删除节点的所有属性,保证节点符合 ONNX 的规范。
回忆一下,如果一个节点的 domain
为空,这个节点就会被当成一个 ONNX 原生算子节点。
运行调试模型
在生成调试节点时, 我们已经顺便记录了Pytorch模型调试张量的值,下一步,我们要运行调试模型,记录ONNX模型调试张量的值。实现如下:
import onnxruntime
def run_debug_model(self, input, debug_model):sess = onnxruntime.InferenceSession(debug_model, providers=['CPUExecutionProvider'])onnx_outputs = sess.run(None, input)for name, value in zip(self.output_debug_name, onnx_outputs):self.onnx_value[name] = value
在运行调试模型前,我们要给出模型输入、模型名这两个参数。根据这些参数,run_debug_model会调用ONNX runtime的API,对ONNX模型进行推理。在得到了ONNX模型的输出后,用使用上一部得到的output_debug_name信息,填写onnx_value,把ONNX中间运算结果绑定到调试名上。完成这些步骤之后,我们就有足够的信息做精度对齐了。
def print_debug_result(self):for name in self.torch_value.keys():if name in self.onnx_value:mse = np.mean((self.torch_value[name] - self.onnx_value[name])**2)
最后,我们同时遍历 self.torch_value
和 self.onnx_value
这两个词典,比较同一个张量在 PyTorch 模型和 ONNX 模型里的输出。在循环体中,我们只需要使用 self.torch_value[name]
和 self.onnx_value[name]
就可以访问同一个张量在 PyTorch 里的值和在 ONNX 里的值。作为示例,这里我们可以计算二者的均方误差 mse
,以此为精度对齐的依据。
使用方法
实现了精度对齐工具后,我们来看看该怎么把这个工具用起来。
现在,假设我们得到了一个这样的模型:
class Model(torch.nn.Module): def __init__(self): super().__init__() self.convs1 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3, 1, 1), torch.nn.Conv2d(3, 3, 3, 1, 1), torch.nn.Conv2d(3, 3, 3, 1, 1)) self.convs2 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3, 1, 1), torch.nn.Conv2d(3, 3, 3, 1, 1)) self.convs3 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3, 1, 1), torch.nn.Conv2d(3, 3, 3, 1, 1)) self.convs4 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3, 1, 1), torch.nn.Conv2d(3, 3, 3, 1, 1), torch.nn.Conv2d(3, 3, 3, 1, 1)) def forward(self, x): x = self.convs1(x) x = self.convs2(x) x = self.convs3(x) x = self.convs4(x) return x torch_model = Model()
没错!这就是本文开头展示的那个全卷积网络。现在我们想对齐 convs1
至 convs4
这每一个卷积块的输出精度,该怎么使用之前写好的精度对齐工具呢?
首先,我们生成管理类 Debugger
的一个实例:
debugger = Debugger()
之后,我们要设法把 Debug 节点插入原模型:
from types import MethodType def new_forward(self, x): x = self.convs1(x) x = debugger.debug(x, 'x_0') x = self.convs2(x) x = debugger.debug(x, 'x_1') x = self.convs3(x) x = debugger.debug(x, 'x_2') x = self.convs4(x) x = debugger.debug(x, 'x_3') return x torch_model.forward = MethodType(new_forward, torch_model)
我们可以为原模型新写一个 forward
函数。在这个新的函数函数中,我们可以通过 debugger.debug
把每一个输出张量标记起来,并各取一个不重复的调试名。
有了 new_forward
函数,我们需要使用 MethodType
这个 Python API 把这个函数变成模型实例 torch_model
的一个成员方法,确保 torch_model
的 forward
函数能够被正确替换。
实现了”狸猫换太子“般巧妙的操作后,我们就可以使用 PyTorch API 导出一个带有 Debug 节点的 ONNX 模型了:
dummy_input = torch.randn(1, 3, 10, 10)
torch.onnx.export(torch_model, dummy_input, 'before_debug.onnx', input_names=['input'])
由于 torch.onnx.export
模型使用的是跟踪法,模型的 forward
函数会被执行一次, debugger.debug
操作可以把 PyTorch 模型的调试张量输出记录在 debugger.torch_value
里。
这个 before_debug.onnx
模型的部分可视化结果如下:
接下来,我们替换掉所有 Debug 节点,并记录每个 Debug 输出张量的 ONNX 名与调试名的对应关系:
debugger.extract_debug_model('before_debug.onnx', 'after_debug.onnx')
这步操作得到的 after_debug.onnx
模型的部分可视化结果如下:
我们可以使用下面的代码运行这个模型:
debugger.run_debug_model({'input':dummy_input.numpy()}, 'after_debug.onnx')
这样,ONNX 模型的调试张量输出会记录在 debugger.onnx_value
里。
总算,一切准备工作结束了。我们可以轻轻松松地用一行代码输出精度对齐的结果:
debugger.print_debug_result()
这个函数大致会输出以下内容:
x_0 MSE: 8.465450562766819e-16
x_1 MSE: 1.4122021817221354e-16
x_2 MSE: 6.501743508551734e-17
x_3 MSE: 1.7635199492054931e-16
这份输出表明,在这一轮精度对齐测试中,所有模块的精度误差都很小。我们几乎可以认为,ONNX 模型的运行结果等价于 PyTorch 模型的运行结果。
如果有某些模块的误差比较大,我们可以深入子模块,去加更多的 debug 节点,看看是哪一步、哪一个算子出现了问题。