PyTorch 的 DataLoader 是数据加载的核心组件,它能高效地批量加载数据并进行预处理。

Pytorch DataLoader基础概念

DataLoader基础概念
DataLoader是PyTorch基础概念
DataLoader是PyTorch中用于加载数据的工具,它可以:批量加载数据(batch loading)打乱数据(shuffling)并行加载数据(多线程)
自定义数据加载方式Dataloader的基本使用from torch.utils.data import Dataset, DataLoader

自定义数据集类

class MyDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __getitem__(self, index):return self.data[index], self.labels[index]def __len__(self):return len(self.data)

创建数据集实例

dataset = MyDataset(data, labels)

创建DataLoader

dataloader = DataLoader(dataset=dataset,      # 数据集batch_size=32,        # 批次大小shuffle=True,         # 是否打乱数据num_workers=4,        # 多进程加载数据的线程数drop_last=False       # 当样本数不能被batch_size整除时,是否丢弃最后一个不完整的batch
)
# 使用DataLoader迭代数据
for batch_data, batch_labels in dataloader:# 训练或推理代码pass

DataLoader重要参数详解

  1. dataset: 要加载的数据集,必须是Dataset类的实例 batch_size: 每个批次的样本数
  2. shuffle:是否在每个epoch重新打乱数据
  3. sampler:自定义从数据集中抽取样本的策略,如果指定了sampler,则shuffle必须为False
  4. num_workers:使用多少个子进程加载数据,0表示在主进程中加载。
  5. collate_fn:将一批数据整合成一个批次的函数,特别使用于处理不同长度的序列数据
  6. Pin_memory:如果为True,数据加载器会将张量复制到CUDA固定内存中,加速CPU到GPU的数据传输
  7. drop_last: 如果数据集大小不能被batch_size整除,是否丢弃最后一个不完整的批次。
  8. timeout:收集一个批次的超时值
  9. worker_init_fn:每个worker初始化时被调用的函数
  10. weight_sampler:参数决定是都使用加权采样器来平衡类别分布
if infinite_data_loader:data_loader = InfiniteDataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)
else:data_loader = DataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)n_class = len(data.classes)
return data_loader, n_class
这段代码决定了如何创建数据加载器,根据infinite_data_loader参数选择不同的加载器类型:
if infinite_data_loader:data_loader = InfiniteDataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)
else:data_loader = DataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)n_class = len(data.classes)
return data_loader, n_class

代码解析

这段代码基于infinite_data_loader参数创建不同类型的数据加载器:
当infinite_data_loader为True时:
创建InfiniteDataLoader实例
自定义的无限循环数据加载器,会持续提供数据而不会在一个epoch结束时停止
当infinite_data_loader为False时:
创建标准的PyTorch DataLoader实例
这是普通的数据加载器,一个epoch结束后会停止

共同参数:

dataset=data:要加载的数据集
batch_size=batch_size:每批数据的大小
shuffle=shuffle:是否打乱数据(之前代码中已设置)
num_workers=num_workers:用于并行加载数据的线程数
sampler=sampler:用于采样的策略(之前代码中已设置,可能是加权采样器)
**kwargs:其他可能的参数,如pin_memory、drop_last等

返回值:

data_loader:创建好的数据加载器
n_class = len(data.classes):数据集中的类别数量
InfiniteDataLoader的作用
在您的代码中定义了两种InfiniteDataLoader实现:一种作为DataLoader的子类,另一种是完全自定义的类。它们的共同目的是:
持续提供数据:当一个epoch结束后,自动重新开始,不会引发StopIteration异常
支持长时间训练:在需要长时间训练的场景中特别有用,如半监督学习或者领域适应
避免手动重置:不需要在每个epoch结束后手动重置数据加载器

使用场景

无限数据加载器特别适用于:
持续训练:模型需要无限期地训练,如自监督学习或强化学习
不均匀更新:源域和目标域数据需要不同频率的更新
流式训练:数据以流的形式到达,不需要明确的epoch边界
基于迭代而非epoch的训练:训练基于迭代次数而非数据epoch
最后的返回值n_class提供了数据集的类别数量,这对模型构建和评估都很重要,比如设置分类层的输出维度或计算平均类别准确率。
高级用法

1.自定义collate_fn处理变长序列

def collate_fn(batch):# 排序批次数据,按序列长度降序batch.sort(key=lambda x: len(x[0]), reverse=True)# 分离数据和标签sequences, labels = zip(*batch)# 计算每个序列的长度lengths = [len(seq) for seq in sequences]# 填充序列到相同长度padded_seqs = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)return padded_seqs, torch.tensor(labels), lengths

使用自定义的collate_fn

dataloader = DataLoader(dataset=text_dataset,batch_size=16,shuffle=True,collate_fn=collate_fn
)

2.使用Sampler进行不均衡数据采样
from torch.utils.data import WeightedRandomSampler

假设我们有类别不平衡问题,计算采样权重

class_count = [100, 1000, 500]  # 每个类别的样本数量
weights = 1.0 / torch.tensor(class_count, dtype=torch.float)
sample_weights = weights[target_list]  # target_list是每个样本的类别索引

创建WeightedRandomSampler

sampler = WeightedRandomSampler(weights=sample_weights,num_samples=len(sample_weights),replacement=True
)

使用sampler

dataloader = DataLoader(dataset=dataset,batch_size=32,sampler=sampler,  # 使用sampler时,shuffle必须为Falsenum_workers=4
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/80768.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/80768.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/80768.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTML、CSS 和 JavaScript 基础知识点

HTML、CSS 和 JavaScript 基础知识点 一、HTML 基础 1. HTML 文档结构 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.…

亚远景-对ASPICE评估体系的深入研究与分析

一、ASPICE评估体系的定义与背景 ASPICE&#xff08;Automotive Software Process Improvement and Capability Determination&#xff09;即汽车软件过程改进及能力测定模型&#xff0c;是由欧洲20多家主要汽车制造商共同制定的&#xff0c;专门针对汽车行业的软件开发过程评…

灰度图像和RGB图像在数据大小和编码处理方式差别

技术背景 好多开发者对灰度图像和RGB图像有些认知差异&#xff0c;今天我们大概介绍下二者差别。灰度图像&#xff08;Grayscale Image&#xff09;和RGB图像在编码处理时&#xff0c;数据大小和处理方式的差别主要体现在以下几个方面&#xff1a; 1. 通道数差异 图像类型通道…

从爬虫到网络---<基石9> 在VPS上没搞好Docker项目,把他卸载干净

1.停止并删除所有正在运行的容器 docker ps -a # 查看所有容器 docker stop $(docker ps -aq) # 停止所有容器 docker rm $(docker ps -aq) # 删除所有容器如果提示没有找到容器&#xff0c;可以忽略这些提示。 2.删除所有镜像 docker images # 查看所有镜像 dock…

Centos 上安装Klish(clish)的编译和测试总结

1&#xff0c;介绍 clish是一个类思科命令行补全与执行程序&#xff0c;它可以帮助程序员在nix操作系统上实现功能导引、命令补全、命令执行的程序。支持&#xff1f;&#xff0c;help, Tab按键。本文基于klish-2.2.0介绍编译和测试。 2&#xff0c;klish的编译 需要安装的库&…

理解计算机系统_并发编程(3)_基于I/O复用的并发(二):基于I/O多路复用的并发事件驱动服务器

前言 以<深入理解计算机系统>(以下称“本书”)内容为基础&#xff0c;对程序的整个过程进行梳理。本书内容对整个计算机系统做了系统性导引,每部分内容都是单独的一门课.学习深度根据自己需要来定 引入 接续上一帖理解计算机系统_并发编程(2)_基于I/O复用的并发…

系统可靠性分析:指标解析与模型应用全览

以下是关于系统可靠性分析中可靠性指标、串联系统与并联系统、混合系统、系统可靠性模型的相关内容&#xff1a; 一、可靠性指标 可靠度&#xff1a;是系统、设备或元件在规定条件和规定时间内完成规定功能的概率。假设一个系统由多个部件组成&#xff0c;每个部件都有其自身…

数字高程模型(DEM)公开数据集介绍与下载指南

数字高程模型&#xff08;DEM&#xff09;公开数据集介绍与下载指南 数字高程模型&#xff08;Digital Elevation Model, DEM&#xff09;广泛应用于地理信息系统&#xff08;GIS&#xff09;、水文模拟、城市规划、环境分析、灾害评估等领域。本文系统梳理了主流的DEM公开数据…

Python+大模型 day01

Python基础 计算机系统组成 基础语法 如:student_num 4.标识符要做到见名知意,增强代码的可读性 关键字 系统或者Python定义的,有特殊功能的字符组合 在学习过程中,文件名没有遵循标识符命名规则,是为了按序号编写文件方便查找复习 但是,在开发中,所有的Python文件名称必须…

C++引用编程练习

#include <iostream> using namespace std; double vals[] {10.1, 12.6, 33.1, 24.1, 50.0}; double& setValues(int i) { double& ref vals[i]; return ref; // 返回第 i 个元素的引用&#xff0c;ref 是一个引用变量&#xff0c;ref 引用 vals[i] } // 要调用…

机密虚拟机的威胁模型

本文将介绍近年兴起的机密虚拟机&#xff08;Confidential Virtual Machine&#xff09;技术所旨在抵御的威胁模型&#xff0c;主要关注内存机密性&#xff08;confidentiality&#xff09;和内存完整性&#xff08;integrity&#xff09;两个方面。在解释该威胁可能造成的问题…

【Rust trait特质】如何在Rust中使用trait特质,全面解析与应用实战

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者简介&#xff1a;景天科技苑 &#x1f3c6;《头衔》&#xff1a;大厂架构师&#xff0c;华为云开发者社区专家博主&#xff0c;…

Simulink模型回调

Simulink 模型回调函数是一种特殊的 MATLAB 函数&#xff0c;可在模型生命周期的特定阶段自动执行。它们允许用户自定义模型行为、执行初始化任务、验证参数或记录数据。以下是各回调函数的详细说明&#xff1a; 1. PreLoadFcn 触发时机&#xff1a;Simulink 模型加载到内存之…

FPGA:Xilinx Kintex 7实现DDR3 SDRAM读写

在Xilinx Kintex 7系列FPGA上实现对DDR3 SDRAM的读写&#xff0c;主要依赖Xilinx提供的Memory Interface Generator (MIG) IP核&#xff0c;结合Vivado设计流程。以下是详细步骤和关键点&#xff1a; 1. 准备工作 硬件需求&#xff1a; Kintex-7 FPGA&#xff08;如XC7K325T&…

Python爬虫实战:研究进制流数据,实现逆向解密

1. 引言 1.1 研究背景与意义 在现代网络环境中,数据加密已成为保护信息安全的重要手段。许多网站和应用通过二进制流数据传输敏感信息,如视频、金融交易数据等。这些数据通常经过复杂的加密算法处理,直接分析难度较大。逆向工程进制流数据不仅有助于合法的数据获取与分析,…

Java Spring Boot项目目录规范示例

以下是一个典型的 Java Spring Boot 项目目录结构规范示例&#xff0c;结合了分层架构和模块化设计的最佳实践&#xff1a; text 复制 下载 src/ ├── main/ │ ├── java/ │ │ └── com/ │ │ └── example/ │ │ └── myapp/ │…

图像颜色理论与数据挖掘应用的全景解析

文章目录 一、图像颜色系统的理论基础1.1 图像数字化的本质逻辑1.2 颜色空间的数学框架1.3 量化过程的技术原理 二、主要颜色空间的深度解析2.1 RGB颜色空间的加法原理2.2 HSV颜色空间的感知模型2.3 CMYK颜色空间的减色原理 三、图像几何属性与高级特征3.1 分辨率与像素密度的关…

mysql两张关联表批量更新一张表存在数据,而另一张表不存在数据的sql

一、mysql两张关联表批量更新一张表存在、另一张表不存在的数据 创建user和user_order表 CREATE TABLE user (id varchar(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL,id_card varchar(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NU…

PNG转ico图标(支持圆角矩形/方形+透明背景)Python脚本 - 随笔

摘要 在网站开发或应用程序设计中&#xff0c;常需将高品质PNG图像转换为ICO格式图标。本文提供一份基于Pillow库实现的&#xff0c;能够完美保留透明背景且支持导出圆角矩形/方形图标的格式转换脚本。 源码示例 圆角方形 from PIL import Image, ImageDraw, ImageOpsdef c…

在线SQL转ER图工具

在线SQL转ER图网站 在数据库设计、软件开发或学术研究中&#xff0c;ER图&#xff08;实体-关系图&#xff09; 是展示数据库结构的重要工具。然而&#xff0c;手动绘制ER图不仅耗时费力&#xff0c;还容易出错。今天&#xff0c;我将为大家推荐一款非常实用的在线工具——SQL…