如果需要在独立于训练脚本的新脚本中部署模型,这种情况模型和权重在内存中不存在,因此需要构造一个模型类的对象,然后将存储的权重加载到模型中。

加载模型参数,验证模型的性能,并在测试数据集上部署模型

from torch import nn
from torchvision import models# 定义一个resnet18模型,不使用预训练参数
model_resnet18 = models.resnet18(pretrained=False)
# 获取模型的全连接层的输入特征数
num_ftrs = model_resnet18.fc.in_features
# 定义分类的类别数
num_classes=10
# 将全连接层的输出特征数改为分类的类别数
model_resnet18.fc = nn.Linear(num_ftrs, num_classes)import torch 
path2weights="./models/resnet18_pretrained.pt"
# 加载预训练的ResNet18模型权重
model_resnet18.load_state_dict(torch.load(path2weights))
# 将ResNet-18模型设置为评估模式
model_resnet18.eval();
# 检查CUDA是否可用
if torch.cuda.is_available():# 如果可用,将设备设置为CUDAdevice = torch.device("cuda")# 将模型移动到CUDA设备上model_resnet18=model_resnet18.to(device)def deploy_model(model,dataset,device, num_classes=10,sanity_check=False):# 获取数据集的长度len_data=len(dataset)# 初始化输出张量y_out=torch.zeros(len_data,num_classes)# 初始化真实标签张量y_gt=np.zeros((len_data),dtype="uint8")# 将模型移动到指定设备model=model.to(device)# 初始化时间列表elapsed_times=[]with torch.no_grad():for i in range(len_data):# 获取数据集中的一个样本x,y=dataset[i]# 将真实标签存入张量y_gt[i]=y# 记录开始时间start=time.time()    # 将输入数据传入模型进行预测yy=model(x.unsqueeze(0).to(device))# 将预测结果存入张量y_out[i]=torch.softmax(yy,dim=1)# 计算预测时间elapsed=time.time()-start# 将预测时间存入列表elapsed_times.append(elapsed)# 如果进行完整性检查,则跳出循环if sanity_check is True:break# 计算平均预测时间inference_time=np.mean(elapsed_times)*1000# 打印平均预测时间print("average inference time per image on %s: %.2f ms " %(device,inference_time))# 返回预测结果和真实标签return y_out.numpy(),y_gt
from torchvision import datasets
import torchvision.transforms as transforms# 数据转换
data_transformer = transforms.Compose([transforms.ToTensor()])path2data="./data"# 加载数据
test0_ds=datasets.STL10(path2data, split='test', download=True,transform=data_transformer)
print(test0_ds.data.shape)

from sklearn.model_selection import StratifiedShuffleSplit# 创建StratifiedShuffleSplit对象,设置分割次数为1,测试集大小为0.2,随机种子为0
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)# 获取test0_ds的索引
indices=list(range(len(test0_ds)))# 获取test0_ds的标签
y_test0=[y for _,y in test0_ds]# 对索引和标签进行分割
for test_index, val_index in sss.split(indices, y_test0):# 打印测试集和验证集的索引print("test:", test_index, "val:", val_index)# 打印测试集和验证集的大小print(len(val_index),len(test_index))

from torch.utils.data import Subset# 从test0_ds中选取val_index索引的子集,赋值给val_ds
val_ds=Subset(test0_ds,val_index)
# 从test0_ds中选取test_index索引的子集,赋值给test_ds
test_ds=Subset(test0_ds,test_index)
# 定义均值
mean=[0.4467106, 0.43980986, 0.40664646]
# 定义标准差
std=[0.22414584,0.22148906,0.22389975]
# 定义一个名为test0_transformer的变量,用于将一系列的图像变换操作组合在一起
test0_transformer = transforms.Compose([# 将图像转换为Tensor类型transforms.ToTensor(),# 对图像进行归一化操作,使用mean和std作为均值和标准差transforms.Normalize(mean, std),])   
# 将test0_transformer赋值给test0_ds的transform属性
test0_ds.transform=test0_transformer
import time
import numpy as np# 调用deploy_model函数,传入model_resnet18,val_ds,device和sanity_check参数,返回y_out和y_gt
y_out,y_gt=deploy_model(model_resnet18,val_ds,device=device,sanity_check=False)
# 打印y_out和y_gt的形状
print(y_out.shape,y_gt.shape)

from sklearn.metrics import accuracy_score# 将y_out中的最大值索引赋值给y_pred
y_pred = np.argmax(y_out,axis=1)
# 打印y_pred和y_gt的形状
print(y_pred.shape,y_gt.shape)# 计算并打印y_pred和y_gt的准确率
acc=accuracy_score(y_pred,y_gt)
print("accuracy: %.2f" %acc)

 

# 部署模型,得到预测结果和真实标签
y_out,y_gt=deploy_model(model_resnet18,test_ds,device=device)# 取出预测结果中概率最大的类别
y_pred = np.argmax(y_out,axis=1)# 计算准确率
acc=accuracy_score(y_pred,y_gt)# 打印准确率
print(acc)

from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
np.random.seed(1)# 定义一个函数,用于显示图像
def imshow(inp, title=None):# 定义图像的均值和标准差mean=[0.4467106, 0.43980986, 0.40664646]std=[0.22414584,0.22148906,0.22389975]# 将图像从tensor转换为numpy数组,并转置inp = inp.numpy().transpose((1, 2, 0))# 将均值和标准差转换为numpy数组mean = np.array(mean)std = np.array(std)# 将图像的像素值进行归一化inp = std * inp + mean# 将像素值限制在0和1之间inp = np.clip(inp, 0, 1)# 显示图像plt.imshow(inp)# 如果有标题,则显示标题if title is not None:plt.title(title)# 暂停0.001秒plt.pause(0.001) # 定义网格大小
grid_size=16
# 随机生成4个索引
rnd_inds=np.random.randint(1,len(test_ds),grid_size)
# 打印随机生成的索引
print("image indices:",rnd_inds)# 根据索引获取对应的图像和标签
x_grid_test=[test_ds[i][0] for i in rnd_inds]
y_grid_test=[(y_pred[i],y_gt[i]) for i in rnd_inds]# 将图像转换为网格
x_grid_test=utils.make_grid(x_grid_test, nrow=4, padding=2)
# 打印网格的形状
print(x_grid_test.shape)# 设置图像的大小
plt.rcParams['figure.figsize'] = (10, 10)
# 显示网格
imshow(x_grid_test,y_grid_test)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/96765.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/96765.shtml
英文地址,请注明出处:http://en.pswp.cn/web/96765.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FS950R08A6P2B 双通道汽车级IGBT模块Infineon英飞凌 电子元器件核心解析

一、核心解析:FS950R08A6P2B 是什么?1. 电子元器件类型FS950R08A6P2B 是英飞凌(Infineon) 推出的一款 950A/800V 双通道汽车级IGBT模块,属于功率半导体模块。它采用 EasyPACK 2B 封装,集成多个IGBT芯片和二…

【系列文章】Linux中的并发与竞争[05]-互斥量

【系列文章】Linux中的并发与竞争[05]-互斥量 该文章为系列文章:Linux中的并发与竞争中的第5篇 该系列的导航页连接: 【系列文章】Linux中的并发与竞争-导航页 文章目录【系列文章】Linux中的并发与竞争[05]-互斥量一、互斥锁二、实验程序的编写2.1驱动…

TensorRT 10.13.3: Limitations

Limitations Shuffle-op can not be transformed to no-op for perf improvement in some cases. For the NCHW32 format, TensorRT takes the third-to-last dimension as the channel dimension. When a Shuffle-op is added like [N, ‘C’, H, 1] -> [‘N’, C, H], the…

Python与Go结合

Python与Go结合的方法Python和Go可以通过多种方式结合使用,通常采用跨语言通信或集成的方式。以下是几种常见的方法:使用CFFI或CGO进行绑定Python可以通过CFFI(C Foreign Function Interface)调用Go编写的库,而Go可以通…

C++ 在 Visual Studio Release 模式下,调试运行与直接运行 EXE 的区别

前言 在 Visual Studio (以下简称 VS) 中开发 C 项目时,我们常常需要在 Debug 和 Release 两种构建模式之间切换。Debug 模式适合开发和调试,而 Release 模式则针对生产环境,进行代码优化以提升性能。然而,即使在 Release 模式下&…

南京方言数据集|300小时高质量自然对话音频|专业录音棚采集|方言语音识别模型训练|情感计算研究|方言保护文化遗产数字化|语音情感识别|方言对话系统开发

引言与背景 随着人工智能技术的快速发展,语音识别和自然语言处理领域对高质量方言数据的需求日益增长。南京方言作为江淮官话的重要分支,承载着丰富的地域文化和语言特色,在语言学研究和方言保护方面具有重要价值。本数据集精心采集了300小时…

基于LSTM深度学习的电动汽车电池荷电状态(SOC)预测

基于LSTM深度学习的电动汽车电池荷电状态(SOC)预测 摘要 电动汽车(EV)的普及对电池管理系统(BMS)提出了极高的要求。电池荷电状态(State of Charge, SOC)作为BMS最核心的参数之一&am…

Golang语言之数组、切片与子切片

一、数组先记住数组的核心特点:盒子大小一旦定了就改不了(长度固定),但盒子里的东西能换(元素值可变)。就像你买了个能装 3 个苹果的铁皮盒,想多装 1 个都不行,但里面的苹果可以换成…

速通ACM省铜第四天 赋源码(G-C-D, Unlucky!)

目录 引言: G-C-D, Unlucky! 题意分析 逻辑梳理 代码实现 结语: 引言: 因为今天打了个ICPC网络赛,导致坐牢了一下午,没什么时间打题目了,就打了一道题,所以,今天我们就只讲一题了&…

数据链路层总结

目录 (一)以太网(IEEE 802.3) (1)以太网的帧格式 (2)帧协议类型字段 ①ARP协议 (横跨网络层和数据链路层的协议) ②RARP协议 (二&#xff…

Scala 新手实战三案例:从循环到条件,搞定基础编程场景

Scala 新手实战三案例:从循环到条件,搞定基础编程场景 对 Scala 新手来说,单纯记语法容易 “学完就忘”,而通过小而精的实战案例巩固知识点,是掌握语言的关键。本文精选三个高频基础场景 ——9 乘 9 乘法口诀表、成绩等…

java学习笔记----标识符与变量

1.什么是标识符?Java中变量、方法、类等要素命名时使用的字符序列,称为标识符。 技巧:凡是自己可以起名字的地方都叫标识符。 比如:类名、方法名、变量名、包名、常量名等 2.标识符的命名规则由26个英文字母大小写,0-9,或$组成 数字不可以开…

AI产品经理面试宝典第93天:Embedding技术选型与场景化应用指南

1. Embedding技术演进全景解析 1.1 稀疏向量:关键词匹配的基石 1.1.1 问:请说明稀疏向量的适用场景及技术特点 答:稀疏向量适用于关键词精确匹配场景,典型实现包括TF-IDF、BM25和SPLADE。其技术特征表现为50,000+高维向量且95%以上位置为零值,通过余弦或点积计算相似度…

【Mermaid.js】从入门到精通:完美处理节点中的空格、括号和特殊字符

文章标签: Mermaid, Markdown, 前端开发, 数据可视化, 流程图 文章摘要: 你是否在使用 Mermaid.js 绘制流程图时,仅仅因为节点文本里加了一个空格或括号,整个图就渲染失败了?别担心,这几乎是每个 Mermaid 新…

多技术融合提升环境生态水文、土地土壤、农业大气等领域的数据分析与项目科研水平

一:空间数据获取与制图1.1 软件安装与应用1.2 空间数据介绍1.3海量空间数据下载1.4 ArcGIS软件快速入门1.5 Geodatabase地理数据库二:ArcGIS专题地图制作2.1专题地图制作规范2.2 空间数据的准备与处理2.3 空间数据可视化:地图符号与注记2.4 研…

【音视频】Android NDK 与.so库适配

一、名词解析 名词全称核心说明Android NDKNative Development Kit在SDK基础上增加“原生”开发能力,支持使用C/C编写代码,用于开发需要调用底层能力的模块(如音视频、加密算法等).so库Shared Object即共享库,由NDK编…

SpringBoot 轻量级一站式日志可视化与JVM监控

一、项目初衷Java 应用开发的同学都知道,项目上线后,日志的可视化查询与 JVM 的可视化监控是一件非常重要的事。 市面上成熟方案一般是采用 ELK/EFK 实现日志可视化,采用 Actuator Prometheus Grafana 实现 JVM 监控。 这两套都是非常优秀的…

【Leetcode hot 100】101.对称二叉树

问题链接 101.对称二叉树 问题描述 给你一个二叉树的根节点 root , 检查它是否轴对称。 示例 1: 输入:root [1,2,2,3,4,4,3] 输出:true 示例 2: 输入:root [1,2,2,null,3,null,3] 输出:…

Zynq开发实践(FPGA之选择开发板)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】我们之所以选用zynq开发板,就在于它支持arm软件开发,也支持fpga开发,甚至可以运行linux,这是之前没有…

Flutter Riverpod 3.0 发布,大规模重构下的全新状态管理框架

在之前的 《注解模式下的 Riverpod 有什么特别之处》我们聊过 Riverpod 2.x 的设计和使用原理,同时当时我们就聊到作者已经在开始探索 3.0 的重构方式,而现在随着 Riverpod 3.0 的发布,riverpod 带来了许多细节性的变化。 当然,这…