@浙大疏锦行https://blog.csdn.net/weixin_45655710知识点回顾:

  1. tensorboard的发展历史和原理
  2. tensorboard的常见操作
  3. tensorboard在cifar上的实战:MLP和CNN模型

作业:对resnet18在cifar10上采用微调策略下,用tensorboard监控训练过程。

核心:

  1. 数据加载和模型创建:复用之前的函数,保持模块化。

  2. SummaryWriter初始化:创建TensorBoard的写入器,并自动处理日志目录,避免覆盖。

  3. train_and_evaluate函数:创建一个总控函数,封装了完整的“冻结-解冻”训练循环,并在其中集成了TensorBoard的各种日志记录功能。

  4. TensorBoard日志记录

  • 模型图谱 (Graph):在训练开始前,记录模型的计算图。
  • 标量 (Scalars):实时记录训练集和测试集的损失(Loss)与准确率(Accuracy),以及学习率(Learning Rate)的变化。
  • 图像 (Images):记录输入的样本图像和每个epoch结束时预测错误的样本。
  • 直方图 (Histograms):定期记录模型各层权重(Weights)和梯度(Gradients)的分布,用于诊断训练状态。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter # 导入TensorBoard的核心类
import matplotlib.pyplot as plt
import os
import time
from tqdm import tqdm
import torchvision # 确保torchvision被导入以使用make_grid# --- 步骤 1: 准备数据加载器 (保持不变) ---
def get_cifar10_loaders(batch_size=128):"""获取CIFAR-10的数据加载器,包含数据增强"""train_transform = transforms.Compose([transforms.RandomResizedCrop(224), # ResNet通常在224x224的图像上预训练transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet的标准化参数])test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)return train_loader, test_loader# --- 步骤 2: 模型创建与冻结/解冻函数 (保持不变) ---
def create_resnet18(pretrained=True, num_classes=10):"""创建并修改ResNet18模型"""model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)in_features = model.fc.in_featuresmodel.fc = nn.Linear(in_features, num_classes)return modeldef set_freeze_state(model, freeze=True):"""冻结或解冻模型的特征提取层"""print(f"--- {'冻结' if freeze else '解冻'} 特征提取层 ---")for name, param in model.named_parameters():if 'fc' not in name: # 只训练最后的全连接层param.requires_grad = not freeze# --- 步骤 3: 封装了TensorBoard的训练与评估总控函数 ---
def train_with_tensorboard(model, device, train_loader, test_loader, epochs, freeze_epochs, writer):"""使用TensorBoard监控的完整训练流程"""# 初始化优化器和损失函数criterion = nn.CrossEntropyLoss()# 初始只优化未冻结的参数optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5, verbose=True)# --- TensorBoard初始记录 ---print("正在记录初始信息到TensorBoard...")dataiter = iter(train_loader)images, _ = next(dataiter)writer.add_graph(model, images.to(device)) # 记录模型图img_grid = torchvision.utils.make_grid(images[:16]) # 取16张图预览writer.add_image('CIFAR-10 样本图像', img_grid)print("✅ 初始信息记录完成。")# 开始训练global_step = 0for epoch in range(1, epochs + 1):# --- 解冻控制 ---if epoch == freeze_epochs + 1:set_freeze_state(model, freeze=False)# 解冻后需要为优化器加入所有参数optimizer = optim.Adam(model.parameters(), lr=1e-4) # 使用更小的学习率进行全局微调print("优化器已更新以包含所有参数,学习率已降低。")# --- 训练部分 ---model.train()train_loss, train_correct, train_total = 0, 0, 0loop = tqdm(train_loader, desc=f"Epoch [{epoch}/{epochs}] Training", leave=False)for data, target in loop:data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()train_loss += loss.item() * data.size(0)_, pred = output.max(1)train_correct += pred.eq(target).sum().item()train_total += data.size(0)writer.add_scalar('Train/Batch_Loss', loss.item(), global_step)global_step += 1loop.set_postfix(loss=loss.item())loop.close()# 记录Epoch级训练指标avg_train_loss = train_loss / train_totalavg_train_acc = 100. * train_correct / train_totalwriter.add_scalar('Train/Epoch_Loss', avg_train_loss, epoch)writer.add_scalar('Train/Epoch_Accuracy', avg_train_acc, epoch)# --- 评估部分 ---model.eval()test_loss, test_correct, test_total = 0, 0, 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)loss = criterion(output, target)test_loss += loss.item() * data.size(0)_, pred = output.max(1)test_correct += pred.eq(target).sum().item()test_total += data.size(0)# 记录Epoch级测试指标avg_test_loss = test_loss / test_totalavg_test_acc = 100. * test_correct / test_totalwriter.add_scalar('Test/Epoch_Loss', avg_test_loss, epoch)writer.add_scalar('Test/Epoch_Accuracy', avg_test_acc, epoch)# 记录权重和梯度的直方图 (每个epoch记录一次)for name, param in model.named_parameters():writer.add_histogram(f'Weights/{name}', param, epoch)if param.grad is not None:writer.add_histogram(f'Gradients/{name}', param.grad, epoch)# 更新学习率调度器scheduler.step(avg_test_loss)writer.add_scalar('Train/Learning_Rate', optimizer.param_groups[0]['lr'], epoch)print(f"Epoch {epoch} 完成 | 训练准确率: {avg_train_acc:.2f}% | 测试准确率: {avg_test_acc:.2f}%")# --- 步骤 4: 主执行流程 ---
if __name__ == "__main__":# --- 配置 ---EPOCHS = 15FREEZE_EPOCHS = 5 # 先冻结训练5轮,再解冻训练10轮BATCH_SIZE = 64DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")# --- TensorBoard 初始化 ---log_dir = "runs/resnet18_finetune_cifar10"version = 1while os.path.exists(f"{log_dir}_v{version}"):version += 1log_dir = f"{log_dir}_v{version}"writer = SummaryWriter(log_dir)print(f"TensorBoard 日志将保存在: {log_dir}")# --- 开始实验 ---train_loader, test_loader = get_cifar10_loaders(batch_size=BATCH_SIZE)model = create_resnet18(pretrained=True).to(DEVICE)set_freeze_state(model, freeze=True) # 初始冻结print("\n--- 开始使用ResNet18微调模型 ---")print("训练完成后,在终端运行 `tensorboard --logdir=runs` 来查看可视化结果。")train_with_tensorboard(model, DEVICE, train_loader, test_loader, EPOCHS, FREEZE_EPOCHS, writer)writer.close() # 关闭writerprint("\n✅ 训练完成,TensorBoard日志已保存。")

解析

1.数据预处理适配 (get_cifar10_loaders)

图像尺寸ResNet系列是在224x224的ImageNet图像上预训练的。虽然它们也能处理32x32的CIFAR-10图像,但为了更好地利用预训练权重,一个常见的做法是将小图像放大224x224。我们在transforms中加入了transforms.RandomResizedCrop(224)transforms.Resize(256) / transforms.CenterCrop(224)来实现这一点。

标准化参数:使用了ImageNet数据集的标准化均值和标准差,这是使用在ImageNet上预训练的模型的标准做法

2.模块化训练流程 (train_with_tensorboard)

将整个包含“冻结-解冻”逻辑的训练循环封装成一个函数,使得主程序非常简洁。

该函数接收一个writer对象作为参数,所有TensorBoard的日志记录都在这个函数内部完成。

3.TensorBoard全面监控

  • 模型图 (add_graph):在训练开始前,将模型的结构图写入日志,方便在GRAPHS标签页查看。
  • 图像 (add_image):将一批原始训练样本写入日志,可以在IMAGES标签页直观地看到输入数据。
  • 标量 ( add_scalar )

Batch级:记录了每个训练批次的损失(Train/Batch_Loss),可以观察到最细粒度的训练动态。

Epoch级:记录了每个轮次结束后的训练和测试的损失准确率,以及学习率的变化。这能让我们在同一个图表中清晰地对比训练集和测试集的性能曲线,判断过拟合。

  • 直方图 (add_histogram):每个轮次结束后,记录模型所有可训练参数的权重分布梯度分布。这对于高级调试非常有用,可以帮助判断是否存在梯度消失/爆炸,或者权重是否更新正常。

4.清晰的执行逻辑

if __name__ == "__main__":中,代码逻辑非常清晰:设置参数 -> 初始化TensorBoard写入器 -> 准备数据 -> 创建模型 -> 调用总控函数开始训练 -> 结束并关闭写入器。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/89323.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/89323.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/89323.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2023年全国硕士研究生招生考试英语(一)试题总结

文章目录 题型与分值分布完形填空错误 1:考察连词 or 前后内容之间的逻辑关系错误2:错误3:错误4:这个错得最有价值,因为压根没读懂错误5:学到的短语: 仔细阅读排序/新题型翻译小作文大作文 题型…

react-数据Mock实现——json-server

什么是mock? 在前后端分离的开发模式下,前端可以在没有实际后端接口的支持下先进行接口数据的模拟,进行正常的业务功能开发 json-server实现数据Mock json-server是一个node的包,可以在不到30秒内获得零编码的完整Mock服务 实现…

使用POI导入解析excel文件

首先校验 /*** 校验导入文件* param file 上传的文件* return 校验结果,成功返回包含成功状态的AjaxResult,失败返回包含错误信息的AjaxResult*/private AjaxResult validateImportFile(MultipartFile file) {if (file.isEmpty()) {return AjaxResult.er…

从0开始学习计算机视觉--Day06--反向传播算法

尽管解析梯度可以让我们省去巨大的计算量,但如果函数比较复杂,对这个损失函数进行微分计算会变得很困难。我们通常会用反向传播技术来递归地调用链式法则来计算向量每一个方向上的梯度。具体来说,我们将整个计算过程的输入与输入具体化&#…

企业流程知识:《学习观察:通过价值流图创造价值、消除浪费》读书笔记

《学习观察:通过价值流图创造价值、消除浪费》读书笔记 作者:迈克鲁斯(Mike Rother),约翰舒克(John Shook) 出版时间:1999年 历史地位:精益生产可视化工具的黄金标准&am…

Day02_C语言IO进程线程

01.思维导图 02.将当前的时间写入到time. txt的文件中,如果ctrlc退出之后,在再次执行支持断点续写 1.2022-04-26 19:10:20 2.2022-04-26 19:10:21 3.2022-04-26 19:10:22 //按下ctrlc停止,再次执行程序 4.2022-04-26 20:00:00 5.2022-04-26 2…

FFmpeg中TS与MP4格式的extradata差异详解

在视频处理中,extradata是存储解码器初始化参数的核心元数据,直接影响视频能否正确解码。本文深入解析TS和MP4格式中extradata的结构差异、存储逻辑及FFmpeg处理方案。 📌 一、extradata的核心作用 extradata是解码必需的参数集合&#xff0…

【CV数据集介绍-40】Cityscapes 数据集:助力自动驾驶的语义分割神器

🧑 博主简介:曾任某智慧城市类企业算法总监,目前在美国市场的物流公司从事高级算法工程师一职,深耕人工智能领域,精通python数据挖掘、可视化、机器学习等,发表过AI相关的专利并多次在AI类比赛中获奖。CSDN…

SAP月结问题9-FAGLL03H与损益表中研发费用金额不一致(FAGLL03H Bug)

SAP月结问题9-FAGLL03H与损益表中研发费用金额不一致(S4 1709) 财务反馈,月结后核对数据时发现FAGLL03H导出的研发费用与损益表中的研发费用不一致,如下图所示: 对比FAGLL03H与损益表对应的明细,发现FAGLL03H与损益表数据存在倍数…

HTML inputmode 属性详解

inputmode 是一个 HTML 属性&#xff0c;用于指定用户在编辑元素或其内容时应使用的虚拟键盘布局类型。它主要影响移动设备和平板电脑的输入体验。 语法 <input inputmode"value"> <!-- 或 --> <textarea inputmode"value"></texta…

软考中级【网络工程师】第6版教材 第1章 计算机网络概述

考点分析&#xff1a; 本章重要程度&#xff1a;一般&#xff0c;为后续章节做铺垫&#xff0c;有总体认识即可&#xff0c;选择题1-2分高频考点&#xff1a;OSI模型、TCP/IP模型、每个层次的功能、协议层次新教材变化&#xff1a;删除网络结构、删除X.25、更新互联网发展【基本…

Mysql事务与锁

数据库并发事务 数据库一般都会并发执行多个事务&#xff0c;多个事务可能会并发的对相同的一批数据进行增删改查操作&#xff0c;可能就会导致我们说的脏写、脏读、不可重复读、幻读这些问题。为了解决这些并发事务的问题&#xff0c;数据库设计了事务隔离机制、锁机制、MVCC多…

Bilibili多语言字幕翻译扩展:基于上下文的实时翻译方案设计

Bilibili多语言字幕翻译扩展&#xff1a;基于上下文的实时翻译方案设计 本文介绍了一个Chrome扩展的设计与实现&#xff0c;该扩展可以为Bilibili视频提供实时多语言字幕翻译功能。重点讨论了字幕翻译中的上下文问题及其解决方案。 该项目已经登陆Chrome Extension Store: http…

热血三国野地名将列表

<!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>野地名将信息表</title><style>tabl…

【记录】Word|Word创建自动编号的多级列表标题样式

文章目录 前言创建方式第一种方法&#xff1a;从“定义多级列表”中直接绑定已有样式第二种方法&#xff1a;通过已有段落创建样式&#xff0c;再绑定补充说明 尾声 前言 这世上荒唐的事情不少&#xff0c;但若说到吊诡&#xff0c;Word中的多级列表样式设定&#xff0c;倒是能…

使用mavros启动多机SITL仿真

使用mavros启动多机SITL仿真 方式1&#xff1a;使用roslaunch一键启动Step1&#xff1a;创建一个新的 ROS 包或放到现有包里Step2&#xff1a;编辑 multi_mavros.launchStep3&#xff1a;构建工作空间并 source 环境Step4&#xff1a;构建工作空间并 source 环境 方式2&#xf…

Flutter 网络栈入门,Dio 与 Retrofit 全面指南

面向多年 iOS 开发者的零阻力上手 写在前面 你在 iOS 项目中也许习惯了 URLSession、Alamofire 或 Moya。 换到 Flutter 后&#xff0c;等价的「组合拳」就是 Dio Retrofit。 本文将带你一次吃透两套库的安装、核心 API、进阶技巧与最佳实践。 1. Dio&#xff1a;Flutter 里的…

工作室考核源码(带后端)

题目内容可更改 下载地址:https://mcwlkj.lanzoub.com/iUF3z300tgfe 如图所示

数字孪生技术为UI前端提供全面支持:实现产品的可视化配置与定制

hello宝子们...我们是艾斯视觉擅长ui设计、前端开发、数字孪生、大数据、三维建模、三维动画10年经验!希望我的分享能帮助到您!如需帮助可以评论关注私信我们一起探讨!致敬感谢感恩! 一、引言&#xff1a;数字孪生驱动产品定制的技术革命 在消费升级与工业 4.0 的双重驱动下&a…

通往物理世界自主智能的二元实在论与罗塞塔协议

序章&#xff1a;AI的“两种文化”之争——我们是否在构建错误的“神”&#xff1f; 自诞生以来&#xff0c;人工智能领域始终存在着一场隐秘的“两种文化”之争。一方是符号主义与逻辑的信徒&#xff0c;他们追求可解释、严谨的推理&#xff0c;相信智能的核心在于对世界规则…