前言

处理数据样本的代码可能会变得杂乱无章且难以维护;为了获得更好的可读性和模块化,我们理想的情况是将数据集代码与模型训练代码解耦。PyTorch 提供了两个数据处理类:

torch.utils.data.DataLoadertorch.utils.data.Dataset,它们允许你使用预加载的数据集以及自己的数据。Dataset 存储样本及其相应的标签,而 DataLoaderDataset 周围包装一个可迭代对象,以便轻松访问样本。

PyTorch领域库提供了许多预加载的数据集(如FashionMNIST),这些数据集继承自 torch.utils.data.Dataset 并实现了特定于该数据的函数。它们可用于对模型进行原型设计和基准测试。你可以在以下位置找到它们:图像数据集、文本数据集 和 音频数据集

加载数据集

以下是一个如何从TorchVision加载 Fashion-MNIST 数据集的示例。时尚MNIST是Zalando的商品图像数据集,由60000个训练示例和10000个测试示例组成。每个示例包含一张28×28的灰度图像以及一个来自10个类别之一的关联标签。

我们使用以下参数加载时尚MNIST数据集:

  • root是存储训练/测试数据的路径
  • train 指定训练数据集或测试数据集,”
  • download=True 如果数据在root 中不可用,download=True 会从互联网下载数据。
  • transform and target_transform 分别指定特征和标签的变换
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

在这里插入图片描述

迭代和可视化数据集:

我们可以像列表一样手动对Datasets进行索引:training_data[index]。我们使用matplotlib来可视化训练数据中的一些样本。

labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx]figure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

在这里插入图片描述

创建自定义数据集

自定义数据集类必须实现三个函数:__init____len____getitem__。看一下这个实现;FashionMNIST图像存储在目录 img_dir 中,它们的标签则单独存储在CSV文件 annotations_file 中。

import os
import pandas as pd
from torchvision.io import decode_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = decode_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

__init__

__init__函数在实例化Dataset对象时运行一次。我们初始化包含图像的目录、注释文件以及两种变换(下一节将更详细介绍)。

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform

__len__

__len__函数返回我们数据集中样本的数量。

def __len__(self):return len(self.img_labels)

__getitem__

__getitem__函数会根据给定的索引idx从数据集中加载并返回一个样本。根据该索引,它会确定图像在磁盘上的位置,使用decode_image将其转换为张量,从self.img_labels中的CSV数据中检索相应的标签,对它们调用变换函数(如果适用),并以元组形式返回张量图像和相应的标签。

def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

使用DataLoaders为训练准备数据

Dataset 每次检索一个样本,获取我们数据集的特征和标签。在训练模型时,我们通常希望以 “小批量” 方式传递样本,在每个时期对数据进行重新洗牌以减少模型过拟合,并使用 Python 的 multiprocessing 来加速数据检索。

DataLoader 是一个可迭代对象,它通过简单的API为我们抽象了这种复杂性。

from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

遍历DataLoader

我们已将该数据集加载到DataLoader中,并可根据需要遍历数据集。下面的每次迭代都会返回一批train_featurestrain_labels(分别包含batch_size=64个特征和标签)。因为我们指定了shuffle=True,所以在遍历完所有批次后,数据会被打乱。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/918517.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/918517.shtml
英文地址,请注明出处:http://en.pswp.cn/news/918517.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JavaWeb 30 天入门:第七天 —— 异常处理机制

在前六天的学习中,我们掌握了 Java 的基础语法、面向对象核心特性、抽象类与接口等知识。今天我们将学习 Java 中的异常处理机制,这是保证程序健壮性的关键技术。在 JavaWeb 开发中,无论是用户输入错误、数据库连接失败还是网络异常&#xff…

编译器默认生成的c++类六大成员函数

编译器默认生成的c类六大成员函数 编译器默认生成的六大成员函数 当你定义一个空类时,例如: class Empty {};如果代码中没有显式定义任何成员函数,C编译器会在需要时(例如,代码中实际调用了这些函数)为你…

人工智能概念:常见的大模型微调方法

文章目录一、微调技术的底层逻辑1.1 预训练与微调的关系1.2 核心目标:适配任务与数据二、经典微调方法详解2.1 全量微调(Full Fine-Tuning)2.2 冻结层微调(Layer-Freezing Fine-Tuning)2.3 参数高效微调(Pa…

动态路由协议(一)

1. 动态路由 概述 静态路由在大网络里太麻烦(设备多、配置量大,拓扑变了还要手动改) 静态路由是由工程师手动配置和维护的路由条目,命令行简单明确,适用于小型或稳定的网络。静态路由有以下问题: 无法适…

LINUX812 shell脚本:if else,for 判断素数,创建用户

问题 [rootweb ~]# for((i2;i<n;i)) > if [ $n -ne $i ] && [ $((n%i)) -eq 0 ];then -bash: 未预期的符号 if 附近有语法错误 您在 /var/spool/mail/root 中有邮件 [rootweb ~]#[rootweb ~]# cat judgeprimeok.sh declare -i n read -p "please type the n…

游戏中角色持枪:玩家操控角色,角色转向时枪也要转向

角色持有枪&#xff0c;玩家&#xff08;你&#xff09;操控角色&#xff0c;那么&#xff0c;在角色转向时&#xff0c;枪也要转向。 先看看简单情况&#xff1a;假定角色只面向左或右方向&#xff0c;pygame中用这句来实现&#xff1a;pos self.facing * self.gun_offset s…

深度学习入门Day8:生成模型革命——从GAN到扩散模型

一、开篇&#xff1a;创造力的算法革命从昨天的Transformer到今天的生成模型&#xff0c;我们正从"理解"世界迈向"创造"世界。生成对抗网络(GAN)和扩散模型(Diffusion Model)代表了当前生成式AI的两大主流范式&#xff0c;它们让机器能够生成逼真的图像、音…

基于WRF-Chem的不同气溶胶的辐射效应的研究

前言目前我对于气溶胶辐射效应的理解就是设计敏感性实验&#xff0c;基础实验打开气溶胶参与辐射开关&#xff08;aer_ra_feedback&#xff09;&#xff0c;其他的实验则关闭气溶胶参与辐射过程开关&#xff0c;也有去掉某些气溶胶的影响&#xff0c;如黑碳&#xff08;BC&…

专题:2025人形机器人与服务机器人技术及市场报告|附130+份报告PDF汇总下载

原文链接&#xff1a;https://tecdat.cn/?p43583 当特斯拉Optimus在工厂里精准分拣电池&#xff0c;当普渡机器人在酒店完成跨楼层配送&#xff0c;一个万亿级的智能革命正在拉开序幕。服务机器人与人形机器人不再是实验室里的概念&#xff0c;而是正在重塑制造业、服务业的“…

JS 模块化与打包工具

一、模块化体系&#xff1a;ESM vs CJS 深入1.语法与静态性(1)ESM:静态语法&#xff0c;可被打包器做 Tree-shakingexport function play() {}export default ...import { play } from ./mod.js(2)CJS:运行时 require() , 分析能力弱&#xff0c;不利于 Tree-shaking2.Node 解析…

防御保护11

带宽管理 --- 设备对自身的流量进行管理和控制&#xff0c;去提供带宽保证、带宽限制等等功能。 带宽限制 带宽保证 连接数限制 应用场景 实现带宽管理 带宽通道 --- 定义了被管理对象所能使用的带宽资源 整体的保证带宽和最大带宽&#xff1b; SW1-SW2&#xff1a;VLAN 201 --…

[激光原理与应用-254]:理论 - 几何光学 - 自动对焦的原理

自动对焦&#xff08;Auto Focus, AF&#xff09;是现代光学系统&#xff08;如相机、手机摄像头、监控设备等&#xff09;的核心功能之一&#xff0c;其原理是通过检测成像面的清晰度或测量物体距离&#xff0c;驱动透镜组移动至最佳对焦位置。以下是自动对焦的详细原理及技术…

【Python办公】Mermaid代码转图片工具 - Tkinter GUI版本

目录 专栏导读 项目简介 功能特性 🎨 直观的图形界面 📝 代码编辑功能 🖼️ 图片生成与预览 💾 文件操作 ⚡ 性能优化 技术架构 核心技术栈 架构设计 安装与使用 环境要求 依赖安装 运行程序 使用步骤 代码示例 基本流程图 时序图 甘特图 核心代码解析 1. 主类结构 2. …

【Activiti】要点初探

Activiti 7.0.0配置 流程配置节点流程XML流程部署部署后会操作表&#xff1a;&#xff08;每部署一次增加一条记录&#xff09; ACT_RE_DEPLOYMENT 流程定义部署表 ACT_RE_PROCDEF 流程定义表 ACT_GE_BYTEARRAY 流程启动查看任务&#xff08;张三要查看准备办理任务&#xff0…

VBS 字符串处理

一. 字符串是由Unicode字符组成的一串字符。通常由数字&#xff0c;字母&#xff0c;符号组成。二. 常用函数1. 消除空格 Ltrim: 删除字符串左侧的空格。 Rtrim: 删除字符串右侧的空格。 trim: 删除字符串左侧和右侧的空格。a" hello " b"sx"msgbo…

《算法导论》第 21 章-用于不相交集合的数据结构

引言不相交集合&#xff08;Disjoint Set&#xff09;&#xff0c;也称为并查集&#xff08;Union-Find&#xff09;&#xff0c;是一种非常实用的数据结构&#xff0c;主要用于处理一些元素分组的问题。它支持高效的集合合并和元素查找操作&#xff0c;在很多算法中都有重要应…

基于51单片机RFID智能门禁系统红外人流量计数统计

1 系统功能介绍 本设计基于STC89C52单片机&#xff0c;集成RFID读卡器、红外避障传感器、继电器、LCD1602液晶显示和蜂鸣器&#xff0c;实现智能门禁与人流量统计功能。系统能够识别合法的RFID卡开门&#xff0c;并实时统计通过人数&#xff0c;具有安全报警和直观显示功能。具…

c#,vb.net全局多线程锁,可以在任意模块或类中使用,但尽量用多个锁提高效率

Public ReadOnly LockObj As New Object() 全局多线程锁 VB.NET模块中的LockObj 可以在任意模块或类中使用吧 在 VB.NET 中&#xff0c;模块&#xff08;Module&#xff09;中声明的 Public ReadOnly LockObj 可以被其他模块或类访问和使用&#xff0c;但需要注意其可见性范围…

企业安全运维服务计划书

安全运维服务计划书 一、概述 为保障企业信息系统安全、稳定、高效运行,防范各类网络安全风险,提升整体安全防护能力,特制定本安全运维服务计划书。本计划旨在通过系统化、规范化的安全运维流程,全面识别、评估、处置并持续监控企业网络环境中的安全风险,构建主动防御与…

小杰python高级(four day)——matplotlib库

1.绘制子图的方式pyplot中函数subplotFigure类中的函数add_subplotpyplot中函数subplotsfig, ax plt.subplots(nrows1, ncols1, *, sharexFalse, shareyFalse,squeezeTrue, subplot_kwNone, gridspec_kwNone, **fig_kw) 功能&#xff1a;绘制多个子图&#xff0c;可以一次生成…