DataLoader 是 PyTorch 中处理数据的核心组件,它提供了高效的数据加载、批处理和并行处理功能。下面是一个全面的 DataLoader 实战指南,包含代码示例和最佳实践。

基础用法:简单数据加载

import torch
from torch.utils.data import Dataset, DataLoader# 1. 创建自定义数据集
class SimpleDataset(Dataset):def __init__(self, size=1000):self.data = torch.randn(size, 3, 32, 32)  # 模拟图像数据self.labels = torch.randint(0, 10, (size,))  # 0-9的标签def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.labels[idx]# 2. 创建DataLoader
dataset = SimpleDataset(1000)
dataloader = DataLoader(dataset,batch_size=64,       # 批大小shuffle=True,        # 是否打乱数据num_workers=4,       # 使用4个进程加载数据pin_memory=True      # 使用固定内存(加速GPU传输)
)# 3. 使用DataLoader
for epoch in range(3):print(f"Epoch {epoch+1}")for batch_idx, (data, targets) in enumerate(dataloader):# 数据自动分批:data.shape = [64, 3, 32, 32], targets.shape = [64]if batch_idx % 10 == 0:print(f"  Batch {batch_idx}: {data.shape}, {targets.shape}")print("Epoch completed\n")

高级功能:自定义数据集与转换

图像数据集示例

import os
from PIL import Image
from torchvision import transformsclass CustomImageDataset(Dataset):def __init__(self, img_dir, transform=None):self.img_dir = img_dirself.transform = transformself.img_names = [f for f in os.listdir(img_dir) if f.endswith('.jpg')]# 假设文件名格式为 "label_imageid.jpg",例如 "3_001.jpg"self.labels = [int(f.split('_')[0]) for f in self.img_names]def __len__(self):return len(self.img_names)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_names[idx])image = Image.open(img_path).convert('RGB')label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 定义数据转换
transform = transforms.Compose([transforms.Resize((256, 256)),      # 调整大小transforms.RandomHorizontalFlip(),   # 随机水平翻转transforms.RandomRotation(15),       # 随机旋转 ±15度transforms.ToTensor(),               # 转为Tensor [0,1]transforms.Normalize(                # 标准化mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 创建数据集和DataLoader
dataset = CustomImageDataset('/path/to/images', transform=transform)
dataloader = DataLoader(dataset,batch_size=32,shuffle=True,num_workers=4,collate_fn=lambda batch: tuple(zip(*batch))  # 自定义批处理函数
)

文本数据集示例

from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizerclass TextDataset(Dataset):def __init__(self, file_path, max_len=100):self.max_len = max_lenself.tokenizer = get_tokenizer('basic_english')# 读取文本数据和标签self.texts = []self.labels = []with open(file_path, 'r', encoding='utf-8') as f:for line in f:label, text = line.split('\t')self.labels.append(int(label))self.texts.append(text.strip())# 构建词汇表self.vocab = build_vocab_from_iterator((self.tokenizer(text) for text in self.texts),specials=['<unk>', '<pad>'])self.vocab.set_default_index(self.vocab['<unk>'])def __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]tokens = self.tokenizer(text)# 将token转换为索引indices = [self.vocab[token] for token in tokens]# 截断或填充序列if len(indices) > self.max_len:indices = indices[:self.max_len]else:indices = indices + [self.vocab['<pad>']] * (self.max_len - len(indices))return torch.tensor(indices), self.labels[idx]# 自定义批处理函数(处理变长序列)
def collate_fn(batch):texts, labels = zip(*batch)# 找到批次中最长序列的长度max_len = max(len(t) for t in texts)# 填充所有序列到相同长度padded_texts = []for text in texts:padding = torch.zeros(max_len - len(text), dtype=torch.long)padded_texts.append(torch.cat((text, padding)))return torch.stack(padded_texts), torch.tensor(labels)# 创建DataLoader
text_dataset = TextDataset('/path/to/text_data.txt', max_len=100)
text_dataloader = DataLoader(text_dataset,batch_size=32,shuffle=True,num_workers=2,collate_fn=collate_fn  # 使用自定义批处理函数
)

性能优化技巧

1. 使用并行加载

# 根据CPU核心数设置num_workers
import os
num_workers = min(4, os.cpu_count())  # 使用不超过4个或CPU核心数的workerdataloader = DataLoader(dataset,batch_size=64,shuffle=True,num_workers=num_workers,pin_memory=True,  # 对于GPU训练非常重要persistent_workers=True  # 保持worker进程活动(PyTorch 1.7+)
)

2. 数据预取

from torch.utils.data import DataLoader, PrefetchGenerator# 使用预取生成器(PyTorch 1.7+)
dataloader = DataLoader(dataset,batch_size=64,shuffle=True,num_workers=4,prefetch_factor=2  # 每个worker预取的批次数
)# 或者使用自定义预取
class PrefetchLoader:def __init__(self, loader, device):self.loader = loaderself.device = deviceself.stream = torch.cuda.Stream() if device.type == 'cuda' else Nonedef __iter__(self):first = Truefor batch in self.loader:if self.stream is not None:with torch.cuda.stream(self.stream):batch = self._preprocess(batch)else:batch = self._preprocess(batch)if not first and self.stream is not None:torch.cuda.current_stream().wait_stream(self.stream)first = Falseyield batchdef _preprocess(self, batch):data, target = batchreturn data.to(self.device, non_blocking=True), target.to(self.device, non_blocking=True)# 使用自定义预取
device = torch.device('cuda')
prefetch_dataloader = PrefetchLoader(dataloader, device)

3. 内存映射文件处理大文件

import numpy as np
import torch
from torch.utils.data import Datasetclass MmapDataset(Dataset):def __init__(self, file_path, shape, dtype=np.float32):self.data = np.memmap(file_path, dtype=dtype, mode='r', shape=shape)def __len__(self):return self.data.shape[0]def __getitem__(self, idx):return torch.from_numpy(np.array(self.data[idx]))

分布式数据加载

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler# 初始化分布式环境
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()# 创建分布式采样器
sampler = DistributedSampler(dataset,num_replicas=world_size,rank=rank,shuffle=True,seed=42
)# 创建分布式DataLoader
dist_dataloader = DataLoader(dataset,batch_size=64,sampler=sampler,num_workers=4,pin_memory=True,drop_last=True  # 丢弃最后不完整的批次
)# 在每个进程中
for epoch in range(10):# 设置epoch确保所有进程的shuffle一致dist_dataloader.sampler.set_epoch(epoch)for batch in dist_dataloader:# 处理批次数据pass

数据增强策略

图像增强

from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2# 使用torchvision
torchvision_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 使用Albumentations(更丰富的增强)
albumentations_transform = A.Compose([A.RandomResizedCrop(224, 224),A.HorizontalFlip(p=0.5),A.VerticalFlip(p=0.2),A.Rotate(limit=30),A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.9),A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),ToTensorV2()
])# 在数据集类中使用
def __getitem__(self, idx):img_path = self.img_paths[idx]image = cv2.imread(img_path)image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)if self.transform:augmented = self.transform(image=image)image = augmented['image']return image, self.labels[idx]

文本增强

import nlpaug.augmenter.word as naw# 创建文本增强器
augmenter = naw.ContextualWordEmbsAug(model_path='bert-base-uncased', action="substitute",  # 替换、插入等aug_p=0.1  # 增强比例
)# 在数据集中使用
def __getitem__(self, idx):text = self.texts[idx]if self.augment and random.random() < 0.5:  # 50%概率增强text = augmenter.augment(text)# 后续处理...

数据可视化与调试

import matplotlib.pyplot as plt
import numpy as npdef show_batch(dataloader, n=4):"""显示一批图像及其标签"""dataiter = iter(dataloader)images, labels = next(dataiter)fig, axes = plt.subplots(1, n, figsize=(15, 4))for i in range(n):img = images[i].permute(1, 2, 0).numpy()  # CHW -> HWCimg = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # 反归一化img = np.clip(img, 0, 1)axes[i].imshow(img)axes[i].set_title(f"Label: {labels[i].item()}")axes[i].axis('off')plt.show()# 使用
show_batch(dataloader, n=8)

常见问题解决方案

1. 内存不足

# 解决方案1:使用更小的批大小
dataloader = DataLoader(dataset, batch_size=16)# 解决方案2:使用内存映射文件
# 如前文的MmapDataset示例# 解决方案3:使用IterableDataset
from torch.utils.data import IterableDatasetclass LargeIterableDataset(IterableDataset):def __init__(self, file_path, chunk_size=1000):self.file_path = file_pathself.chunk_size = chunk_sizedef __iter__(self):with open(self.file_path, 'r') as f:chunk = []for line in f:chunk.append(process_line(line))  # 自定义处理函数if len(chunk) == self.chunk_size:yield from chunkchunk = []if chunk:yield from chunk# 使用
dataset = LargeIterableDataset('large_file.txt')
dataloader = DataLoader(dataset, batch_size=64)

2. Windows多进程问题

# 解决方案:将主代码放入if __name__ == '__main__'块中
if __name__ == '__main__':# 在这里创建DataLoaderdataloader = DataLoader(dataset, num_workers=4)# 训练代码...

3. 数据加载成为瓶颈

# 解决方案1:增加num_workers
dataloader = DataLoader(dataset, num_workers=os.cpu_count())# 解决方案2:使用预取
# 如前文的PrefetchLoader示例# 解决方案3:使用更快的存储(如SSD代替HDD)# 解决方案4:使用更高效的数据格式(如HDF5、LMDB)

最佳实践总结

  1. 批大小选择:根据GPU内存选择最大可用批大小

  2. Worker数量:设置为CPU核心数的1-2倍

  3. 固定内存:GPU训练时始终设置pin_memory=True

  4. 数据增强:在CPU上执行,避免占用GPU资源

  5. 分布式训练:使用DistributedSampler确保数据正确分区

  6. 内存优化:对大文件使用内存映射或IterableDataset

  7. 预取策略:使用内置prefetch_factor或自定义预取

  8. 数据验证:定期可视化批次数据确保数据增强有效

  9. 资源监控:监控CPU/GPU利用率,识别瓶颈

  10. 格式优化:使用高效数据格式(如TFRecord、LMDB)加速IO

通过合理配置DataLoader,你可以显著提高模型训练效率,充分利用硬件资源,加速模型迭代过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/90998.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/90998.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/90998.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot单元测试类拿不到bean报空指针异常

原代码package com.atguigu.gulimall.product;import com.aliyun.oss.OSSClient; import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; impo…

持续集成 简介环境搭建

1. 持续集成简介 1.1 持续集成的作用 随着互联网的蓬勃发展,软件生命周期模型也经历了几个比较大的阶段,从最初的瀑布模型,到 V 模型,再到现在的敏捷或者 devops,不论哪个阶段,项目从立项到交付几乎都离不开以下几个过程,开发、构建、测试和发布,而且一直都在致力于又…

关于 java:11. 项目结构、Maven、Gradle 构建系统

一、Java 项目目录结构标准1.1 Java 项目标准目录结构总览标准 Java 项目目录结构&#xff08;以 Maven / Gradle 通用结构为基础&#xff09;&#xff1a;project-root/ ├── src/ │ ├── main/ │ │ ├── java/ # 主业务逻辑代码&#xff08;核心…

大数据的安全挑战与应对

在大数据时代&#xff0c;大数据安全问题已成为开发者最为关注的核心议题之一。至少五年来&#xff0c;大数据已融入各类企业的运营体系&#xff0c;而采用先进数据分析解决方案的组织数量仍在持续增长。本文将明确当前市场中最关键的大数据安全问题与威胁&#xff0c;概述企业…

PostgreSQL ERROR: out of shared memory处理方式

系统允许的总锁数 SELECT (SELECT setting::int FROM pg_settings WHERE name max_locks_per_transaction) * (SELECT setting::int FROM pg_settings WHERE name max_connections) (SELECT setting::int FROM pg_settings WHERE name max_prepared_transactions);当锁大于…

Django 模型(Model)

1. 模型简介 ORM 简介 MVC 框架中一个重要的部分就是 ORM,它实现了数据模型与数据库的解耦,即数据模型的设计不需要依赖于特定的数据库,通过简单的配置就可以轻松更换数据库。即直接面向对象操作数据,无需考虑 sql 语句。 ORM 是“对象-关系-映射”的简称,主要任务是:…

深入解析Hadoop RPC:技术细节与推广应用

Hadoop RPC框架概述在分布式系统的核心架构中&#xff0c;远程过程调用&#xff08;RPC&#xff09;机制如同神经网络般连接着各个计算节点。Hadoop作为大数据处理的基石&#xff0c;其自主研发的RPC框架不仅支撑着内部组件的协同运作&#xff0c;更以独特的工程哲学诠释了分布…

为什么玩游戏用UDP,看网页用TCP?

故事场景&#xff1a;两种不同的远程沟通方式假设你需要和远方的朋友沟通一件重要的事情。方式一&#xff1a;TCP — 打一个重要的电话打电话是一种非常严谨、可靠的沟通方式。• 1. 建立连接 (三次握手):• 你拿起电话&#xff0c;拨号&#xff08;SYN&#xff09;。• 朋友那…

【EGSR2025】材质+扩散模型+神经网络相关论文整理随笔(二)

High-Fidelity Texture Transfer Using Multi-Scale Depth-Aware Diffusion 这篇文章可以从一个带有纹理的几何物体出发&#xff0c;将其身上的纹理自动提取并映射到任意的几何拓扑结构上&#xff08;见下图红线左侧&#xff09;&#xff1b;或者从一个白模几何对象出发&#x…

深度学习图像分类数据集—玉米粒质量识别分类

该数据集为图像分类数据集&#xff0c;适用于ResNet、VGG等卷积神经网络&#xff0c;SENet、CBAM等注意力机制相关算法&#xff0c;Vision Transformer等Transformer相关算法。 数据集信息介绍&#xff1a;玉米粒质量识别分类&#xff1a;[crush, good, mul] 训练数据集总共有3…

Unity VR手术模拟系统架构分析与数据流设计

Unity VR手术模拟系统架构分析与数据流设计 前言 本文将深入分析一个基于Unity引擎开发的多人VR手术模拟系统。该系统采用先进的网络架构设计&#xff0c;支持多用户实时协作&#xff0c;具备完整的手术流程引导和精确的工具交互功能。通过对系统架构和数据管道的详细剖析&…

【Spring Boot】Spring Boot 4.0 的颠覆性AI特性全景解析,结合智能编码实战案例、底层架构革新及Prompt工程手册

Spring Boot 4.0 的颠覆性AI特性全景解析&#xff0c;结合智能编码实战案例、底层架构革新及Prompt工程手册一、Spring Boot 4.0 核心AI能力矩阵二、AI智能编码插件实战&#xff08;Spring AI Assistant&#xff09;1. 安装与激活2. 实时代码生成场景3. 缺陷预测与修复三、AI引…

audiobookshelf-web 项目怎么运行

git clone https://github.com/audiobookshelf/audiobookshelf-web.git cd audiobookshelf-web npm i 启动项目 npm run dev http://localhost:3000/

扫描文件 PDF / 图片 纠斜 | 图片去黑边 / 裁剪 / 压缩

问题&#xff1a;扫描后形成的 PDF 或图片文档常存在变形倾斜等问题&#xff0c;手动调整颇为耗时费力。 一、PDF 纠斜 - Adobe Acrobat DC 1、所用功能 扫描和 OCR&#xff1a; 识别文本&#xff1a;在文件中 → 设置 确定后启动扫描&#xff0c;识别过程中自动纠偏。 2、…

适配器模式:兼容不兼容接口

将一个类的接口转换成客户端期望的另一个接口&#xff0c;解决接口不兼容问题。代码示例&#xff1a;// 目标接口&#xff08;客户端期望的格式&#xff09; interface ModernPrinter {void printDocument(String text); }// 被适配的旧类&#xff08;不兼容&#xff09; class…

流程控制:从基础结构到跨语言实践与优化

流程控制 一、流程控制基础概念与核心价值 &#xff08;一&#xff09;流程控制定义与本质 流程控制是通过特定逻辑结构决定程序执行顺序的机制&#xff0c;核心是控制代码运行路径&#xff0c;包括顺序执行、条件分支、循环迭代三大核心逻辑。其本质是将无序的指令集合转化为有…

Http与Https区别和联系

一、HTTP 详解 HTTP&#xff08;HyperText Transfer Protocol&#xff09;​​ 是互联网数据通信的基础协议&#xff0c;用于客户端&#xff08;浏览器&#xff09;与服务器之间的请求-响应交互 核心特性​​&#xff1a; 1.无连接&#xff08;Connectionless&#xff09;​​…

飞算JavaAI:开启 Java 开发 “人机协作” 新纪元

每日一句 明天是新的一天&#xff0c; 你也不再是昨天的你。 目录每日一句一、需求到架构&#xff1a;AI深度介入开发“源头设计”1.1 需求结构化&#xff1a;自然语言到技术要素的精准转化1.2 架构方案生成&#xff1a;基于最佳实践的动态适配二、编码全流程&#xff1a;从“…

Qt项目锻炼——TODO(五)

发现问题如果是自己创建的ui文件&#xff0c;怎么包含进自己的窗口类并且成为ui成员&#xff1f;一般来说Qt designer 会根据你.ui文件生成对应的ui_文件名这个类&#xff08;文件名是ui文件名&#xff09;&#xff0c;它包含了所有 UI 组件&#xff08;如按钮、文本框、标签等…

Vue框架之模板语法全面解析

Vue框架之模板语法全面解析一、模板语法的核心思想二、插值表达式&#xff1a;数据渲染的基础2.1 基本用法&#xff1a;渲染文本2.2 纯HTML渲染&#xff1a;v-html指令2.3 一次性插值&#xff1a;v-once指令三、指令系统&#xff1a;控制DOM的行为3.1 条件渲染&#xff1a;v-if…