​一、需求描述

实战四分为三部分来实现,第二部分是基于PyTorch的猫狗图像可视化训练的教程,实现了一个完整的猫狗分类模型训练流程,使用预训练的ResNet50模型进行迁移学习,并通过SwanLab进行实验跟踪。

效果图

​二、实现思路

总体思路

  1. 导入和初始化配置:设置训练超参数(学习率、批次大小、训练轮数等);
  2. 加载数据集:读取自定义数据集,并设置数据加载器;
  3. 模型构建:加载预训练的ResNet50模型,并修改全连接层适配二分类任务;
  4. 训练配置:定义交叉熵损失函数,设置Adam优化器;
  5. 模型训练:循环遍历训练轮次,在每轮次遍历每个批次的数据,并实时打印训练进度及记录损失值到SwanLab。

2.1 导入和初始化配置

import swanlab
num_epochs=20
lr=1e-4
batch_size=8
num_classes=2
device="cuda"swanlab.init(experiment_name="模型训练实验",description="猫狗分类",mode="local",config={"model":"resnet50","optim":"Adam","lr":lr,"batch_size":batch_size,"num_epochs":num_epochs,"num_class":num_classes,"device":device,}
)
  • import swanlab - 导入SwanLab库,用于实验跟踪和可视化
  • num_epochs=20 - 设置训练轮数为20轮
  • lr=1e-4 - 设置学习率为0.0001
  • batch_size=8 - 设置批次大小为8
  • num_classes=2 - 设置分类类别数为2(猫和狗)
  • device="cuda" - 设置使用GPU进行训练
  • swanlab.init() - 初始化SwanLab实验,记录实验配置参数

2.2 加载数据集

import readDataset
from torch.utils.data import DataLoader
train_dataset=readDataset.DatasetLoader(readDataset.ds_train)
train_loader=(DataLoader(train_dataset,batch_size=batch_size,shuffle=True))
  • import readDataset - 导入自定义的数据集读取模块
  • from torch.utils.data import DataLoader - 导入PyTorch的数据加载器
  • train_dataset=readDataset.DatasetLoader(readDataset.ds_train) - 创建训练数据集对象
  • train_loader=(DataLoader(train_dataset,batch_size=batch_size,shuffle=True)) - 创建数据加载器,设置批次大小并启用随机打乱

2.3 模型构建

import torch
import torchvision
from torchvision.models import ResNet50_Weightsmodel=torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
in_features=model.fc.in_features
model.fc=torch.nn.Linear(in_features,num_classes)
model.to(device)
  • import torch - 导入PyTorch深度学习框架
  • import torchvision - 导入计算机视觉库
  • from torchvision.models import ResNet50_Weights - 导入ResNet50预训练权重
  • model=torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - 加载预训练的ResNet50模型
  • in_features=model.fc.in_features - 获取全连接层的输入特征数
  • model.fc=torch.nn.Linear(in_features,num_classes) - 替换最后的全连接层,输出类别数为2
  • model.to(device) - 将模型移动到GPU设备

2.4 训练配置

criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=lr)
  • criterion=torch.nn.CrossEntropyLoss() - 定义交叉熵损失函数,适用于多分类问题
  • optimizer=torch.optim.Adam(model.parameters(),lr=lr) - 定义Adam优化器,设置学习率

2.5 模型训练

for epoch in range(num_epochs):model.train()for iter,(inputs,labels) in enumerate(train_loader):inputs,labels=inputs.to(device),labels.to(device)optimizer.zero_grad()outputs=model(inputs)loss=criterion(outputs,labels)loss.backward()optimizer.step()print('Epoch[{}/{}],Iteration[{}/{}],Loss:{:.4f}'.format(epoch+1,num_epochs,iter+1,len(train_loader),loss.item()))swanlab.log({"train_loss":loss.item()})
  • for epoch in range(num_epochs): - 外层循环,遍历每个训练轮次
  • model.train() - 设置模型为训练模式
  • for iter,(inputs,labels) in enumerate(train_loader): - 内层循环,遍历每个批次的数据
  • inputs,labels=inputs.to(device),labels.to(device) - 将输入数据和标签移动到GPU
  • optimizer.zero_grad() - 清空梯度
  • outputs=model(inputs) - 前向传播,获取模型预测结果
  • loss=criterion(outputs,labels) - 计算损失
  • loss.backward() - 反向传播,计算梯度
  • optimizer.step() - 更新模型参数
  • print(...) - 打印训练进度和损失值
  • swanlab.log({"train_loss":loss.item()}) - 记录损失值到SwanLab实验跟踪系统

三、完整代码

import swanlab
num_epochs=20
lr=1e-4
batch_size=8
num_classes=2
device="cuda"swanlab.init(experiment_name="模型训练实验",description="猫狗分类",mode="local",config={"model":"resnet50","optim":"Adam","lr":lr,"batch_size":batch_size,"num_epochs":num_epochs,"num_class":num_classes,"device":device,}
)import readDataset
from torch.utils.data import DataLoader
train_dataset=readDataset.DatasetLoader(readDataset.ds_train)
train_loader=(DataLoader(train_dataset,batch_size=batch_size,shuffle=True))import torch
import torchvision
from torchvision.models import ResNet50_Weightsmodel=torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
in_features=model.fc.in_features
model.fc=torch.nn.Linear(in_features,num_classes)
model.to(device)
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=lr)for epoch in range(num_epochs):model.train()for iter,(inputs,labels) in enumerate(train_loader):inputs,labels=inputs.to(device),labels.to(device)optimizer.zero_grad()outputs=model(inputs)loss=criterion(outputs,labels)loss.backward()optimizer.step()print('Epoch[{}/{}],Iteration[{}/{}],Loss:{:.4f}'.format(epoch+1,num_epochs,iter+1,len(train_loader),loss.item()))swanlab.log({"train_loss":loss.item()})

四、效果展示

  • PyCharm运行日志
    在这里插入图片描述
  • PyCharm终端日志
    在这里插入图片描述
  • SwanLab工作区
    在这里插入图片描述
  • 模拟训练实验的概览
    在这里插入图片描述
  • 模拟训练实验的实验图表
    在这里插入图片描述
  • 模拟训练实验的日志
    在这里插入图片描述
  • 模拟训练实验的实验环境
    在这里插入图片描述

五、问题与解决

问题一:ModuleNotFoundError: No module named ‘XXX’
解决一:pip install XXX

pip install 'swanlab[dashboard]'

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/89094.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/89094.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/89094.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

对比几个测试云的一些速度

最近被hosting vps主机的速度给困扰了&#xff0c;干脆放下手中的活 测试下 test.php放在网站根目录即可 代码如下&#xff1a; <?php /*** 最终版服务器性能测试工具* 测试项目&#xff1a;CPU运算性能、内存读写速度、硬盘IO速度、网络下载速度*/// 配置参数&#xff…

UE5 Grid3D 学习笔记

一、Neighbor Grid 3D 的核心作用 NeighborGrid3D 是一种基于位置的哈希查找结构&#xff0c;将粒子按空间位置划分到网格单元&#xff08;Cell&#xff09;中&#xff0c;实现快速邻近查询&#xff1a; 空间划分&#xff1a;将模拟空间划分为多个三维网格单元&#xff08;Cel…

Spring AI ——在springboot应用中实现基本聊天功能

文章目录 前言测试环境项目构建依赖引入指定openai 相关配置基于 application.yml 配置 Open AI 属性application.yml编写测试类测试请求基于读取后配置请求编写测试接口测试效果展示流式输出前言 AI 技术越来越火爆,作为Java开发人员也不能拖了后腿。 前段时间使用LangChain…

条件概率:不确定性决策的基石

条件概率是概率论中的核心概念&#xff0c;用于描述在已知某一事件发生的条件下&#xff0c;另一事件发生的概率。它量化了事件之间的关联性&#xff0c;是贝叶斯推理、统计建模和机器学习的基础。 本文由「大千AI助手」原创发布&#xff0c;专注用真话讲AI&#xff0c;回归技术…

搭建Flink分布式集群

1. 基础环境&#xff1a; 1.1 安装JDK 本次使用 jdk-11.0.26_linux-x64_bin.tar.gz 解压缩 tar -zxvf jdk-11.0.26_linux-x64_bin.tar.gz -C /usr/local/java/ 配置环境变量&#xff1a; vi /etc/profileJAVA_HOME/usr/local/java/jdk-11.0.26 CLASSPATH.:${JAVA_HOME}/li…

基于ssm校园综合服务系统微信小程序源码数据库文档

摘 要 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;校园综合服务被用户普遍使用&#xff0c;为方便用户能够可…

桌面小屏幕实战课程:DesktopScreen 17 HTTPS

飞书文档http://https://x509p6c8to.feishu.cn/docx/doxcn8qjiNXmw2r3vBEdc7XCBCh 源码参考&#xff1a; /home/kemp/work/esp/esp-idf/examples/protocols/https_request 源码下载方式参考&#xff1a; 源码下载方式 获取网站ca证书 openssl s_client -showcerts -connec…

uniapp上传gitee

右键点击项目&#xff0c;选择git提交&#xff0c;会弹出这样的弹窗 在Message输入框里面输入更新的内容&#xff0c;选择更新过的文件&#xff0c;然后点击commit 然后点击push 后面会让你填写gitee的用户名和密码 用户名就是邮箱 密码就是登录gitee的密码

重写(Override)与重载(Overload)深度解析

在Java面向对象编程中&#xff0c;多态性是一个核心概念&#xff0c;它允许我们以统一的方式处理不同类型的对象。而实现多态性的两种重要机制便是方法的“重写”&#xff08;Override&#xff09;与“重载”&#xff08;Overload&#xff09;。透彻理解这两者之间的区别与联系…

Go 语言中操作 SQLite

sqlite以其无需安装和配置&#xff1a;直接使用数据库文件&#xff0c;无需启动独立的数据库服务进程。 单文件存储&#xff1a;整个数据库&#xff08;包括表、索引、数据等&#xff09;存储在单个跨平台文件中&#xff0c;便于迁移和备份。 在应对的小型应用软件中.有着不可…

【硬核数学】2.3 AI的“想象力”:概率深度学习与生成模型《从零构建机器学习、深度学习到LLM的数学认知》

欢迎来到本系列的第八篇文章。在前七章中&#xff0c;我们已经构建了一个强大的深度学习工具箱&#xff1a;我们用张量来处理高维数据&#xff0c;用反向传播来高效地计算梯度&#xff0c;用梯度下降来优化模型参数。我们训练出的模型在分类、回归等任务上表现出色。 但它们有…

华为云Flexus+DeepSeek征文|Dify平台开发搭建口腔牙科24小时在线问诊系统(AI知识库系统)

引言&#xff1a;为什么需要口腔牙科24小时在线问诊系统&#xff1f; 在口腔医疗领域&#xff0c;“时间”是患者最敏感的需求之一——深夜牙齿突发疼痛、周末想提前了解治疗方案、异地患者无法及时到院……传统“工作时间在线”的咨询模式已无法满足用户需求。同时&#xff0…

嵌入式硬件中电容的基本原理与详解

大家好我们今天重讨论点知识点如下: 1.电容在电路中的作用 2.用生活中水缸的例子来比喻电容 3.电容存储能力原理 4.电容封装的种类介绍电容种类图片辨识 5.X 电容的作用介绍 6.Y 电容的作用介绍7.钽电容的优点及特性 7.钽电容的缺点及特性 8. 铝电解电容的优点及特性…

中央空调控制系统深度解析:从原理到智能AIOT运维

——附水冷式系统全电路图解与技术参数 一、中央空调系统架构与技术演进 1. 两大主流系统对比 技术趋势&#xff1a;2023年全球冷水机组市场占比达68%&#xff08;BSRIA数据&#xff09;&#xff0c;其核心优势在于&#xff1a; - 分区控温精度&#xff1a;0.5℃&#…

document.write 和 innerHTML、innerText 的区别

document.write 与 innerHTML、innerText 的区别 document.write 直接写入 HTML 文档流&#xff0c;若在页面加载完成后调用会覆盖整个文档。常用于动态生成内容&#xff0c;但会破坏现有 DOM 结构&#xff0c;不推荐在现代开发中使用。 document.write("<p>直接写…

日志分析与实时监控:Elasticsearch在DevOps中的核心作用

引言 在现代DevOps实践中&#xff0c;日志分析与实时监控是保障系统稳定性与性能的关键。Elasticsearch作为分布式搜索与分析引擎&#xff0c;凭借其高效的索引与查询能力&#xff0c;成为构建日志管理与监控系统的核心组件。本文将深入探讨Elasticsearch在DevOps中的应用&…

Unity Catalog 三大升级:Data+AI 时代的统一治理再进化

在刚刚落幕的 2025 Databricks Data AI Summit 上&#xff0c;Databricks 重磅发布了多项 Lakehouse 相关功能更新。其中&#xff0c;面向数据湖治理场景的统一数据访问与管理方案 —— Unity Catalog&#xff0c;迎来了三大关键升级&#xff1a;全面支持 Apache Iceberg、面向…

电容屏触摸不灵敏及跳点问题分析

在电容屏的使用过程中&#xff0c;触摸不灵敏和触点不精准是极为常见且让人困扰的问题。这些问题不仅影响用户的操作体验&#xff0c;在一些对触摸精度要求较高的场景&#xff0c;如工业控制、绘图设计等领域&#xff0c;还可能导致严重的后果。下面我们就来深入剖析一下这两个…

小程序学习笔记:导航、刷新、加载、生命周期

在小程序开发的领域中&#xff0c;掌握视图与逻辑相关的技能是打造功能完备、用户体验良好应用的关键。今天&#xff0c;咱们就来深入梳理一下小程序视图与逻辑的学习要点&#xff0c;并结合代码示例&#xff0c;让大家有更直观的理解。 一、页面之间的导航跳转 在小程序里实…

生成树基础实验

以太网交换网络中为了进行链路备份&#xff0c;提高网络可靠性&#xff0c;通常会使用冗余链路。但是使用冗余链路会在交换网络上产生环路&#xff0c;引发广播风暴以及 MAC地址表不稳定等故障现象&#xff0c;从而导致用户通信质量较差&#xff0c;甚至通信中断。 为解决交换…