目录

一、PyTorch 数据加载的核心组件

1.1 Dataset 类的核心方法

1.2 DataLoader 的作用

二、加载 CSV 数据实战

2.1 自定义 CSV 数据集

2.2 使用 TensorDataset 快速加载

三、加载图像数据实战

3.1 自定义图像数据集

3.2 使用 ImageFolder 快速加载

四、加载官方数据集

五、总结


在深度学习项目中,数据加载是模型训练的第一步,也是至关重要的一步。PyTorch 提供了灵活的数据加载工具,让我们能够轻松处理各种类型的数据。本文将结合实际代码,详细讲解如何使用 PyTorch 加载 CSV 数据和图像数据,帮助初学者快速掌握数据加载的核心技巧。

一、PyTorch 数据加载的核心组件

PyTorch 的数据加载主要依赖两个核心类:DatasetDataLoader

  • Dataset:负责数据的读取和预处理,是所有自定义数据集的基类
  • DataLoader:负责批量加载数据,支持打乱顺序、多线程加载等功能

1.1 Dataset 类的核心方法

自定义数据集需要继承Dataset类,并实现以下三个方法:

class CustomDataset(Dataset):def __init__(self, ...):  # 初始化数据集,加载文件路径等passdef __len__(self):  # 返回数据集大小return len(self.data)def __getitem__(self, index):  # 根据索引返回样本return sample, label

1.2 DataLoader 的作用

DataLoader像是一个 "搬运工",将Dataset中的数据按批次搬运给模型:

dataloader = DataLoader(dataset=dataset,  # 要加载的数据集batch_size=32,    # 批次大小shuffle=True,     # 是否打乱数据num_workers=2     # 多线程加载
)

二、加载 CSV 数据实战

CSV 文件是存储表格数据的常用格式,比如学生成绩表、特征数据表等。下面我们通过实际代码讲解如何加载 CSV 数据。

2.1 自定义 CSV 数据集

import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pdclass CsvDataset(Dataset):def __init__(self, filepath):# 读取CSV文件df = pd.read_csv(filepath)# 删除不需要的列(学号、姓名)df.drop(['学号', '姓名'], axis=1, inplace=True)# 提取特征和标签x = df.iloc[1:, :-1]  # 从第二行开始,取除最后一列外的所有列作为特征y = df.iloc[1:, -1]   # 从第二行开始,取最后一列作为标签# 转换为Tensorself.data = torch.tensor(x.values, dtype=torch.float)self.labels = torch.tensor(y.values, dtype=torch.float)def __len__(self):return len(self.data)def __getitem__(self, index):sample = self.data[index]label = self.labels[index]return sample, label# 测试代码
def test_csv_dataset():filepath = '大数据答辩成绩表.csv'dataset = CsvDataset(filepath)print(f"数据集大小: {len(dataset)}")print(f"第一个样本: {dataset[0]}")test_csv_dataset()

2.2 使用 TensorDataset 快速加载

如果数据已经是 Tensor 格式,可以使用TensorDataset快速创建数据集,无需自定义类:

def test_tensor_dataset():filepath = '大数据答辩成绩表.csv'df = pd.read_csv(filepath)df.drop(['学号', '姓名'], axis=1, inplace=True)x = df.iloc[1:, :-1]y = df.iloc[1:, -1]# 转换为Tensordata = torch.tensor(x.values, dtype=torch.float)labels = torch.tensor(y.values, dtype=torch.float)# 使用TensorDatasetdataset = TensorDataset(data, labels)print(f"第一个样本: {dataset[0]}")

三、加载图像数据实战

处理图像数据时,我们需要考虑图像的读取、大小调整、格式转换等问题。下面介绍两种加载图像数据的方法。

3.1 自定义图像数据集

import os
import cv2
from torch.utils.data import Datasetclass PicDataset(Dataset):def __init__(self, filepath):self.filepaths = []  # 存储图像路径self.labels = []     # 存储标签dirnames = []        # 存储类别名称# 遍历文件夹for root, dirs, files in os.walk(filepath):if len(dirs) > 0:dirnames = dirs  # 获取类别文件夹名称for file in files:f_path = os.path.join(root, file)self.filepaths.append(f_path)# 根据文件夹名称确定标签classname = root.split('\\')[-1]self.labels.append(dirnames.index(classname))def __len__(self):return len(self.filepaths)def __getitem__(self, index):filepath = self.filepaths[index]# 读取图像img = cv2.imread(filepath)# 调整图像大小为112x112img = cv2.resize(img, (112, 112))# 转换为Tensor并调整维度 (HWC -> CHW)t_img = torch.tensor(img)t_img = t_img.permute(2, 0, 1)label = self.labels[index]return t_img, label# 测试代码
def test_pic_dataset():filepath = r'E:\人工智能\深度学习\dataset\butterfly'dataset = PicDataset(filepath)print(f"数据集大小: {len(dataset)}")img, label = dataset[0]print(f"图像形状: {img.shape}, 标签: {label}")

3.2 使用 ImageFolder 快速加载

PyTorch 的ImageFolder是加载图像数据集的便捷工具,特别适合以下结构的数据集:

root/class1/img1.jpgimg2.jpgclass2/img1.jpgimg2.jpg

使用方法如下:

from torchvision.datasets import ImageFolder
from torchvision import transformsdef test_image_folder():filepath = r'E:\人工智能\深度学习\dataset\butterfly'# 定义图像转换transform = transforms.Compose([transforms.Resize((112, 112)),  # 调整大小transforms.ToTensor(),          # 转换为Tensor])# 使用ImageFolder加载数据dataset = ImageFolder(root=filepath, transform=transform)print(f"类别: {dataset.classes}")print(f"数据集大小: {len(dataset)}")# 创建DataLoaderdataloader = DataLoader(dataset=dataset,batch_size=1,shuffle=True)# 显示一张图像for img, label in dataloader:print(f"图像形状: {img.shape}")print(f"标签: {label}")breaktest_image_folder()

四、加载官方数据集

PyTorch 提供了许多常用的公开数据集(如 MNIST、CIFAR 等),可以直接下载使用:

from torchvision import datasets, transformsdef test_mnist_dataset():# 定义转换transform = transforms.Compose([transforms.ToTensor()])# 加载MNIST训练集dataset = datasets.MNIST(root='../dataset',  # 数据保存路径train=True,         # 训练集download=True,      # 如果没有数据则下载transform=transform)# 创建DataLoaderdataloader = DataLoader(dataset=dataset,batch_size=1,shuffle=True)# 显示一张图像for img, label in dataloader:print(f"图像形状: {img.shape}")print(f"标签: {label}")breaktest_mnist_dataset()

五、总结

本文介绍了 PyTorch 加载不同类型数据的方法,包括:

  1. 加载 CSV 数据:可以自定义CsvDataset,或使用TensorDataset快速加载
  2. 加载图像数据:可以自定义PicDataset,或使用ImageFolder加载按类别组织的图像
  3. 加载官方数据集:直接使用torchvision.datasets中的类

掌握数据加载的技巧,可以为后续的模型训练打下坚实基础。在实际项目中,需要根据数据的具体格式和特点,选择合适的加载方式,并进行必要的预处理。

希望本文能帮助大家快速上手 PyTorch 的数据加载,如果你有任何问题或建议,欢迎在评论区留言讨论!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/89716.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/89716.shtml
英文地址,请注明出处:http://en.pswp.cn/web/89716.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

程序人生,开启2025下半年

时光匆匆,2025年已然过去一半。转眼来到了7月份。 回望过去上半年,可能你也经历了职场的浮沉、生活的跌宕、家庭的变故。 而下半年,生活依旧充满了各种变数。 大环境的起起伏伏、生活节奏的加快,都让未来的不确定性愈发凸显。 在这…

在 .NET Core 中创建 Web Socket API

要在 ASP.NET Core 中创建 WebSocket API,您可以按照以下步骤操作:设置新的 ASP.NET Core 项目打开 Visual Studio 或您喜欢的 IDE。 创建一个新的 ASP.NET Core Web 应用程序项目。 选择API模板,因为这将成为您的 WebSocket API 的基础。在启…

Python 之地址编码识别

根据输入地址,利用已有的地址编码文件,构造处理规则策略识别地址的编码。 lib/address.json 地址编码文件(这个文件太大,博客里放不下,需要的话可以到 gitcode 仓库获取:https://gitcode.com/TomorrowAndT…

kafka的部署

目录 一、kafka简介 1.1、概述 1.2、消息系统介绍 1.3、点对点消息传递模式 1.4、发布-订阅消息传递模式 二、kafka术语解释 2.1、结构概述 2.2、broker 2.3、topic 2.4、producer 2.5、consumer 2.6、consumer group 2.7、leader 2.8、follower 2.9、partition…

小语种OCR识别技术实现原理

小语种OCR(光学字符识别)技术的实现原理涉及计算机视觉、自然语言处理(NLP)和深度学习等多个领域的融合,其核心目标是让计算机能够准确识别并理解不同语言的印刷或手写文本。以下是其关键技术实现原理的详细解析&#…

GPT:让机器拥有“创造力”的语言引擎

当ChatGPT写出莎士比亚风格的十四行诗,当GitHub Copilot自动生成编程代码,背后都源于同一项革命性技术——**GPT(Generative Pre-trained Transformer)**。今天,我们将揭开这项“语言魔术”背后的科学原理!…

LeetCode|Day19|14. 最长公共前缀|Python刷题笔记

LeetCode|Day19|14. 最长公共前缀|Python刷题笔记 🗓️ 本文属于【LeetCode 简单题百日计划】系列 👉 点击查看系列总目录 >> 📌 题目简介 题号:14. 最长公共前缀 难度:简单…

安全事件响应分析--基础命令

----万能密码oror1 or # 1or11 1 or 11安全事件响应分析------***windoes***------方法开机启动有无异常文件 【开始】➜【运行】➜【msconfig】文件排查 各个盘下的temp(tmp)相关目录下查看有无异常文件 :Windows产生的 临时文件 可以通过查看日志且通过筛…

基于C#+SQL Server实现(Web)学生选课管理系统

学生选课管理系统的设计与开发一、项目背景学生选课管理系统是一个学校不可缺少的部分,传统的人工管理档案的方式存在着很多的缺点,如:效率低、保密性差等,所以开发一套综合教务系统管理软件很有必要,它应该具有传统的…

垃圾回收(GC)

内存管理策略,在业务进程运行的过程中,由垃圾收集器以类似守护协程的方式在后台运行,按照指定策略回收不再被使用的对象,释放内存空间进行回收 优势: 屏蔽内存回收的细节:屏蔽复杂的内存管理工作&#xff0…

Datawhale AI夏令营-机器学习

比赛简介 「用户新增预测挑战赛」是由科大讯飞主办的一项数据科学竞赛,旨在通过机器学习方法预测用户是否为新增用户 比赛属于二分类任务,评价指标采用F1分数,分数越高表示模型性能越好。 如果你有一份带标签的表格型数据,只要…

Spring IOC容器在Web环境中是如何启动的(源码级剖析)?

文章目录一、Web 环境中的 Spring MVC 框架二、Web 应用部署描述配置传统配置(web.xml):Java配置类(Servlet 3.0):三、核心启动流程详解1. 启动流程图2. ★容器初始化入口:ContextLoaderListene…

18个优质Qt开源项目汇总

1,Clementine Music Player Clementine Music Player 是一个功能完善、跨平台的开源音乐播放器,非常适合用于学习如何开发媒体类应用,尤其是跨平台桌面应用。它基于 Qt 框架开发,支持多种操作系统,包括 Windows、macO…

计算机视觉:AI 的 “眼睛” 如何看懂世界?

1. 什么是计算机视觉:让机器 “看见” 并 “理解” 的技术1.1 计算机视觉的核心目标计算机视觉(CV)是人工智能的一个重要分支,它让计算机能够 “看懂” 图像和视频 —— 不仅能捕捉像素信息,还能分析内容、提取语义&am…

华为OD刷题记录

华为OD刷题记录 刷过的题 入门 1、进制 2、NC61 doing 订阅专栏

QT学习教程(二十五)

双缓冲技术&#xff08;Double Buffering&#xff09;&#xff08; 2、公有函数实现&#xff09;#include <QtGui> #include <cmath> using namespace std; #include "plotter.h"以上代码为文件的开头&#xff0c;在这里把std 的名空间加入到当前的全…

设计模式笔记_结构型_装饰器模式

1.装饰器模式介绍装饰器模式是一种结构型设计模式&#xff0c;允许你动态地给对象添加行为&#xff0c;而无需修改其代码。它的核心思想是将对象放入一个“包装器”中&#xff0c;这个包装器提供了额外的功能&#xff0c;同时保持原有对象的接口不变。想象一下&#xff0c;你有…

day25 力扣90.子集II 力扣46.全排列 力扣47.全排列 II

子集II给你一个整数数组 nums &#xff0c;找出并返回所有该数组中不同的递增子序列&#xff0c;递增子序列中 至少有两个元素 。你可以按 任意顺序 返回答案。数组中可能含有重复元素&#xff0c;如出现两个整数相等&#xff0c;也可以视作递增序列的一种特殊情况。示例 1&…

Solidity 中的`bytes`

在 Solidity 中&#xff0c;bytes 和 bytes32 都是用来保存二进制数据的类型&#xff0c;但它们的长度、使用场景、Gas 成本完全不同。✅ 一句话区分类型一句话总结bytes32定长 32 字节&#xff0c;适合做哈希、地址、标识符等固定长度数据。bytes动态长度字节数组&#xff0c;…

初学者STM32—PWM驱动电机与舵机

一、简介 上一节课主要学习了输出比较和PWM的基本原理和结构&#xff0c;本节课就主要以实践为主通过STM32最小系统板和驱动器控制舵机和直流电机。 上一节课的坐标 初学者STM32—输出比较与PWM-CSDN博客 二、舵机 舵机是一种根据输入PWM信号占空比来控制输出角度的装置 输…