@浙大疏锦行 python day37.

内容:

  • 保存模型只需要保存模型的参数即可,使用的时候直接构建模型再导入参数即可
# 保存模型参数
torch.save(model.state_dict(), "model_weights.pth")# 加载参数(需先定义模型结构)
model = MLP()  # 初始化与训练时相同的模型结构
model.load_state_dict(torch.load("model_weights.pth"))
# model.eval()  # 切换至推理模式(可选)
  • 也可以同时保存模型 + 参数
# 保存整个模型
torch.save(model, "full_model.pth")# 加载模型(无需提前定义类,但需确保环境一致)
model = torch.load("full_model.pth")
model.eval()  # 切换至推理模式(可选)
  • 保存训练状态,用于在训练过程中保存中间状态
# # 保存训练状态
# checkpoint = {
#     "model_state_dict": model.state_dict(),
#     "optimizer_state_dict": optimizer.state_dict(),
#     "epoch": epoch,
#     "loss": best_loss,
# }
# torch.save(checkpoint, "checkpoint.pth")# # 加载并续训
# model = MLP()
# optimizer = torch.optim.Adam(model.parameters())
# checkpoint = torch.load("checkpoint.pth")# model.load_state_dict(checkpoint["model_state_dict"])
# optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# start_epoch = checkpoint["epoch"] + 1  # 从下一轮开始训练
# best_loss = checkpoint["loss"]# # 继续训练循环
# for epoch in range(start_epoch, num_epochs):
#     train(model, optimizer, ...)
  • 对于跨框架保存时,需要保存为onnx文件
  • 早停策略:在训练过程中,训练集的损失不断下降,但是验证集或者测试集的损失反而上升,此时这种情况称为过拟合;针对这种情况,我们可以引入早停策略,用于提前终止训练;
  • 可以在训练达到某一个epoch次数时进行一次验证,将此次结果和历史最佳结果进行比较,如果结果更好则保留,如果不好则会使临时记录变量自增,一旦连续自增到预设值,直接停止。这种设计是为了防止出现震荡波动情况,避免偶然性
  • 可以结合上面的保存断点 breakpoint
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import time
import matplotlib.pyplot as plt
from tqdm import tqdm  # 导入tqdm库用于进度条显示
import warnings
warnings.filterwarnings("ignore")  # 忽略警告信息# 设置GPU设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 加载鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征数据
y = iris.target  # 标签数据# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 归一化数据
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 将数据转换为PyTorch张量并移至GPU
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.LongTensor(y_train).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_test = torch.LongTensor(y_test).to(device)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(4, 10)  # 输入层到隐藏层self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 3)  # 隐藏层到输出层def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 实例化模型并移至GPU
model = MLP().to(device)# 分类问题使用交叉熵损失函数
criterion = nn.CrossEntropyLoss()# 使用随机梯度下降优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
num_epochs = 20000  # 训练的轮数# 用于存储每200个epoch的损失值和对应的epoch数
train_losses = []  # 存储训练集损失
test_losses = []   # 存储测试集损失
epochs = []# ===== 新增早停相关参数 =====
best_test_loss = float('inf')  # 记录最佳测试集损失
best_epoch = 0                 # 记录最佳epoch
patience = 50                # 早停耐心值(连续多少轮测试集损失未改善时停止训练)
counter = 0                    # 早停计数器
early_stopped = False          # 是否早停标志
# ==========================start_time = time.time()  # 记录开始时间# 创建tqdm进度条
with tqdm(total=num_epochs, desc="训练进度", unit="epoch") as pbar:# 训练模型for epoch in range(num_epochs):# 前向传播outputs = model(X_train)  # 隐式调用forward函数train_loss = criterion(outputs, y_train)# 反向传播和优化optimizer.zero_grad()train_loss.backward()optimizer.step()# 记录损失值并更新进度条if (epoch + 1) % 200 == 0:# 计算测试集损失model.eval()with torch.no_grad():test_outputs = model(X_test)test_loss = criterion(test_outputs, y_test)model.train()train_losses.append(train_loss.item())test_losses.append(test_loss.item())epochs.append(epoch + 1)# 更新进度条的描述信息pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})# ===== 新增早停逻辑 =====if test_loss.item() < best_test_loss: # 如果当前测试集损失小于最佳损失best_test_loss = test_loss.item() # 更新最佳损失best_epoch = epoch + 1 # 更新最佳epochcounter = 0 # 重置计数器# 保存最佳模型torch.save(model.state_dict(), 'best_model.pth')else:counter += 1if counter >= patience:print(f"早停触发!在第{epoch+1}轮,测试集损失已有{patience}轮未改善。")print(f"最佳测试集损失出现在第{best_epoch}轮,损失值为{best_test_loss:.4f}")early_stopped = Truebreak  # 终止训练循环# ======================# 每1000个epoch更新一次进度条if (epoch + 1) % 1000 == 0:pbar.update(1000)  # 更新进度条# 确保进度条达到100%if pbar.n < num_epochs:pbar.update(num_epochs - pbar.n)  # 计算剩余的进度并更新time_all = time.time() - start_time  # 计算训练时间
print(f'Training time: {time_all:.2f} seconds')# ===== 新增:加载最佳模型用于最终评估 =====
if early_stopped:print(f"加载第{best_epoch}轮的最佳模型进行最终评估...")model.load_state_dict(torch.load('best_model.pth'))
# ================================# 可视化损失曲线
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()# 在测试集上评估模型
model.eval()
with torch.no_grad():outputs = model(X_test)_, predicted = torch.max(outputs, 1)correct = (predicted == y_test).sum().item()accuracy = correct / y_test.size(0)print(f'测试集准确率: {accuracy * 100:.2f}%')    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/92155.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/92155.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/92155.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ORACLE进阶操作

1 事务 事务的任务便是使数据库从一种状态变换成为另一种状态&#xff0c;这不同于文件系统&#xff0c;它是数据库所特用的。 所有的数据库中&#xff0c;事务只针对DML&#xff08;增删改)&#xff0c;不针对select select只能查看其他事务提交或回滚的数据&#xff0c;不能查…

Modbus 的一些理解

疑问&#xff1a;&#xff08;使用的是Modbustcp&#xff09;我在 Modbus slave 上面设置了slave地址为1&#xff0c;位置为40001的位置的值为1&#xff0c;40001这个位置上面的值是怎么存储的&#xff0c;存储在哪里的&#xff1f;他们是怎么进行交互的&#xff1f;在Modbus协…

【运动控制框架】WPF运动控制框架源码,可用于激光切割机,雕刻机,分板机,点胶机,插件机等设备,开箱即用

WPF运动控制框架源码&#xff0c;可用于激光切割机&#xff0c;雕刻机&#xff0c;分板机&#xff0c;点胶机&#xff0c;插件机等设备&#xff0c;考虑到各运动控制硬件不同&#xff0c;视觉应用功能&#xff08;应用视觉软件&#xff09;也不同&#xff0c;所以只开发各路径编…

RabbitMQ-日常运维命令

作者介绍&#xff1a;简历上没有一个精通的运维工程师。请点击上方的蓝色《运维小路》关注我&#xff0c;下面的思维导图也是预计更新的内容和当前进度(不定时更新)。中间件&#xff0c;我给它的定义就是为了实现某系业务功能依赖的软件&#xff0c;包括如下部分:Web服务器代理…

【Linux基础知识系列】第九十篇 - 使用awk进行文本处理

在Linux系统中&#xff0c;文本处理是一个常见的任务&#xff0c;尤其是在处理日志文件、配置文件和数据文件时。awk是一个功能强大的文本处理工具&#xff0c;广泛用于数据提取、分析和格式化。它不仅可以处理简单的文本文件&#xff0c;还可以处理复杂的结构化数据&#xff0…

第二十七天(数据结构:图)

图&#xff1a;是一种非线性结构形式化的描述: G{V,R}V:图中各个顶点元素(如果这个图代表的是地图&#xff0c;这个顶点就是各个点的地址)R:关系集合&#xff0c;图中顶点与顶点之间的关系(如果是地图&#xff0c;这个关系集合可能就代表的是各个地点之间的距离)在顶点与顶点…

数据赋能(386)——数据挖掘——迭代过程

概述重要性如下&#xff1a;提升挖掘效果&#xff1a;迭代过程能不断优化数据挖掘模型&#xff0c;提高挖掘结果的准确性和有效性&#xff0c;从而更好地满足业务需求。适应复杂数据&#xff1a;数据往往具有复杂性和多样性&#xff0c;通过迭代可以逐步探索和适应数据的特点&a…

什么是键值缓存?让 LLM 闪电般快速

一、为什么 LLMs 需要 KV 缓存&#xff1f;大语言模型&#xff08;LLMs&#xff09;的文本生成遵循 “自回归” 模式 —— 每次仅输出一个 token&#xff08;如词语、字符或子词&#xff09;&#xff0c;再将该 token 与历史序列拼接&#xff0c;作为下一轮输入&#xff0c;直到…

16.Home-懒加载指令优化

问题1&#xff1a;逻辑书写位置不合理问题2&#xff1a;重复监听问题已经加载完毕但是还在监听

Day116 若依融合mqtt

MQTT 1.MQTT协议概述MQTT是一种基于发布/订阅模式的轻量级消息传输协议&#xff0c;设计用于低带宽、高延迟或不稳定的网络环境&#xff0c;广泛应用于物联网领域1.1 MQTT协议的应用场景1.智能家居、车联网、工业物联网&#xff1a;MQTT可以用于连接各种家电设备和传感器&#…

PyTorch + PaddlePaddle 语音识别

PyTorch PaddlePaddle 语音识别 目录 概述环境配置基础理论数据预处理模型架构设计完整实现案例模型训练与评估推理与部署性能优化技巧总结 语音识别&#xff08;ASR, Automatic Speech Recognition&#xff09;是将音频信号转换为文本的技术。结合PyTorch和PaddlePaddle的…

施耐德 Easy Altivar ATV310 变频器:高效电机控制的理想选择(含快速调试步骤及常见故障代码)

施耐德 Easy Altivar ATV310 变频器&#xff1a;高效电机控制的理想选择&#xff08;含快速调试步骤&#xff09;在工业自动化领域&#xff0c;变频器作为电机控制的核心设备&#xff0c;其性能与可靠性直接影响整个生产系统的效率。施耐德电气推出的 Easy Altivar ATV310 变频…

搭建邮件服务器概述

一、电子邮件应用解析标准邮件服务器&#xff08;qq邮箱&#xff09;&#xff1a;1&#xff09;提供电子邮箱&#xff08;lvbuqq.com&#xff09;及存储空间2&#xff09;为客户端向外发送邮件给其他邮箱&#xff08;diaochan163.com&#xff09;3&#xff09;接收/投递其他邮箱…

day28-NFS

1.每日复盘与今日内容1.1复盘Rsync:本地模式、远程模式&#x1f35f;&#x1f35f;&#x1f35f;&#x1f35f;&#x1f35f;、远程守护模式&#x1f35f;&#x1f35f;&#x1f35f;&#x1f35f;&#x1f35f;安装、配置Rsync启动、测试服务备份案例1.2今日内容NFS优缺点NFS服…

二叉搜索树--通往高阶数据结构的基石

目录 前言&#xff1a; 1、二叉搜索树的概念 2、二叉搜索树性能分析 3、二叉搜索树的实现 BinarySelectTree.h test.cpp 4、key 和 key / value&#xff08; map 和 set 的铺垫 &#xff09; 前言&#xff1a; 又回到数据结构了&#xff0c;这次我们将要学习一些复杂的…

Profinet转Ethernet IP网关接入五轴车床上下料机械手控制系统的配置实例

本案例为西门子1200PLC借助PROFINET转EtherNet/IP网关与搬运机器人进行连接的配置案例。所需设备包括&#xff1a;西门子1200PLC、Profinet转EtherNet/IP网关以及发那科&#xff08;Fanuc&#xff09;机器人。开启在工业自动化控制领域广泛应用、功能强大且专业的西门子博图配置…

专题二_滑动窗口_长度最小的子数组

引入&#xff1a;滑动窗口首先&#xff0c;这是滑动窗口的第一道题&#xff0c;所以简短的说一下滑动窗口的思路&#xff1a;当我们题目要求找一个满足要求的区间的时候&#xff0c;且这个区间的left和right指针&#xff0c;都只需要同向移动的时候&#xff0c;就可以使用滑动窗…

解锁高效开发:AWS 前端 Web 与移动应用解决方案详解

告别繁杂的部署与运维&#xff0c;AWS 让前端开发者的精力真正聚焦于创造卓越用户体验。在当今快速迭代的数字环境中&#xff0c;Web 与移动应用已成为企业与用户交互的核心。然而&#xff0c;前端开发者常常面临诸多挑战&#xff1a;用户认证的复杂性、后端 API 的集成难题、跨…

北京JAVA基础面试30天打卡04

1. 单例模式的实现方式及线程安全 单例模式&#xff08;Singleton Pattern&#xff09;确保一个类只有一个实例&#xff0c;并提供一个全局访问点。以下是常见的单例模式实现方式&#xff0c;以及如何保证线程安全&#xff1a; 单例模式的实现方式饿汉式&#xff08;Eager Init…

Redis 缓存三大核心问题:穿透、击穿与雪崩的深度解析

引言在现代互联网架构中&#xff0c;缓存是提升系统性能、降低数据库压力的核心手段之一。而 Redis 作为高性能的内存数据库&#xff0c;凭借其丰富的数据结构、灵活的配置选项以及高效的网络模型&#xff0c;已经成为缓存领域的首选工具。本文将从 Redis 的基本原理出发&#…