一个基于 PyTorch 的完整模型训练流程

flyfish

训练步骤具体操作目的
1. 训练前准备设置随机种子、配置超参数(batch size、学习率等)、选择计算设备(CPU/GPU)确保实验可复现;统一控制训练关键参数;利用硬件加速训练
2. 数据预处理与加载对数据进行标准化/归一化、转换为张量;用DataLoader按batch加载数据统一输入格式,适配模型要求;高效分批读取数据,减少内存占用
3. 初始化组件定义模型结构并加载到计算设备;选择损失函数(如交叉熵)和优化器(如Adam)搭建训练核心框架:模型负责预测,损失函数量化误差,优化器负责参数更新
4. 训练循环(每个epoch)逐轮迭代优化模型参数
4.1 模型切换为训练模式model.train()启用dropout、批量归一化的训练模式,确保梯度计算有效
4.2 遍历训练数据(每个batch)逐批更新参数
4.2.1 清零梯度optimizer.zero_grad()消除历史梯度累积,确保当前batch的梯度计算独立
4.2.2 前向传播output = model(data)用当前模型参数对输入数据做预测,得到输出结果
4.2.3 计算损失loss = criterion(output, target)量化预测结果与真实标签的差距,作为优化目标
4.2.4 反向传播loss.backward()从损失值反向推导,计算所有可训练参数的梯度(参数对损失的影响程度)
4.2.5 参数更新optimizer.step()根据梯度,按优化器规则调整模型参数,减小损失
4.3 记录训练指标保存每个epoch的训练损失、准确率跟踪模型在训练集上的学习效果
5. 验证(每个epoch后)评估模型泛化能力
5.1 模型切换为评估模式model.eval()关闭dropout、固定批量归一化参数,确保评估稳定
5.2 关闭梯度计算with torch.no_grad():减少内存占用,加速验证过程(无需计算梯度)
5.3 计算验证指标计算验证损失、准确率评估模型在未见过的数据上的表现,判断泛化能力
6. 模型保存保存表现最优的模型参数(如验证准确率最高时)留存最佳模型,便于后续部署或继续训练
7. 训练后分析绘制损失/准确率曲线,统计训练时间直观展示训练过程,分析模型收敛状态和效率

前向传播→计算损失→反向传播→参数优化

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm
import time# 设置随机种子,保证结果可复现
def set_seed(seed=42):torch.manual_seed(seed)torch.cuda.manual_seed(seed)np.random.seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False# 定义超参数
class Config:def __init__(self):self.batch_size = 64self.learning_rate = 0.001self.epochs = 10self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.save_path = './models'self.log_interval = 100# 定义简单的卷积神经网络模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)self.relu = nn.ReLU()self.dropout = nn.Dropout(0.5)def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 64 * 7 * 7)  # 展平x = self.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x# 准备数据
def prepare_data(config):# 定义数据变换transform = transforms.Compose([ToTensor(),transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差])# 加载MNIST数据集train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform)test_dataset = datasets.MNIST(root='./data',train=False,download=True,transform=transform)# 创建数据加载器train_loader = DataLoader(train_dataset,batch_size=config.batch_size,shuffle=True,num_workers=2)test_loader = DataLoader(test_dataset,batch_size=config.batch_size,shuffle=False,num_workers=2)return train_loader, test_loader# 训练函数
def train(model, train_loader, criterion, optimizer, config, epoch):model.train()  # 设置为训练模式train_loss = 0.0correct = 0total = 0# 使用tqdm显示进度条pbar = tqdm(train_loader, desc=f'Train Epoch {epoch}')for batch_idx, (data, target) in enumerate(pbar):data, target = data.to(config.device), target.to(config.device)# 清零梯度optimizer.zero_grad()# 前向传播output = model(data)loss = criterion(output, target)# 反向传播和优化loss.backward()optimizer.step()# 统计训练信息train_loss += loss.item()_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()# 打印日志if batch_idx % config.log_interval == 0:pbar.set_postfix({'loss': f'{train_loss/(batch_idx+1):.6f}','accuracy': f'{100.*correct/total:.2f}%'})# 计算平均损失和准确率avg_loss = train_loss / len(train_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy# 验证函数
def validate(model, test_loader, criterion, config):model.eval()  # 设置为评估模式test_loss = 0.0correct = 0total = 0# 不计算梯度with torch.no_grad():for data, target in test_loader:data, target = data.to(config.device), target.to(config.device)output = model(data)test_loss += criterion(output, target).item()# 统计准确率_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()# 计算平均损失和准确率avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalprint(f'\nTest set: Average loss: {avg_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.2f}%)\n')return avg_loss, accuracy# 保存模型
def save_model(model, optimizer, epoch, loss, config):# 创建保存目录if not os.path.exists(config.save_path):os.makedirs(config.save_path)# 保存模型状态torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,}, f"{config.save_path}/model_epoch_{epoch}.pth")print(f"Model saved to {config.save_path}/model_epoch_{epoch}.pth")# 主函数
def main():# 初始化设置set_seed()config = Config()print(f"Using device: {config.device}")# 准备数据train_loader, test_loader = prepare_data(config)# 初始化模型、损失函数和优化器model = SimpleCNN().to(config.device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)# 记录训练过程中的指标history = {'train_loss': [],'train_acc': [],'val_loss': [],'val_acc': []}# 开始训练start_time = time.time()best_val_acc = 0.0for epoch in range(1, config.epochs + 1):print(f"\nEpoch {epoch}/{config.epochs}")print("-" * 50)# 训练train_loss, train_acc = train(model, train_loader, criterion, optimizer, config, epoch)history['train_loss'].append(train_loss)history['train_acc'].append(train_acc)# 验证val_loss, val_acc = validate(model, test_loader, criterion, config)history['val_loss'].append(val_loss)history['val_acc'].append(val_acc)# 保存最佳模型if val_acc > best_val_acc:best_val_acc = val_accsave_model(model, optimizer, epoch, val_loss, config)# 计算总训练时间end_time = time.time()total_time = end_time - start_timeprint(f"Training complete in {total_time:.0f}s ({total_time/config.epochs:.2f}s per epoch)")print(f"Best validation accuracy: {best_val_acc:.2f}%")# 绘制训练曲线plot_training_history(history)# 绘制训练历史
def plot_training_history(history):plt.figure(figsize=(12, 4))# 绘制损失曲线plt.subplot(1, 2, 1)plt.plot(history['train_loss'], label='Training Loss')plt.plot(history['val_loss'], label='Validation Loss')plt.title('Loss Curves')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()# 绘制准确率曲线plt.subplot(1, 2, 2)plt.plot(history['train_acc'], label='Training Accuracy')plt.plot(history['val_acc'], label='Validation Accuracy')plt.title('Accuracy Curves')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.tight_layout()plt.savefig('training_history.png')print("Training history plot saved as 'training_history.png'")plt.show()if __name__ == '__main__':main()
......
--------------------------------------------------
Train Epoch 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:07<00:00, 124.14it/s, loss=0.024222, accuracy=99.22%]Test set: Average loss: 0.0256, Accuracy: 9926/10000 (99.26%)Model saved to ./models/model_epoch_9.pthEpoch 10/10
--------------------------------------------------
Train Epoch 10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:07<00:00, 127.89it/s, loss=0.021473, accuracy=99.31%]Test set: Average loss: 0.0266, Accuracy: 9927/10000 (99.27%)Model saved to ./models/model_epoch_10.pth
Training complete in 85s (8.52s per epoch)
Best validation accuracy: 99.27%
Training history plot saved as 'training_history.png'

在这里插入图片描述
一、左侧:Loss Curves(损失曲线)
蓝色:训练损失(Training Loss)
橙色:验证损失(Validation Loss)

二、右侧:Accuracy Curves(准确率曲线)
蓝色:训练准确率(Training Accuracy)
橙色:验证准确率(Validation Accuracy)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/93074.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/93074.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/93074.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ffmpeg,ffplay, vlc,rtsp-simple-server,推拉流命令使用方法,及测试(二)

一、常用命令 ffmpeg 推流命令 : ffmpeg -re -i input.mp4 -c copy -f flv rtmp://39.105.129.233/myapp/ffmpeg -re -i input.mp4 -c copy -f flv rtsp://39.105.129.233/myapp/-re 读取流 -i 输入文件 -f # 指定推流formatffplay 拉流命令 : ffplay rtmp://39.105.129.233/m…

使用行为树控制机器人(三) ——通用端口

文章目录一、通用端口功能实现1. 功能实现1.1 头文件定义1.2 源文件实现1.3 main文件实现1.4 tree.xml 实现2. 执行结果使用行为树控制机器人(一) —— 节点使用行为树控制机器人(二) —— 黑板使用行为树控制机器人(三) —— 通用端口有了上述前两节我们已经可以实现节点间的通…

DataDome反爬虫验证技术深度解析:无感、滑块与设备验证全攻略

DataDome反爬虫验证技术深度解析&#xff1a;无感、滑块与设备验证全攻略 随着网络安全威胁的不断演进&#xff0c;企业对数据保护的需求日益增强。DataDome作为业界领先的反爬虫解决方案&#xff0c;以其三层防护机制在众多知名网站中得到广泛应用。本文将深入解析DataDome的…

RabbitMQ 消息转换器详解

RabbitMQ 消息转换器详解 一、为什么需要消息转换器&#xff1f; RabbitMQ 的消息传输协议只识别字节流&#xff1a; 发送对象时&#xff0c;需要序列化成字节数组接收消息时&#xff0c;需要将字节数组反序列化成对象 如果不使用消息转换器&#xff1a; 需要手动序列化和反序列…

内网穿透的应用-告别“现场救火”!用 cpolar远程调试让内网故障排查进入“云时代”

文章目录前言**常见困境与解决方案****实际应用价值**1. Remote JVM Debug2. 系统要求与环境准备2.1 服务器环境2.2 本地开发环境3. 内网服务器准备及开始3.1 安装cpolar配置支持远程ssh登录3.1.1 什么是cpolar&#xff1f;3.1.2 安装cpolar3.1.3 注册及配置cpolar系统服务3.1.…

Cherryusb UAC例程对接STM32内置ADC和PWM播放音乐和录音(下)=>UAC+STM32 ADC+PWM实现录音和播放

1. 程序基本框架整个程序框架, 与之前的一篇文章《Cherryusb UAC例程对接STM32内置ADC和DAC播放音乐和录音(中)>UACSTM32 ADCDAC实现录音和播放》基本一致, 只是这次将DAC替换成了PWM。因此这里不再赘述了。 2. audio_v1_mic_speaker_multichan_template.c的修改说明(略) 参…

1 JQ6500语音播报模块详解(STM32)

系列文章目录 文章目录系列文章目录前言1 JQ6500简介2 基本参数说明2.1 硬件参数2.2 模块管脚说明3 控制方式3.1 通信格式3.2 通信指令4 硬件设计5 软件设计5.1 main.c5.2 board_config5.2.1board_config.h5.2.2 board_config.c5.3 module_config5.3.1 module_config.h5.3.2 mo…

常用数据分析工具

Tableau丨Power BI丨FineBI丨SQL丨影刀丨Excel丨Python丨 参考视频&#xff1a;【戴师兄】数据分析有哪些必学工具&#xff1f;2023最新版&#xff01;Tableau丨Power BI丨FineBI丨SQL丨影刀丨Excel丨Python丨课程教程自学攻略_哔哩哔哩_bilibili 文档资料&#xff1a; 【戴师兄…

OBOO鸥柏丨智能会议平板教学查询一体机交互式触摸终端招标投标核心标底参数要求

整机参数要求&#xff1a;55寸/65寸/75寸/85-86寸/98寸/100寸/110寸/115寸智能会议平板教学触控一体机/智慧黑板触摸屏招标投标核心标底参数要求1、整机屏幕采用≥采用超高清原厂原包原装工业LCD液晶屏面板&#xff1b;具有高色域&#xff0c;显示动态视频、web及3D动画时&…

无人机在环保监测中的应用:低空经济发展的智能监测与高效治理

一、行业背景与技术革新 随着全球环境问题日益严峻&#xff0c;传统环保监测手段已难以满足现代环境管理的需求。固定监测站点建设成本高、覆盖范围有限&#xff0c;地面巡查效率低下且存在安全风险。在此背景下&#xff0c;无人机技术凭借其独特的空间优势和技术特性&#xff…

PO、BO、VO、DTO、POJO、DAO、DO基本概念

一、图解二、相关概念 1、PO&#xff08;Persistant Object - 持久化对象&#xff09; 核心定位&#xff1a; 直接与数据库表结构一一映射的对象&#xff0c;通常用于 ORM&#xff08;对象关系映射&#xff09;框架&#xff08;如 MyBatis、Hibernate&#xff09;中。 特点&…

todoList清单(HTML+CSS+JavaScript)

&#x1f30f;个人博客主页&#xff1a; 前言&#xff1a; 前段时间学习了JavaScript&#xff0c;然后写了一个todoList小项目&#xff0c;现在和大家分享一下我的清单以及如何实现的&#xff0c;希望对大家有所帮助 &#x1f525;&#x1f525;&#x1f525;文章专题&#xff…

Mac M1探索AnythingLLM+Ollama+知识库问答

AnythingLLM内置 RAG、AI Agent、可视化/无代码的 Agent 编排&#xff0c;支持多家模型与本地/云端向量库&#xff0c;并提供多用户与可嵌入的聊天组件&#xff0c;用来快速验证“知识 模型 工具”拼成的 AI 应用。 1 AnythingLLM、Ollama准备 1&#xff09;AnythingLLM 打…

【 Navicat Premium 17 完全图形化新手指南(从零开始)】

Navicat Premium 17 完全图形化新手指南&#xff08;从零开始&#xff09; 一、准备阶段&#xff1a;清理现有环境 1. 删除已创建的测试数据库&#xff08;如需重新开始&#xff09;打开Navicat Premium 17 双击桌面图标启动程序在左侧连接面板中找到你的MySQL连接&#xff08;…

Web学习笔记5

Javascript概述1、JS简介JS是运行在浏览器的脚本编程语言&#xff0c;最初用于Web表单的校验。现在的作用主要有三个&#xff1a;网页特效、表单验证、数据交互JS由三部分组成&#xff0c;分别是ECMAscript、DOM、BOM&#xff0c;其中ECMAscript规定了JS的基本语法和规则&#…

部署一个开源的证件照系统

以下数据来自官方网站,记录下来,方便自己 项目简介 &#x1f680; 谢谢你对我们的工作感兴趣。您可能还想查看我们在图像领域的其他成果&#xff0c;欢迎来信:zeyi.linswanhub.co. HivisionIDPhoto 旨在开发一种实用、系统性的证件照智能制作算法。 它利用一套完善的AI模型工作…

Linux客户端利用MinIO对服务器数据进行同步

接上篇 Windows客户端利用MinIO对服务器数据进行同步 本篇为Linux下 操作&#xff0c;先看下我本地的系统版本 所以我这里下载的话&#xff0c;是AMD64 文档在这 因为我这里只是需要用到客户端&#xff0c;获取数据而已&#xff0c;所以我只需要下载个MC工具用来数据获取就可以…

Docker 中部署 MySQL 5.7 并远程连接 Navicat 的完整指南

个人名片 &#x1f393;作者简介&#xff1a;java领域优质创作者 &#x1f310;个人主页&#xff1a;码农阿豪 &#x1f4de;工作室&#xff1a;新空间代码工作室&#xff08;提供各种软件服务&#xff09; &#x1f48c;个人邮箱&#xff1a;[2435024119qq.com] &#x1f4f1…

自己动手造个球平衡机器人

你是否曾对那些能够精妙地保持平衡的机器设备感到好奇&#xff1f; 从无人机到独轮平衡车&#xff0c;背后都蕴藏着复杂的控制系统。 今天&#xff0c;我们来介绍一个充满挑战与乐趣的项目——制作一个球平衡机器人。这不仅是一个酷炫的摆件&#xff0c;更是一次深入学习机器…

21.Linux HTTPS服务

Linux : HTTPS服务协议传输方式端口安全性HTTP明文传输80无加密&#xff0c;可被窃听HTTPS加密传输443HTTP SSL/TLS 数据加密&#xff08;防窃听&#xff09;身份认证&#xff08;防伪装&#xff09;完整性校验&#xff08;防篡改&#xff09;OpenSSL 证书操作核心命令命令选项…