文章目录

  • 一、Dataset 与 DataLoader 功能介绍
    • 抽象类Dataset的作用
    • DataLoader 作用
    • 两者关系
  • 二、自定义Dataset类
    • Dataset的三个重要方法
      • `__len__()`方法
      • `_getitem__()`方法
      • `__init__` 方法
  • 三、现成的torchvision.datasets模块
    • MNIST举例
    • COCODetection举例
    • `torchvision.datasets.MNIST`使用举例
    • `torchvision.datasets.CocoDetection`举例

一、Dataset 与 DataLoader 功能介绍

抽象类Dataset的作用

简单来说,就是将原始数据(可能是图片、文本、音频等各种格式)整理成模型可以处理的格式,为后续的数据加载和处理做准备。功能是定义数据集的基本属性数据获取方式

  • 初始化数据路径:在Dataset类的__init__方法中,通常会初始化数据存放的路径,以及一些数据预处理的操作,比如指定图片数据集图片所在文件夹路径,文本数据集文本文件路径等 。包含 加载数据/读取数据、预处理数据、图像增强 等一系列操作
  • 获取单个样本及其标签:通过实现__getitem__方法,根据给定的索引(dataloader返回的),返回相应的数据样本和对应的标签。例如在图片分类任务中,给定索引后,返回该索引对应的图片数据(经过预处理,如调整尺寸、归一化等)以及图片的类别标签。
  • 统计样本数量:通过实现__len__方法,返回数据集中样本的总数,方便在训练和评估过程中知道数据规模 。

DataLoader 作用

DataLoader是在Dataset的基础上,提供了一种更加高效、便捷地加载数据的方式,它可以将Dataset返回的单个样本,按照指定的方式进行打包(如组成batch)、打乱顺序等操作,从而满足模型训练和评估的需求。

  • 创建数据批次,指定数据打包输出规则:通过batch_size参数,将Dataset中的单个样本打包成一个个批次(batch)的数据。

    • collate_fn指定如何从NNN张训练集选出一个batch的Nbatch_size\frac{N}{batch\_size}batch_sizeN张图片。
    • 例如batch_size=32,那么DataLoader每次会从Dataset中取出32个样本组成一个batch。每次迭代,返回的是 一个batch 的数据
  • 自定义数据采样,指定数据迭代读取规则:

    • 一般使用自定义的采样器(Sampler),实现对数据的特殊采样方式,比如分层采样(在类别不均衡的数据集中,保证每个batch中各类别的样本比例与原始数据集相似)等。
    • dataset对象是dataloader的一个参数,通过dataset让dataloader知道训练集一共多少图片,从而知道共跌代多少次。
  • 数据打乱:通过shuffle参数设置是否在每个epoch开始时打乱数据顺序,这样可以避免模型在训练时对数据产生特定的依赖,有助于模型学习到更通用的特征,提高模型的泛化能力 。

  • 多进程加载:通过num_workers参数设置多进程加载数据,从而加快数据加载速度,尤其是在数据量较大、数据预处理较为复杂的情况下,多进程可以充分利用CPU资源,减少数据加载时间,避免数据加载成为训练过程中的瓶颈 。

两者关系

  • Dataset是数据的基础容器,定义了如何获取数据集中的单个样本;

  • DataLoader则是Dataset的上层应用,负责按照特定规则(如批量处理、打乱顺序等)从Dataset中高效地加载数据,供模型进行训练、验证和测试等操作。

  • 可以说,Dataset是数据的来源和基本操作接口,DataLoader则是为了更好地适配模型训练需求,对Dataset的数据进行进一步处理和组织的工具。

二、自定义Dataset类

所谓的 自定义 dataset ,即自己去写一个 Dataset 类,要满足两个要求:

  • 一般需要继承自 torch.utils.data.Dataset
    • 继承 torch.utils.data.Dataset 主要目的是为了与 DataLoader 保持兼容,确保数据集遵循 DataLoader 的接口标准,方便后续使用 PyTorch 提供的工具,比如 :批量加载、打乱数据、并行处理等功能
  • 并且满足和DataLoader进行交互的规范 :
    • 因为DataLoader会调用 Datasetlen()getitem() 方法,所以自定义 Dataset 类必须实现这两个方法,如此才能保证 DataLoader 可以正确地加载和操作你的数据集
  • 兼容训练和推理阶段

Dataset的三个重要方法

创建自定义 Dataset类时,必须实现的3个方法 :__init__()__len__()__getitem__()
这些方法定义了数据集的基本结构和行为,也是 DataLoader 可以正确的从 Dataset 中读取数据的基础。

__len__()方法

DataLoader是通过Dataset的 __len__(),得知训练集一共多少数据样本的。

def __len__(self):return len(self.file_list)
  • 返回值:数据集中的样本的总数。
  • 作用:
    • 方便通过调用 len(dataset) 来获取数据量,其中 dataset 为 Dataset 对象
    • Dataloader 会用它和 batch_size 一起来计算一个 epoch 要迭代多少个 steps:
      steps=len(dataset)batch_sizesteps = \frac{len(dataset)}{batch\_size}steps=batch_sizelen(dataset)
    • DataLoader调用len方法的代码封装在源码了,所以看不到显式调用。DataLoader得到一共NNN个数据样本后,生成000 ~ N−1N-1N1的索引。再根据batch_size和是否打乱,生成一个batch的索引列表,再将每个索引idx传入到Dataset的_getitem__()方法中返回得到图片和索引return image, label

_getitem__()方法

作用: 根据给定的索引返回数据集中的一个样本。这是用于获取数据集中单个样本的方法。

def __getitem__(self, idx):# 通过索引idx,获取图片地址img_nameimg_name = os.path.join(self.data_folder, self.file_list[idx])# 根据图片地址img_name读取对应图像original_imageoriginal_image = Image.open(img_name)# 通过索引idx获取图片对应的标签(这里举的例子的标签含在图片名中)label = img_name.split('_')[-1].split('.')[0]# 图像预处理和数据增强(仅训练阶段)if self.train:image = self.transform(original_image)else:image = self.transform(original_image)# 返回处理好的一张图像和标签return image, label
  • 接收参数: index(idx)是单个数据样本的索引,由DataLoader传来的
  • 返回值: 返回数据集中索引指定的样本。通常是一个包含输入数据和对应标签的元组。这里可以根据自己的需求,进行自定义。

DataLoader返回的是一个batch的数据,具体是:

  • DataLoader的采样器sampler根据数据总量和batch_size=2,和采样方法(举例为顺序采样)得到第一次迭代结果为索引列表[0, 1]
  • DataLoader分别把索引0和1给Dataset,__getitem__()方法返回出对应单个索引的图片和标签。
  • 把得到的一个batch的两组图片和标签给collate_fn函数进行打包并以一种数据结构储存,由DataLoader返回

__init__ 方法

  • 参数: 根据需要传递一些参数,例如文件路径、数据转换等。
  • 作用: 构造方法,配好len和getitem方法做一些初始化工作,需要什么数据,就传入进来赋值到成员属性。
def __init__(self, data_folder, train, transform=None):self.data_folder = data_folderself.transform = transformself.file_list = os.listdir(data_folder)# 把文件名读取出来,存入到file_list,方便len方法获取数据量self.train = train

例如:设置文件路径selfl.data_folder、定义数据转换的transforms、当前是训练阶段还是验证阶段的布尔值train等。

三、现成的torchvision.datasets模块

对于一些公开的数据集,可以直接用torchvision.datasets模块的现成的Dataset类。

Pytorch官方文档的torchvision的Dataset列出了可使用的数据集的Dataset,实现了getitem和len方法

在这里插入图片描述

MNIST举例

这里以Image classification任务的MNIST(mixed national institute of standards and technology database)数据集举例,点入详情页课查看:
在这里插入图片描述

在这里插入图片描述

train_dataset = torchvision.datasets.MNIST(root,    train=True,               transform=None,  target_transform= None  download=True)

参数:

  • root:数据集存放的路径
  • download:是否下载数据集,默认为False 。配合root参数:
    • 若设置download=True
      • root目录下没有该数据集,数据集将会被下载到root指定的位置。
      • root目录下已经存在该数据集,则不会重新下载,而是会直接使用已存在的数据,以节省时间
    • 若设置download=False,程序将会在root指定的位置查找数据集,如果数据集不存在,则会抛出错误。
  • train
    • 如果是True,下载训练集trainin.pt
    • 如果是False,下载测试集test.pt。默认是True
  • transform:接收torchvision.transforms的对象,一系列作用在PIL图片上的转换操作,用于对数据集的图像预处理和数据增强。
  • target_transform:对target处理,一般不用。因为出来target出来一般用自定义的Dataset,因为图像处理和target处理要放一个transform里写

COCODetection举例

Image detection任务的COCO数据集
在这里插入图片描述

注意:对于一部分数据集比如torchvision.datasets.CocoDetection,Pytorch不提供下载功能 (具体情况取决于数据集的来源和许可协议),就没有download参数。
所以在使用 torchvision.datasets.CocoDetection 这个现成的Dataset类之前,需要确保已经下载并淮备好COCO数
据集的图像和标注文件。然后使用torchvision.datasets.CocoDetection 类来加载 COCO数据集。

torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None, transforms=None)
  • root:指定图片地址(本地已经下载下来的图像地址)
  • annFile:指定标注文件地址(本地已经下载下来的标注文件地址)
  • transform:图像处理 (用于PIL)
  • target_transform:标注处理
  • transforms:图像和标注的处理

torchvision.datasets.MNIST使用举例

训练集和验证集分别实例化一个Dataset类(torchvision.datasets.MNIST)的对象,传入的transforms参数都为实例化的transforms.Compose对象my_transform。数据集下载到当下文件所在目录下。

import torchvision
from torchvision.transforms import transforms
import torch.utils.data as data
import matplotlib.pyplot as pltbatch_size = 5# transforms.Compose的对象,传入到transforms参数
my_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],  # mean=[0.485, 0.456, 0.406]std=[0.5])])  # std=[0.229, 0.224, 0.225]train_dataset = torchvision.datasets.MNIST(root="./",train=True,transform=my_transform,download=True)val_dataset = torchvision.datasets.MNIST(root="./",train=False,transform=my_transform,download=True)

在这里插入图片描述

  • 可以看的在当下目录下出现了一个MNIST文件夹,
  • .gz后缀的是下载的压缩文件,程序自动解压为同名的二进制文件
  • Dataset会自动处理好二进制文件,最终从DataLoader跌代出来的是正常的单通道灰度图。

将定义出的训练集和验证集的Dataset对象,分别作为参数传入到两个DataLoader,得到两个DataLoader对象

train_loader = data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)val_loader = data.DataLoader(val_dataset,batch_size=batch_size,shuffle=True)

分别调用量Dataset的len方法,输出数据量。再将train_loader转换为迭代器iter(train_loader),通过next方法得到一个batch的image和label。
打印出一个batch的image的shape。[5, 1, 28, 28]分别指batch_size,图片通道数,图像长宽。
打印出标签label列表。
最后可视化一个batch的图和标签。

print(len(train_dataset))
print(len(val_dataset))image, label = next(iter(train_loader))
print(image.shape)
print(label)for i in range(batch_size):plt.subplot(1, batch_size, i + 1)plt.title(label[i].item())plt.axis("off")plt.imshow(image[i].permute(1, 2, 0))plt.show()

在这里插入图片描述

torchvision.datasets.CocoDetection举例

需要把数据集的下载地址换掉,换成你的 COCO数据集地址

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.transforms import functional as F
import randomdef collate_fn_coco(batch):return tuple(zip(*batch))coco_det = datasets.CocoDetection(root="./COCO2017/train2017",annFile="./COCO2017/annotations/instances_train 2017.json")sampler = torch.utils.data.SequentialSampler(coco_det)  # RandomSampler
batch_sampler = torch.utils.data.BatchSampler(sampler, 1, drop_last=True)
data_loader = torch.utils.data.DataLoader(coco_det,batch_sampler=batch_sampler,collate_fn=collate_fn_coco)# 可视化
iterator = iter(data_loader)
imgs, gts = next(iterator)
img,  gts_one_img = imgs[0], gts[0]bboxes = []
ids = []
for gt in gts_one_img:bboxes.append([gt['bbox'][0],gt['bbox'][1],gt['bbox'][2],gt['bbox'][3]])ids.append(gt['category_id'])fig, ax = plt.subplots()
for box, id in zip(bboxes, ids):x = int(box[0])y = int(box[1])w = int(box[2])h = int(box[3])rect = plt.Rectangle((x, y), w, h, edgecolor='r', linewidth=2, facecolor='none')ax.add_patch(rect)ax.text(x, y, id, backgroundcolor="r")plt.axis("off")
plt.imshow(img)
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/94476.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/94476.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/94476.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python爬虫实战:研究python_reference库,构建技术研究数据系统

1. 引言 1.1 研究背景与意义 在大数据时代,数据已成为重要的生产要素。互联网作为全球最大的信息库,蕴含着海量有价值的数据。如何从纷繁复杂的网络信息中快速、准确地提取所需数据,成为各行各业面临的重要课题。网络爬虫技术作为数据获取的关键手段,能够模拟人类浏览网页…

Web开发系列-第15章 项目部署-Docker

第15章 项目部署-Docker Docker技术能够避免部署对服务器环境的依赖,减少复杂的部署流程。 轻松部署各种常见软件、Java项目 参考文档:‌‬‌‍‍‌‍⁠⁠‍‍‬‌‍‌‬⁠‌‬第十五章:…

微软无界鼠标(Mouse without Borders)安装及使用:多台电脑共用鼠标键盘

文章目录一、写在前面二、下载安装1、两台电脑都下载安装2、被控端3、控制端主机三、使用一、写在前面 在办公中,我们经常会遇到这种场景,自己带着笔记本电脑外加公司配置的台式机。由于两台电脑,所以就需要搭配两套键盘鼠标。对于有限的办公…

nodejs 编程基础01-NPM包管理

1:npm 包管理介绍 npm 是nodejs 的包管理工具,类似于java 的maven 和 gradle 等,用来解决nodejs 的依赖包问题 使用场景:1. 从NPM 服务骑上下载或拉去别人编写好的第三方包到本地进行使用2. 将自己编写代码或软件包发布到npm 服务器供他人使用…

基于Mediapipe_Unity_Plugin实现手势识别

GitHub - homuler/MediaPipeUnityPlugin: Unity plugin to run MediaPipehttps://github.com/homuler/MediaPipeUnityPlugin 实现了以下: public enum HandGesture { None, Stop, ThumbsUp, Victory, OK, OpenHand } 核心脚本&#xff1a…

Android 项目构建编译概述

主要内容是Android AOSP源码的管理方式,项目源码的构建和编译,用到比如git、repo、gerrit一些命令工具,以及使用Soong编译系统,编写Android.bp文件的格式样式。 1. Android操作系统堆栈概述 Android 是一个针对多种不同设备类型打…

Python爬虫08_Requests聚焦批量爬取图片

一、Requests聚焦批量爬取图片 import re import requests import os import timeurl https://www.douban.com/ userAgent {User-Agent:Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:122.0) Gecko/20100101 Firefox/122.0}#获取整个浏览页面 page_text requests.get(urlur…

Spring Cloud系列—简介

目录 1 单体架构 2 集群与分布式 3 微服务架构 4 Spring Cloud 5 Spring Cloud环境和工程搭建 5.1 服务拆分 5.2 示例 5.2.1 数据库配置 5.2.2 父子项目创建 5.2.3 order_service子项目结构配置 5.2.4 product_service子项目结构配置 5.2.5 服务之间的远程调用 5.…

【普中STM32精灵开发攻略】--第 1 章 如何使用本攻略

学习本开发攻略主要参考的文档有《STM32F1xx 中文参考手册》和《Cortex M3权威指南(中文)》,这两本都是 ST 官方手册,尤其是《STM32F1xx 中文参考手册》,里面包含了 STM32F1 内部所有外设介绍,非常详细。大家在学习 STM32F103的时…

【Docker】RK3576-Debian上使用Docker安装Ubuntu22.04+ROS2

1、简述 RK3576自带Debian12系统,如果要使用ROS2,可以在Debian上直接安装ROS2,缺点是有的ROS包需要源码编译;当然最好是使用Ubuntu系统,可以使用Docker安装,或者构建Ubuntu系统,替换Debian系统。 推荐使用Docker来安装Ubuntu22.04,这里会有个疑问,是否可以直接使用Do…

解决docker load加载tar镜像报json no such file or directory的错误

在使用docker加载离线镜像文件时,出现了json no such file or directory的错误,刚开始以为是压缩包拷贝坏了,重新拷贝了以后还是出现了问题。经过网上查找方案,并且自己实践,采用下面的简单方法就可以搞定。 归结为一句…

《协作画布的深层架构:React与TypeScript构建多人实时绘图应用的核心逻辑》

多人在线协作绘图应用的构建不仅是技术栈的简单组合,更是对实时性、一致性与用户体验的多维挑战。基于React与TypeScript开发这类应用,需要在图形绘制的基础功能之外,解决多用户并发操作的同步难题、状态回溯的逻辑冲突以及大规模协作的性能瓶颈。每一层架构的设计,都需兼顾…

智慧社区(八)——社区人脸识别出入管理系统设计与实现

在社区安全管理日益智能化的背景下,传统的人工登记方式已难以满足高效、精准的管理需求。本文将详细介绍一套基于人脸识别技术的社区出入管理系统,该系统通过整合腾讯云 AI 接口、数据库设计与业务逻辑,实现了居民出入自动识别、记录追踪与访…

嵌入式开发学习———Linux环境下IO进程线程学习(四)

进程相关函数fork创建一个子进程,子进程复制父进程的地址空间。父进程返回子进程PID,子进程返回0。pid_t pid fork(); if (pid 0) { /* 子进程代码 */ } else { /* 父进程代码 */ }getpid获取当前进程的PID。pid_t pid getpid();getppid获取父进程的P…

标记-清除算法中的可达性判定与Chrome DevTools内存分析实践

引言 在现代前端开发中,内存管理是保证应用性能与用户体验的核心技术之一。作为JavaScript运行时的基础机制,标记-清除算法(Mark-and-Sweep) 通过可达性判定决定哪些内存需要回收,而Chrome DevTools提供的Memory工具则为开发者提供了深度的内…

微算法科技(NASDAQ:MLGO)基于量子重加密技术构建区块链数据共享解决方案

随着信息技术的飞速发展,数据已成为数字经济时代的核心生产要素。数据的共享和安全往往是一对难以调和的矛盾。传统的加密方法在面对日益强大的计算能力和复杂的网络攻击时,安全性受到了挑战。微算法科技(NASDAQ:MLGO)通过引入量子重加密技术…

FastAPI快速入门P2:与SpringBoot比较

欢迎来到啾啾的博客🐱。 记录学习点滴。分享工作思考和实用技巧,偶尔也分享一些杂谈💬。 有很多很多不足的地方,欢迎评论交流,感谢您的阅读和评论😄。 目录引言1 FastAPI事件管理2 类的使用2.1 初始化方法对…

SAP-ABAP: Open SQL集合函数COUNT(统计行数)、SUM(数值求和)、AVG(平均值)、MAX/MIN(极值)深度指南

SAP Open SQL集合函数深度指南 1. 核心价值与特性函数作用关键特性COUNT统计行数用COUNT(*)包含NULL值行,COUNT(字段)排除NULLSUM数值求和自动过滤NULL值,结果类型与源字段相同AVG平均值必须用TYPE f接收,否则四舍五入导致精度丢失MAX/MIN极值…

【docker】UnionFS联合操作系统

Linux 的 Namespace、CGroups 和 UnionFS 三大技术支撑了 Docker 的实现。 一、为什么需要联合文件系统?在传统操作系统中,每个文件系统都是独立的孤岛。但当我们需要:合并多个目录的内容保持基础系统不变的同时进行修改高效共享重复文件内容…

CTF-XXE 漏洞解题思路总结

一、XXE 漏洞简介XXE (XML External Entity) 漏洞允许攻击者通过构造恶意的 XML 输入,强迫服务器的 XML 解析器执行非预期的操作。在 CTF 场景中,最常见的利用方式是让解析器读取服务器上的敏感文件,并将其内容返回给攻击者。二、核心攻击载荷…