一、概述

  机器学习模型的训练通常在Python环境下完成,而现实生产环境的复杂性和多样性使得模型的部署成为一个值得关注的重点。不同应用场景下有不同适应的实现方式,这里主要介绍通过一种通用中间格式——ONNX(Open Neural Network Exchange),来实现机器学习模型在C++平台的部署。

二、步骤

  s1. Python环境中安装onnxruntime、skl2onnx工具模块;

  s2. Python环境中训练机器学习模型;

  s3. 将训练好的模型保存为.onnx格式的模型文件;

  s4. C++环境中安装Microsoft.ML.OnnxRuntime程序包;
(Visual Studio 2022中可通过项目->管理NuGet程序包完成快捷安装)

  S5. C++环境中加载模型文件,完成功能开发。

三、示例

  使用 Python 训练一个线性回归模型并将其导出为 ONNX 格式的文件,在C++环境下完成对模型的部署和推理。

1.Python训练和导出

(环境:Python 3.11,scikit-learn 1.6.1,onnxruntime 1.22.0,skl2onnx 1.19.1)

import numpy as np
import onnxruntime as ort
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType# 生成示例数据
X, y = make_regression(n_samples=100, n_features=5, random_state=42)# 训练线性回归模型
model = LinearRegression()
model.fit(X, y)# 定义输入格式
initial_type = [('input', FloatTensorType([None, 5]))]# 转换模型为 ONNX 格式
onnx_model = convert_sklearn(model, initial_types=initial_type)# 保存 ONNX 模型
with open("linear_regression.onnx", "wb") as f:f.write(onnx_model.SerializeToString())print("\n模型已保存为: linear_regression.onnx\n")# 测试导出的模型
ort_session = ort.InferenceSession("linear_regression.onnx")
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name# 创建一个测试样本
test_input = np.array([0.1, 0.2, 0.3, 0.4, 0.5]).reshape(1,5).astype(np.float32)# 运行推理
results = ort_session.run([output_name], {input_name: test_input})print(f"测试输入: {test_input}")
print(f"预测结果: {results[0]}")

在这里插入图片描述

2. C++ 部署和推理

(环境:C++ 14,Microsoft.ML.OnnxRuntime 1.22.0)

#include <iostream>
#include <vector>
#include <string>
#include <memory>
#include <onnxruntime_cxx_api.h>int main() {// 初始化环境Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXExample");// 初始化会话选项Ort::SessionOptions session_options;session_options.SetIntraOpNumThreads(1);session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);// 加载模型std::wstring model_path = L"linear_regression.onnx";Ort::Session session(env, model_path.c_str(), session_options);// 获取输入信息Ort::AllocatorWithDefaultOptions allocator;size_t num_inputs = session.GetInputCount();size_t num_outputs = session.GetOutputCount();// 假设只有一个输入和一个输出if (num_inputs != 1 || num_outputs != 1) {std::cerr << "模型必须有且仅有一个输入和一个输出" << std::endl;return 1;}// 获取输入名称、类型和形状std::string input_name = session.GetInputNameAllocated(0, allocator).get();Ort::TypeInfo input_type_info = session.GetInputTypeInfo(0);auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();ONNXTensorElementDataType input_type = input_tensor_info.GetElementType();std::vector<int64_t> input_dims = input_tensor_info.GetShape();// 获取输出名称std::string output_name = session.GetOutputNameAllocated(0, allocator).get();// 创建输入数据std::vector<float> input_data = { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f };size_t input_size = 5;// 创建输入张量std::vector<int64_t> input_shape = { 1, static_cast<int64_t>(input_size) };auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_data.data(),input_data.size(), input_shape.data(), 2);// 验证输入张量是否为张量if (!input_tensor.IsTensor()) {std::cerr << "创建的输入不是张量类型" << std::endl;return 1;}// 运行模型std::vector<const char*> input_names = { input_name.c_str() };std::vector<const char*> output_names = { output_name.c_str() };std::vector<Ort::Value> outputs = session.Run(Ort::RunOptions{ nullptr },input_names.data(),&input_tensor,1,output_names.data(),1);// 获取输出结果float* output_data = outputs[0].GetTensorMutableData<float>();Ort::TensorTypeAndShapeInfo output_info = outputs[0].GetTensorTypeAndShapeInfo();std::vector<int64_t> output_dims = output_info.GetShape();// 输出结果std::cout << "输入数据: ";for (float val : input_data) {std::cout << val << " ";}std::cout << std::endl;std::cout << "预测结果: ";for (size_t i = 0; i < output_info.GetElementCount(); ++i) {std::cout << output_data[i] << " ";}std::cout << std::endl;return 0;
}

在这里插入图片描述



End.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/90403.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/90403.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/90403.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

保姆级安装 Ruby 环境下载及安装教程, RubyInstaller下载及安装教程

一、下载安装 RubyInstaller 1.打开 RubyInstaller 官网&#xff1a;https://rubyinstaller.org/ 点击跳转, 官网界面如下图&#xff1a; 点击下载最新的 RubyDevkit 版本&#xff08;如 RubyDevkit 3.4.X (x64) &#xff09;。如下图所示&#xff1a; 注意点&#xff1a;如果…

SQL 一键生成 Go Struct!支持字段注释、类型映射、结构体命名规范

SQL 一键生成 Go Struct&#xff01;支持字段注释、类型映射、结构体命名规范 在 Golang 开发中&#xff0c;尤其是操作数据库时&#xff0c;我们经常会遇到这种场景&#xff1a; ✅ 拿到数据库建表 SQL&#xff0c;却要手动写 Go struct✅ 字段几十个、类型复杂&#xff0c;…

Web 前端框架选型:React、Vue 和 Angular 的对比与实践

Web 前端框架选型&#xff1a;React、Vue 和 Angular 的对比与实践 选择前端框架就像选择一个长期合作伙伴。错误的选择可能会让你的项目在未来几年内背负沉重的技术债务&#xff0c;而正确的选择则能让开发效率飞速提升。 经过多年的项目实践&#xff0c;我发现很多新人在框架…

C# 值拷贝、引用拷贝、浅拷贝、深拷贝

值拷贝定义&#xff1a;直接复制变量的值&#xff0c;适用于基本数据类型&#xff08;如int, float, char等&#xff09;。在 C# 中&#xff0c;值类型&#xff08;基本数据类型和结构体&#xff09;默认使用值拷贝。特点&#xff1a;创建原始值的完全独立副本&#xff0c;修改…

深度学习图像分类数据集—百种鸟类识别分类

该数据集为图像分类数据集&#xff0c;适用于ResNet、VGG等卷积神经网络&#xff0c;SENet、CBAM等注意力机制相关算法&#xff0c;Vision Transformer等Transformer相关算法。 数据集信息介绍&#xff1a;525种鸟类识别分类 训练数据集总共有84635张图片&#xff0c;每个文件夹…

零基础 “入坑” Java--- 八、类和对象(一)

文章目录一、初识面向对象二、类的定义和使用1.认识类2.类的定义格式三、类的实例化四、this引用五、对象的构造及初始化1.有关初始化2.构造方法3.就地初始化一、初识面向对象 Java是一门纯面向对象的语言&#xff08;OOP&#xff09;&#xff0c;在面向对象的世界里&#xff…

数字孪生技术引领UI前端设计新篇章:智能物联网的深度集成

hello宝子们...我们是艾斯视觉擅长ui设计、前端开发、数字孪生、大数据、三维建模、三维动画10年经验!希望我的分享能帮助到您!如需帮助可以评论关注私信我们一起探讨!致敬感谢感恩!一、引言&#xff1a;数字孪生与物联网的共生革命在智能设备爆发式增长的今天&#xff0c;传统…

代码审计-shiro漏洞分析

一、关于shiro介绍 简单讲&#xff0c;shiro是apache旗下的一个Java安全框架&#xff0c;轻量级简单易上手&#xff0c;框架提供很多功能接口&#xff0c;常见的身份认证 、权限认证、会话管理、Remember 记住功能、加密等等。 二、漏洞分析 1.CVE-2019-12422-shiro550 漏洞原理…

EF提高性能(查询禁用追踪)(关闭延迟加载)

EF默认是支持延迟加载的&#xff0c;在加载一个表的数据时&#xff0c;会把关联表的数据一并加载&#xff0c;这样会影响性能。 一般建议关闭延迟加载可以提高EF加载的性能。还有其他方法提高性能&#xff08;查询禁用追踪&#xff09; 如果要实现延迟加载&#xff0c;必须满足…

Leetcode+JAVA+贪心III

134.加油站在一条环路上有 n 个加油站&#xff0c;其中第 i 个加油站有汽油 gas[i] 升。你有一辆油箱容量无限的的汽车&#xff0c;从第 i 个加油站开往第 i1 个加油站需要消耗汽油 cost[i] 升。你从其中的一个加油站出发&#xff0c;开始时油箱为空。给定两个整数数组 gas 和 …

Qt信号与槽机制及动态调用

Qt信号与槽机制及动态调用一、信号与槽1、Qt信号与槽机制概述2、信号与槽的基本使用3、信号与槽的特性4、使用Lambda表达式作为槽5、信号与槽的参数传递6、注意事项二、动态调用机制1、基本用法2、示例代码3、带参数的调用4、返回值处理5、信号与槽的动态连接6、动态方法调用7、…

K8s系列之:Kubernetes 的 OLM

K8s系列之:Kubernetes 的 OLM 什么是 Kubernetes 的 OLM什么是Kubernetes中的OperatorOLM 的功能OLM 的核心组件OLM优势OLM 的工作原理OLM 与 OperatorHub 的关系OLM示例场景什么是CRDoperator 和 CRD的关系为什么需要 CRD 和 OperatorCRD定义资源类型DebeziumServer如何使用d…

前端-HTML-day2

目录 1、无序列表 2、有序列表 3、定义列表 4、表格-基本使用 5、表格-结构标签 6、表格-合并单元格 7、表单-input基本使用 8、表单-input占位文本 9、表单-单选框 10、表单-上传多个文件 11、表单-多选框 12、表单-下拉菜单 13、表单-文本域 14、表单-label标签…

两种方式清除已经保存的git账号密码

方式一随便选择一个文件夹&#xff0c;然后鼠标右键-》TortoiseGit ->设置选择已保存的数据-》认证数据-》清除-》点击确定方式二 控制面板\用户帐户\凭据管理器-》windows凭据普通凭据-》找到git信息-》选择删除

Using Spring for Apache Pulsar:Message Production

1. Pulsar Template在Pulsar生产者端&#xff0c;Spring Boot自动配置提供了一个用于发布记录的PulsarTemplate。该模板实现了一个名为PulsarOperations的接口&#xff0c;并提供了通过其合约发布记录的方法。这些send API方法有两类&#xff1a;send和sendAsync。send方法通过…

CSS揭秘:10.平行四边形

前置知识&#xff1a;基本的css变形一、平行四边形 要实现一个平行四边形&#xff0c;可以使用CSS的skew变形属性来倾斜元素。 transform: skewX(-45deg);图-1显示容器和内容都出现了倾斜&#xff0c;该如何解决这个问题&#xff1f; 二、嵌套方案 我们通过将内容嵌套 div 并使…

深度学习 必然用到的 线性代数知识

把标量到张量、点积到范数全串起来&#xff0c;帮你从 0 → 1 搭建 AI 数学底座 &#x1f680; 1 标量&#xff1a;深度学习的最小单元 标量 就是一维空间里的“点”&#xff0c;只有大小没有方向。例如温度 52 F、学习率 0.001。 记号&#xff1a;普通小写 x&#xff1b;域&am…

OpenGL ES 纹理以及纹理的映射

文章目录开启纹理创建纹理绑定纹理生成纹理纹理坐标图像配置线性插值重复效果限制拉伸完整代码在 Android OpenGL ES 中使用纹理&#xff08;Texture&#xff09;可以显著提升图形渲染的质量和效率。以下是使用纹理的主要好处&#xff1a; 增强视觉真实感 纹理可以将复杂的图像…

从金字塔到个性化路径:AI 正在重新定义学习方式

几十年来&#xff0c;我们的教育系统始终遵循着一条熟悉的路线&#xff1a; 从小学、初中、高中&#xff0c;再到大学和研究生。这条标准化的路径&#xff08;K-12 到研究所&#xff09;结构清晰&#xff0c;却也缓慢。但在当今这个信息爆炸、知识快速更新、个性化需求高涨的时…

产品经理岗位职责拆解

以下是产品经理岗位职责的详细分解表&#xff0c;涵盖工作内容、核心动作及输出成果&#xff1a;岗位职责具体工作内容输出成果1. 日常版本迭代管理需求分析及PRD产出协调资源推动产品上线- 收集业务/用户需求&#xff0c;分析可行性及优先级- 撰写PRD文档&#xff0c;明确功能…