前言

  Tensorboard最初是tensorflow的可视化工具,被用于机器学习实验的可视化,后来也适配了pytorch。Tensorboard是一个前端web界面,,能够从文件里面读取数据并展示它(比如损失、准确率、网络图)。具体使用可以参考。

tensorboard安装使用教程https://zhuanlan.zhihu.com/p/420943896

        Tensorboard的安装比较简单,这里我使用conda安装:

conda install tensorboard
tensorboard --version
# 2.19.0

1. 常用操作

    1.1 导入库文件和生成对象

        我们的第一个任务就是 将我们想可视化的数据写入到tensorboard可以读取的文件中,这可以通过下面来实现:

from torch.utils.tensorboard import SummaryWriter  # 导入库文件        
writer = SummaryWriter('./tensorboard')        # 生成tensorboard对象
# ...
writer.close()

   1.2 数字 (scalar)

writer.add_scalar(tag, scalar_value, global_step=None, walltime=None) # tag (string): 曲线名称 。
# scalar_value (float): 数字常量值 。
# global_step (int, optional): 训练的当前步数。
# walltime (float, optional): 记录发生的时间,默认为 time.time()

   1.3 图像 (image)

writer.add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')# tag (string): 图像名称。
# img_tensor (torch.Tensor / numpy.array): 图像数据。
# global_step (int, optional): 训练的当前步数。
# walltime (float, optional): 记录发生的时间,默认为 time.time()
# dataformats (string, optional): 图像数据的格式,默认为 'CHW',即C:Channel。H:Height。W:Width。还可以是 'HWC' 或 'HW' 。

   1.4 直方图 (histogram)

writer.add_histogram(tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None)# tag (string): 数据名称(名称相同,多次存入图片,可以形成视频效果)。
# values (torch.Tensor, numpy.array, or string/blobname): 直方图的数据(训练参数,注意力分数等)。
# global_step (int, optional): 训练的当前步数。
# bins (string, optional): 取值有 ‘tensorflow’、‘auto’、‘fd’ 等, 表示元素个数。
# walltime (float, optional): 记录发生的时间,默认为 time.time()。
# max_bins (int, optional): 表示元素最大个数。

   1.5 模型结构图 (graph)

writer.add_graph(model, input_to_model=None, verbose=False, **kwargs)# model (torch.nn.Module): 网络模型。
# input_to_model (torch.Tensor or list of torch.Tensor, optional): 模型输入参数

   1.6 嵌入向量 (embedding)

writer.add_embedding(mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None)# mat (torch.Tensor or numpy.array): 数据点,shape:NxHxW。
# metadata (list or torch.Tensor or numpy.array, optional): 分类标签,长度=N。
# label_img (torch.Tensor, optional): shape:Nx1xHxW 的张量。
# global_step (int, optional): 训练的当前步数。
# tag (string, optional): 数据名称。

   1.7 打开web端展示结果

# 命令行输入:
tensorboard --logdir=tensorboard

2. Tensorboard 示例

   2.1 代码示例

        我们在上一个博客中的CNN训练代码中,补充tensorboard对应的部分,去记录并可视化训练过程中的训练数据、损失/精确度变化、权重/梯度变化等等。

import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from CNN_network import Network,train_set
from torch.utils.tensorboard import SummaryWriterdef get_num_correct(preds, labels):return preds.argmax(dim=1).eq(labels).sum().item()network = Network()
train_loader = DataLoader(train_set, batch_size=100, shuffle=True)
optimizer = optim.Adam(network.parameters(), lr=0.01)# -生成对象(1.1),写入图像(1.3)和模型结构图(1.5)----------------------------------------
images, labels = next(iter(train_loader))  # get图像和标签
grid = torchvision.utils.make_grid(images) # 将一个包含多张图像的batch tensor,转换成一个可视化的图像网格writer = SummaryWriter('./tensorboard')writer.add_image('images', grid)
writer.add_graph(network, images)
# ---------------------------------------------------------------------------------for epoch in range(10):total_loss = 0total_correct = 0for batch in train_loader: # Get batchimages, labels = batchpreds = network(images) # Pass batchloss = F.cross_entropy(preds, labels) # calculate lossoptimizer.zero_grad() # 梯度归零loss.backward() # caculate gradientsoptimizer.step() # updata weighttotal_loss += loss.item()total_correct += get_num_correct(preds, labels)# -写入标量(1.2)和直方图(1.4)---------------------------------------------------------# 记录每个epoch的相应值,分别画曲线图和直方图writer.add_scalar('loss', total_loss, epoch)writer.add_scalar('accuracy', total_correct/len(train_loader), epoch)writer.add_scalar('number correct', total_correct/len(train_loader), epoch)writer.add_histogram('conv1.bias', network.conv1.bias, epoch)writer.add_histogram('conv1.weight', network.conv1.weight, epoch)writer.add_histogram('conv1.weight.grad', network.conv1.weight.grad, epoch)
# ---------------------------------------------------------------------------------print('epoch:', epoch, 'total_loss:', total_loss, 'total_correct:', total_correct)writer.close()

        执行上述代码后,我们会在./tensorboard文件夹中生成一个记录文件,我们可以通过执行下述命令,会返回一个网站,点击即可打开tensorboard web端查看

tensorboard --logdir=tensorboard

   2.2 web可视化结果展示

        点击进入后,首页大致是下面这样,这里蓝色方框里就对应刚刚在代码里添加的 标量图、直方图等等

  • 数据图像及模型结构图

  • 标量图:纵轴为我们记录的数据,横轴为epoch,也可以在左边蓝色方框位置修改线的平滑度以及横轴

  • 直方图:纵轴为epoch,横轴为tensor的值,每个直方图代表训练过的9个epoch中的1个,显示了 tensor权重/梯度等 倾向于集中在哪个区域的信息,用于发现异常。

  • 分布图:分布图会随同直方图一起出现,纵轴为tensor的值,横轴为epoch,分布图显示了这些tensor如何随着训练的进行而变化。 较暗的区域显示值在某个区域停留了更长的时间, 如果担心模型权重在每个epoch 都没有正确更新,可以使用此选项发现这些问题。

3. 快速对批量超参数进行实验

        在上面的案例中,这些操作我们通过Python或者R也可以实现,Tensorboard真正强大的地方在于可以批量对不同组合的超参数进行实验分析。

        同样拿上面的例子,我们想对 学习率分别取0.1,0.01,0.001,对batch_size分别取100,1000,10000,并对以上两者作笛卡尔积两两组合,看看哪种组合效果最好。下面首先在3.1和3.2补充两个细节点。

   3.1 如何避免多层循环做笛卡尔积

        常规我们做组合一般是这样:

batch_size_list = [100,1000,10000]
lr_list = [0.01,0.001,0.0001]for batch_size in batch_size_list:for lr in lr_list:xxx

        但是一般如果组合选项特别多,就会出现多层for循环嵌套,因此我们可以改进这一点

from itertools import product# 定义一个参数字典
hyperparam = dict(lr = [.01,.001],batch_size = [100,1000],shuffle = [True,False]
)# 返回字典中 每个键值对 对应的 值 来获取参数列表
param_values = [v for v in hyperparam.values()]
print(param_values)
# [[0.01, 0.001], [100, 1000], [True, False]]# 将参数传递给 product(),构建笛卡尔积
# *用于将列表中的每个值作为一个参数,而不是将列表本身用作参数
for lr,batch_size,shuffle in product(*param_values):print(lr,batch_size,shuffle)# 0.01 100 True
# 0.01 100 False
# 0.01 1000 True
# 0.01 1000 False
# 0.001 100 True
# 0.001 100 False
# 0.001 1000 True
# 0.001 1000 False

   3.2 Tensorboard文件的命名修改

SummaryWriter 的构造函数通常是:

SummaryWriter(log_dir=None, comment='', filename_suffix='')# log_dir: 指定存储 TensorBoard 日志文件的目录路径。
# comment: 可选的字符串,会附加到自动生成的日志目录名后面(如果你没有指定 log_dir)
# filename_suffix: 可选的字符串后缀,会添加到生成的事件文件名的末尾。可用来进一步区分同一目录下的不同日志文件

Tensorboard文件的命名逻辑(如果不指定log_dir):

# PyTorch version 1.1.0 SummaryWriter class
if not log_dir:import socketfrom datetime import datetimecurrent_time = datetime.now().strftime('%b%d_%H-%M-%S')log_dir = os.path.join('runs', current_time + '_' + socket.gethostname() + comment)
self.log_dir = log_dir

   3.3 完整代码

import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from CNN_network import Network,train_set
from torch.utils.tensorboard import SummaryWriter
from itertools import product
import osdef get_num_correct(preds, labels):return preds.argmax(dim=1).eq(labels).sum().item()# ----------------------------------------------------------------------------------------------------------
hyperparam = dict(lr = [.01,.001],batch_size = [100,1000],shuffle = [True,False]
)
param_values = [v for v in hyperparam.values()]
# ----------------------------------------------------------------------------------------------------------for lr,batch_size,shuffle in product(*param_values):network = Network()train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=shuffle)# train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=shuffle, drop_last=True) # 丢弃最后一个batchoptimizer = optim.Adam(network.parameters(), lr=lr)# -生成对象(1.1),写入图像(1.3)和模型结构图(1.5)images, labels = next(iter(train_loader))  # get图像和标签grid = torchvision.utils.make_grid(images) # 将一个包含多张图像的batch tensor,转换成一个可视化的图像网格# ---------------------------------------------------------------main_log_dir = './tensorboard' # 主目录os.makedirs(main_log_dir, exist_ok=True)comment = f'batch_size={batch_size}, shuffle={shuffle}, lr={lr}'run_log_dir = os.path.join(main_log_dir, comment)writer = SummaryWriter(log_dir=run_log_dir)# ---------------------------------------------------------------writer.add_image('images', grid)writer.add_graph(network, images)for epoch in range(10):total_loss = 0total_correct = 0for batch in train_loader: # Get batchimages, labels = batchpreds = network(images) # Pass batchloss = F.cross_entropy(preds, labels) # calculate lossoptimizer.zero_grad() # 这里十分重要!!!  pytorch会累加梯度,所以在每个循环时,都必须先将梯度归零loss.backward() # caculate gradientsoptimizer.step() # updata weight# 作相应修改------------------------------------------# 训练集不一定恰巧被batch平均分配,可以丢弃最后一个batch,或者使用每个batch第一个轴的长度代替batch_size# total_loss += loss.item() * batch_sizetotal_loss += loss.item() * images.shape[0]total_correct += get_num_correct(preds, labels)# 写入标量(1.2)和直方图(1.4)# 记录每个epoch的相应值,分别画曲线图和直方图writer.add_scalar('loss', total_loss/100, epoch)writer.add_scalar('accuracy', total_correct/len(train_set), epoch)writer.add_scalar('number correct', total_correct, epoch)for name, weight in network.named_parameters():writer.add_histogram(name, weight, epoch)writer.add_histogram(f'{name}.grad', weight.grad, epoch)print('epoch:', epoch, 'total_loss:', total_loss, 'total_correct:', total_correct)writer.close()

        我们可以从左下角红色方框处看到所有的组合,从图中可明显看出,当batch_sizes设为1000,学习率为0.001时,建模效果比较差

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/95745.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/95745.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/95745.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言————实战项目“扫雷游戏”(完整代码)

无论是找工作面试,还是课设大作业、考研,都离不开实战项目的积累,如果你能把一个项目搞明白,并且给别人熟练的讲出来,即使你没有过项目经历,也可以说是非常加分的,下面来沉浸式体验一下这款扫雷…

数据结构之加餐篇 -顺序表和链表加餐

目录一、链表分割二、随机链表的复制总结一、链表分割 链表分割 题目描述的意思就如下图: 也就是把1,2挪到前面,6,3,5挪到后面,前者的相对顺序不发生改变 这里要想往后挪就要先遍历,遍历到6…

JSP与Servlet整合数据库开发:构建Java Web应用的全栈指南

JSP与Servlet整合数据库开发:构建Java Web应用的全栈指南 概述 在Java Web开发领域,JSP(JavaServer Pages)与Servlet是构建动态Web应用的核心技术组合。Servlet作为Java EE的基础组件,负责处理客户端请求、执行业务逻…

设计五种算法精确的身份证号匹配

问题定义与数据准备 我们有两个Excel文件: small.xlsx: 包含约5,000条记录。large.xlsx: 包含约140,000条记录。 目标:快速、高效地从large.xlsx中找出所有其“身份证号”字段存在于small.xlsx“身份证号”字段中的记录,并将这些匹配的记录保…

Spring 框架(IoC、AOP、Spring Boot) 的必会知识点汇总

目录:🧠 一、Spring 框架概述1. Spring 的核心功能2. Spring 模块化结构🧩 二、IoC(控制反转)核心知识点1. IoC 的核心思想2. Bean 的定义与管理3. IoC 容器的核心接口4. Spring Bean 的创建方式🧱 三、AOP…

简单工厂模式(Simple Factory Pattern)​​ 详解

✅作者简介:大家好,我是 Meteors., 向往着更加简洁高效的代码写法与编程方式,持续分享Java技术内容。 🍎个人主页: Meteors.的博客 💞当前专栏: 设计模式 ✨特色专栏: 知识分享 &…

新电脑硬盘如何分区?3个必知技巧避免“空间浪费症”!

刚到手的新电脑,硬盘就像一间空荡荡的大仓库,文件扔进去没多久就乱成一锅粥?别急,本文会告诉你新电脑硬盘如何分区,这些方法不仅可以帮你给硬盘分区,还可以调整/合并分区大小等。所以,本文的分区…

【微知】git submodule的一些用法总结(不断更新)

文章目录综述要点细节如何新增一个submodule?如何手动.gitmodules修改首次增加一个submodule?git submodule init,init子命令依据.gitmodules.gitmodules如何命令修改某个成员以及同步?如果submodule需要修改分支怎么办&#xff1…

【Spring Cloud微服务】9.一站式掌握 Seata:架构设计与 AT、TCC、Saga、XA 模式选型指南

文章目录一、Seata 框架概述二、核心功能特性三、整体架构与三大角色1. Transaction Coordinator (TC) - 事务协调器(Seata Server)2. Transaction Manager (TM) - 事务管理器(集成在客户端)3. Resource Manager (RM) - 资源管理器…

AI赋能!Playwright带飞UI自动化脚本维护

80%的自动化脚本因一次改版报废? 开发随意改动ID导致脚本集体崩溃?背景UI自动化在敏捷开发席卷行业的今天,UI自动化测试深陷一个尴尬困局:需求迭代速度(平均2周1次)> 脚本维护速度(平…

Redis、Zookeeper 与关系型数据库分布式锁方案对比及性能优化实战指南

Redis、Zookeeper 与关系型数据库分布式锁方案对比及性能优化实战指南 1. 问题背景介绍 在分布式系统中,多节点并发访问共享资源时,如果不加锁或加锁不当,会导致数据不一致、超卖超买、竞态条件等问题。常见的分布式锁方案包括基于Redis、Zoo…

网络安全A模块专项练习任务十一解析

任务十一:IP安全协议配置任务环境说明: (Windows 2008)系统:用户名Administrator,密码Pssw0rd1.指定触发SYN洪水攻击保护所必须超过的TCP连接请求数阈值为5;使用组合键winR,输入regedit打开注册表编辑器&am…

金蝶中间件适配HGDB

文章目录环境文档用途详细信息环境 系统平台:Microsoft Windows (64-bit) 10 版本:5.6.5 文档用途 本文章主要介绍金蝶中间件简单适配HGDB。 详细信息 一、金蝶中间件Apusic安装与配置 1.Apusic安装与配置 Windows和Linux下安装部署过程相同。 &…

使用a标签跳转之后,会刷新一次,这个a标签添加的样式就会消失

<ul class"header-link"><li><a href"storeActive.html">到店活动</a></li><li><a href"fuwu.html">服务</a></li><li><a href"store.html">门店</a></l…

线程池实现及参数详解

线程池概述 Java线程池是一种池化技术&#xff0c;用于管理和复用线程&#xff0c;减少线程创建和销毁的开销&#xff0c;提高系统性能。Java通过java.util.concurrent包提供了强大的线程池支持。 线程池参数详解 1. 核心参数 // 创建线程池的完整构造函数 ThreadPoolExecu…

K8S 部署 NFS Dynamic Provisioning(动态存储供应)

K8S 部署 NFS Dynamic Provisioning&#xff08;动态存储供应&#xff09; 本文档提供完整的 K8s NFS 动态存储部署流程&#xff0c;包含命名空间创建、RBAC 权限配置、Provisioner 部署、StorageClass 创建及验证步骤。 2. 部署步骤 2.1 创建命名空间 首先创建独立的命名空间 …

JavaEE 进阶第二期:开启前端入门之旅(二)

专栏&#xff1a;JavaEE 进阶跃迁营 个人主页&#xff1a;手握风云 目录 一、VS Code开发工具的搭建 1.1. 创建.html文件 1.2. 安装插件 1.3. 快速生成代码 二、HTML常见标签 2.1. 换行标签 2.2. 图片标签: img 2.3. 超链接 三、表格标签 四、表单标签 4.1. input标…

【RNN-LSTM-GRU】第二篇 序列模型原理深度剖析:从RNN到LSTM与GRU

本文将深入探讨循环神经网络&#xff08;RNN&#xff09;的核心原理、其面临的长期依赖问题&#xff0c;以及两大革命性解决方案——LSTM和GRU的门控机制&#xff0c;并通过实例和代码帮助读者彻底理解其工作细节。1. 引言&#xff1a;时序建模的数学本质在上一篇概述中&#x…

Qt---状态机框架QState

QState是Qt状态机框架&#xff08;Qt State Machine Framework&#xff09;的核心类&#xff0c;用于建模离散状态以及状态间的转换逻辑&#xff0c;广泛应用于UI交互流程、设备状态管理、工作流控制等场景。它基于UML状态图规范设计&#xff0c;支持层次化状态、并行状态、历史…

GitHub 热榜项目 - 日榜(2025-09-02)

GitHub 热榜项目 - 日榜(2025-09-02) 生成于&#xff1a;2025-09-02 统计摘要 共发现热门项目&#xff1a;14 个 榜单类型&#xff1a;日榜 本期热点趋势总结 本期GitHub热榜呈现AI Agent生态爆发趋势&#xff0c;Koog、Activepieces等项目推动多平台智能体开发框架成熟。语…