文章目录

  • 1 使用现有网络
  • 2 修改网络结构
    • 2.1 添加新层
    • 2.2 替换现有层
  • 3 保存网络模型
    • 3.1 完整保存
    • 3.2 参数保存(推荐)
  • 4 加载网络模型
    • 4.1 加载完整模型文件
    • 4.2 加载参数文件
  • 5 Checkpoint
    • 5.1 保存 Checkpoint
    • 5.2 加载 Checkpoint

本文环境:

  • Pycharm 2025.1
  • Python 3.12.9
  • Pytorch 2.6.0+cu124

​ PyTorch 通过torchvision.models提供预训练模型(如 VGG16)。

​ 网址链接:https://docs.pytorch.org/vision/stable/models.html。

1 使用现有网络

​ 以 VGG16 为例,进入网址:https://docs.pytorch.org/vision/stable/models/generated/torchvision.models.vgg16.html#torchvision.models.vgg16。

image-20250531103635500

方法一:使用随机初始化权重

​ 将 weights 设置为 None,从 0 开始训练自己的网络。

vgg16_false = torchvision.models.vgg16(weights=None)  # 权重随机初始化

方法二:加载预训练权重

​ 也可以使用预训练好的网络参数,加载后可直接使用网络。
这将从官网上下载已训练好的模型文件。

vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)

​ 可打印网络查看其模型结构:

print(vgg16_true)
image-20250531104433678
...
image-20250531104447912

2 修改网络结构

2.1 添加新层

​ 使用add_module在分类器(classifier)后追加全连接层:

vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
image-20250531104536554

2.2 替换现有层

​ 直接修改分类器的最后一层(如适配 CIFAR10 的 10 分类任务):

vgg16_false.classifier[6] = nn.Linear(4096, 10)  # 替换第6层
image-20250531104551228

3 保存网络模型

​ 使用torch.save()方法保存网络模型。文件扩展名推荐使用.pt.pth

3.1 完整保存

​ 将模型类和参数一并保存到文件中。

torch.save(vgg16, 'vgg16_method1.pth')  # 包含模型类和参数
  • 优点:加载时无需重新定义模型结构。
  • 缺点:文件较大,且依赖原始代码环境(见 4.1 节)。

3.2 参数保存(推荐)

​ 仅保存参数字典到文件中。

torch.save(vgg16.state_dict(), 'vgg16_method2.pth')  # 仅保存参数字典
  • 优点:文件小,灵活性强,适合生产部署。

示例

import torch
import torchvision.models
from torch import nnvgg16 = torchvision.models.vgg16(weights=None)# 保存方式 1,模型结构 + 模型参数
torch.save(vgg16, 'vgg16_method1.pth')# 保存方式 2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')

4 加载网络模型

​ 使用torch.load()方法加载网络模型。

4.1 加载完整模型文件

​ 加载完整模型时,需将 weights_only 参数设置为 False。

model = torch.load('vgg16_method1.pth', weights_only=False)  # 需确保模型类已定义

​ 模型打印结果如下:

print(model)
image-20250531111142517

注意

​ 若保存自定义模型,加载时必须确保环境中也有该模型的定义,否则会出现报错。

  • model_save.py

    # model_save.pyimport torch
    from torch import nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = nn.Conv2d(3, 64, 3)def forward(self, x):return self.conv1(x)model = MyModel()
    torch.save(model, 'my_model_method1.pth')
    
  • model_load.py

    import torchmodel = torch.load('my_model_method1.pth', weights_only=False)  # 报错,找不到 MyModel 的定义
    

    先运行 model_save.py,再运行 model_load.py,则会出现以下报错:

image-20250531110244566

4.2 加载参数文件

​ 首先,使用torch.load()方法加载网络模型。

​ 使用模型时,需先创建匹配的网络结构,再使用model.load_state_dict()加载参数数据。

vgg16 = torchvision.models.vgg16(weights=None)
model_dict = torch.load('vgg16_method2.pth')
vgg16.load_state_dict(model_dict)  # 需结构匹配

​ 模型打印结果是参数字典:

print(model_dict)
image-20250531111411199

注意

​ 模型保存时若在 GPU 上,加载时需指定 map_location 为 cup。

torch.load('model.pth', map_location=torch.device('cpu'))

​ 将参数加载到模型后,手动迁移到 GPU:

model = MyModel()
model.load_state_dict(model_dict)
model.to('cuda:0')

5 Checkpoint

​ 使用 Checkpoint 可以在训练过程中定期保存模型的状态,以便在中断后可以恢复训练,或者在测试时使用最终的模型。文件扩展名推荐使用.tar

5.1 保存 Checkpoint

​ 要保存一个模型的 Checkpoint,通常需要保存以下数据:

  • 模型的 state_dict(状态字典);
  • 优化器的状态;
  • 额外的信息,如 epoch 等。
import torch# 假设 model 是你的模型,optimizer 是你的优化器
checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss
}# 保存checkpoint
torch.save(checkpoint, 'checkpoint.tar')

5.2 加载 Checkpoint

​ 加载 Checkpoint,首先需要加载文件,然后将其内容恢复到模型和优化器的状态中。

# 假设 model 和 optimizer 是你的模型和优化器实例
checkpoint = torch.load('checkpoint.tar')model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']# 如果需要,可以继续训练
model.train()  # 确保模型处于训练模式

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/907924.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/907924.shtml
英文地址,请注明出处:http://en.pswp.cn/news/907924.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

批量导出CAD属性块信息生成到excel——CAD C#二次开发(插件实现)

本插件可实现批量导出文件夹内大量dwg文件的指定块名的属性信息到excel,效果如下: 插件界面: dll插件如下: 使用方法: 1、获取此dll插件。 2、cad命令行输入netload ,加载此dll(要求AutoCAD&…

在Linux环境里面,Python调用C#写的动态库,如何实现?

在Linux环境中,Python可以通过pythonnet(CLR的Python绑定)或subprocess调用C#动态库。以下是两种方法的示例: 方法1:使用pythonnet(推荐) 前提条件 安装Mono或.NET Core运行时安装pythonnet包…

小程序跳转H5或者其他小程序

1. h5跳转小程序有两种情况 &#xff08;1&#xff09;从普通浏览器打开的h5页面跳转小程序使用wx-open-launch-weapp可以实现h5跳转小程序 <wx-open-launch-weappstyle"display:block;"v-elseid"launch-btn":username"wechatYsAppid":path…

性能优化 - 案例篇:缓冲区

文章目录 Pre1. 引言2. 缓冲概念与类比3. Java I/O 中的缓冲实现3.1 FileReader vs BufferedReader&#xff1a;装饰者模式设计3.2 BufferedInputStream 源码剖析3.2.1 缓冲区大小的权衡与默认值 4. 异步日志中的缓冲&#xff1a;Logback 异步日志原理与配置要点4.1 Logback 异…

文档整合自动化

主要功能是按照JSON文件&#xff08;Sort.json&#xff09;中指定的顺序合并多个Word文档&#xff08;.docx&#xff09;&#xff0c;并清除文档中的所有超链接。最终输出合并后的文档名为"sorted_按章节顺序.docx"。 主要分为几个部分&#xff1a; 初始化配置 定…

嵌入式(C语言篇)Day13

嵌入式Day13 一段话总结 文档主要介绍带有头指针和尾指针的单链表的实现及操作&#xff0c;涵盖创建、销毁、头插、尾插、按索引/数据增删查、遍历等核心操作&#xff0c;强调头插/尾插时间复杂度为O(1)&#xff0c;按索引/数据操作需遍历链表、时间复杂度为O(n)&#xff0c;并…

【ASR】基于分块非自回归模型的流式端到端语音识别

论文地址:https://arxiv.org/abs/2107.09428 摘要 非自回归 (NAR) 模型在语音处理中越来越受到关注。 凭借最新的基于注意力的自动语音识别 (ASR) 结构,与自回归 (AR) 模型相比,NAR 可以在仅精度略有下降的情况下实现有前景的实时因子 (RTF) 提升。 然而,识别推理需要等待…

RNN循环网络:给AI装上“记忆“(superior哥AI系列第5期)

&#x1f504; RNN循环网络&#xff1a;给AI装上"记忆"&#xff08;superior哥AI系列第5期&#xff09; 嘿&#xff01;小伙伴们&#xff0c;又见面啦&#xff01;&#x1f44b; 上期我们学会了让AI"看懂"图片&#xff0c;今天要给AI装上一个更酷的技能——…

DAY41 CNN

可以看到即使在深度神经网络情况下&#xff0c;准确率仍旧较差&#xff0c;这是因为特征没有被有效提取----真正重要的是特征的提取和加工过程。MLP把所有的像素全部展平了&#xff08;这是全局的信息&#xff09;&#xff0c;无法布置到局部的信息&#xff0c;所以引入了卷积神…

【仿生系统】爱丽丝机器人的设想(可行性优先级较高)

非程序化、能够根据环境和交互动态产生情感和思想&#xff0c;并以微妙、高级的方式表达出来的能力 我们不想要一个“假”的智能&#xff0c;一个仅仅通过if-else逻辑或者简单prompt来模拟情感的机器人。您追求的是一种更深层次的、能够学习、成长&#xff0c;并形成独特“个性…

面向连接的运输:TCP

目录 TCP连接 TCP报文段结构 往返时间估计与超时 可靠数据传输 回退N步or超时重传 超时间隔加倍 快速重传 流量控制 TCP连接管理 三次握手 1. 客户端 → 服务器&#xff1a;SYN 包 2. 服务器 → 客户端&#xff1a;SYNACK 包 3. 客户端 → 服务器&#xff1a;AC…

SpringAI系列 - 升级1.0.0

目录 一、调整pom二、MessageChatMemoryAdvisor调整三、ChatMemory get方法删除lastN参数四、QuestionAnswerAdvisor调整Spring AI发布1.0.0正式版了😅 ,搞起… 一、调整pom <properties><java.version>17</java.version><spring-ai.version>

前端高频面试题2:JavaScript/TypeScript

1.什么是类数组对象 一个拥有 length 属性和若干索引属性的对象就可以被称为类数组对象&#xff0c;类数组对象和数组类似&#xff0c;但是不能调用数组的方法。常见的类数组对象有 arguments 和 DOM 方法的返回结果&#xff0c;还有一个函数也可以被看作是类数组对象&#xff…

Spring Security入门:创建第一个安全REST端点项目

项目初始化与基础配置 创建基础Spring Boot项目 我们首先创建一个名为ssia-ch2-ex1的空项目(该名称与配套源码中的示例项目保持一致)。项目需要添加以下两个核心依赖: org.springframework.bootspring-boot-starter-weborg.springframework.bootspring-boot-starter-secur…

秋招Day12 - 计算机网络 - UDP

说说TCP和UDP的区别&#xff1f; TCP使用无边界的字节流传输&#xff0c;可能发生拆包和粘包&#xff0c;接收方并不知道数据边界&#xff1b;UDP采用数据报传输&#xff0c;数据报之间相互独立&#xff0c;有边界。 应用场景方面&#xff0c;TCP适合对数据的可靠性要求高于速…

【QQ音乐】sign签名| data参数加密 | AES-GCM加密 | webpack (下)

1.目标 网址&#xff1a;https://y.qq.com/n/ryqq/toplist/26 我们知道了 sign P(n.data)&#xff0c;其中n.data是明文的请求参数 2.webpack生成data加密参数 那么 L(n.data)就是密文的请求参数。返回一个Promise {<pending>}&#xff0c;所以L(n.data) 是一个异步函数…

Codeforces Round 1028 (Div. 2)(A-D)

题面链接&#xff1a;Dashboard - Codeforces Round 1028 (Div. 2) - Codeforces A. Gellyfish and Tricolor Pansy 思路 要知道骑士如果没了那么这个人就失去了攻击手段&#xff0c;贪心的来说我们只需要攻击血量少的即可&#xff0c;那么取min比较一下即可 代码 void so…

【存储基础】存储设备和服务器的关系和区别

文章目录 1. 存储设备和服务器的区别2. 客户端访问数据路径场景1&#xff1a;经过服务器处理场景2&#xff1a;客户端直连 3. 服务器作为"中转站"的作用 刚开始接触存储的时候&#xff0c;以为数据都是存放在服务器上的&#xff0c;服务器和存储设备是一个东西&#…

macOS 安装 Grafana + Prometheus + Node Exporter

macOS 安装指南&#xff1a;Grafana Prometheus Node Exporter 目录简介&#x1f680; 快速开始 安装 Homebrew1. 安装 Homebrew2. 更新 Homebrew 安装 Node Exporter使用 Homebrew 安装验证 Node Exporter 安装 Prometheus使用 Homebrew 安装验证安装 安装 Grafana使用 Home…

不可变集合类型转换异常

记录一个异常&#xff1a;class java.util.ImmutableCollections$ListN cannot be cast to class java.util.ArrayList (java.util.ImmutableCollections$ListN and java.util.ArrayList 文章目录 1、原因2、解决方式一3、解决方式二4、关于不可变集合的补充4.1 JDK8和9的对比4…