关于PyTorch的数据类型和使用的学习笔记  系统介绍了PyTorch的核心数据类型Tensor及其应用。Tensor作为多维矩阵数据容器,支持0-4维数据结构(标量到批量图像),并提供了多种数值类型(float32/int64等)。通过积木类比阐述了Tensor的维度概念,展示了创建、变形、随机生成等基础操作。重点演示了FashionMNIST数据集分类任务实战:构建包含两个全连接层的神经网络(QYNN),使用交叉熵损失和SGD优化器进行训练。

1 介绍

  PyTorch 是Torch的Python版本 是开源的神经网络框架 针对于GPU加速的深度神经网络编程

  Torch是一个经典的多维矩阵数据进行操作的张量(Tensor)库 在机器学习和其他数学密集型应用广泛应用 PyTorch的计算图是动态的 可以按照计算需求实时改变计算图

  PyTorch追求最少的封装 设计遵循Tensor->Variable->nn.Module 三个由低到高的抽象层次 分别代表高维数组(张量) 自动求导(变量) 神经网络(层/模块)三个抽象间联系紧密


2 基础数据类型

 2.1 图文说明

PyTorch处理的最基本的操作对象就是张量(Tensor)表示的就是一个多维矩阵 接下来将进行一个通俗的说明

   我们类比一下积木 ,Tensor就是构建一切模型和计算的最基本积木块。而PyThorch就是一个装数字的​​盒子​​,并且这个盒子可以有很多​​维度​​(几层架子)。

维度类比描述例子具体场景说明
0维张量(标量)一粒积木5, 3.14单个数值(如温度、概率值)
1维张量(向量)一行/一列整齐摆放的积木[1, 2, 3, 4]物体位置坐标、心电图波形数据
2维张量(矩阵)行列组成的积木板[[1,2,3],
[4,5,6]]
灰度图像(28x28像素)、
Excel表格数据
3维张量一摞多个积木板尺寸示例:[3,224,224]彩色图像(通道×高×宽)
MRI切片扫描数据
4维张量多个3维张量
打包的箱子
尺寸示例:[32,3,224,224]批量处理32张
224x224像素的RGB图像

 

 

而每一个数字也自己本身的数据类型:浮点型 和 整型

​数据类型​​位宽/精度​​通俗解释​​典型应用场景​​PyTorch创建方法​​内存占用​
​torch.float32​
(torch.float)
32位
单精度浮点
带小数点的数
(如3.14159)
深度学习模型参数
激活函数计算
.float()
dtype=torch.float32
4字节/元素
​torch.float64​
(torch.double)
64位
双精度浮点
高精度浮点数
(更多小数位)
科学计算
精密数值分析
.double()
dtype=torch.float64
8字节/元素
​torch.int32​
(torch.int)
32位整数普通整数
(如-1, 0, 42)
一般计数
简单索引
.int()
dtype=torch.int32
4字节/元素
​torch.int64​
(torch.long)
64位整数大范围整数
(更大或更精确)
​标签数据​
复杂索引
位置信息
.long()
dtype=torch.int64
8字节/元素
​torch.uint8​8位无符号整数0-255整数
(无负数)
​图像像素值​
(0=黑, 255=白)
.byte()
dtype=torch.uint8
1字节/元素
​torch.bool​布尔值True/False
(是/否)
条件判断
数据掩码
(如x>5)
.bool()
dtype=torch.bool
1字节/元素
​torch.complex64​64位复数复数表示
(实部+虚部浮点)
信号处理
量子计算
dtype=torch.complex648字节/元素
​torch.complex128​128位复数高精度复数高级物理计算
电磁场模拟
dtype=torch.complex12816字节/元素

        ​​之所以说 Tensor 是核心数据类型。是因为 PyTorch 几乎所有操作(神经网络运算、求梯度)都建立在处理 Tensor 之上。你需要把你的数据(数字、图像、文本数值化表示等)最终都放进这些不同形状(维度)的 Tensor 盒子里,PyTorch 才能处理和计算。​ 所有东西最终都变成 Tensor 的某种形式。维度 (shape) 决定了数据的基本结构(标量、向量、矩阵、图片、批量)。而 ​dtype 去指定格子的内容类型

2.2 代码实现

     然后 PyTorch实际的数据类型我们再使用代码实操一下

 2.2.1 基础张量创建​
# 张量的定义方式和Numpy一样 传入矩阵即可生成张量
import torch
a = torch.Tensor([[1,2],[3,4]])
print(a) # <class 'torch.Tensor'>
a = torch.eye(2)  # 创建2x2单位矩阵
print(a)           # 输出: tensor([[1., 0.], [0., 1.]])
​2.2.2 特殊张量初始化​
 = torch.zeros(3, 3)  # 3x3全0张量
c = torch.ones(3, 3)   # 3x3全1张量
d = torch.arange(1, 10, 2)  # [1,10)区间步长为2: [1,3,5,7,9]
e = torch.linspace(1, 10, 10)  # 1-10的10个等差值
f = torch.logspace(1, 10, 10)  # 10^1到10^10的10个对数间隔值
g = torch.logspace(1, 2, 10)   # 10^1到10^2的10个值
2.2.3 随机张量生成​
a1 = torch.rand(3, 3) # [0,1)均匀分布 a2 = torch.randn(3, 3) # 标准正态分布(μ=0, σ=1) a3 = torch.randint(1, 10, (5, 5)) # [1,10)区间的随机整数
2.2.4 ​NumPy互操作​
import numpy as np
a4 = np.array([1, 2])          # 创建NumPy数组
a5 = torch.from_numpy(a4)       # NumPy转PyTorch张量
# 类型转换: <class 'numpy.ndarray'> -> <class 'torch.Tensor'>
 2.2.5 张量形状操作​
a = torch.Tensor(2, 3, 128, 128)
print(a.shape)              # torch.Size([2, 3, 128, 128])
print(a[0].shape)          # torch.Size([3, 128, 128])
print(a[0][0].shape)       # torch.Size([128, 128])# 高级切片
print(a[:1, :1, :64, :64].shape)    # torch.Size([1, 1, 64, 64])
print(a[:1, :1, :64:2, :64:2].shape) # torch.Size([1, 1, 32, 32])
2.2.6 维度变换​
# 重塑形状
B = a.reshape(2, 3, -1)      # 展平后两维: (2,3,16384)
C = a.reshape(4, -1)        # (4, 24576)# 增删维度
a = a.unsqueeze(2)          # 添加维度: (2,3,1,128,128)
a = a.squeeze(1)            # 删除大小为1的维度# 维度交换
a = a.transpose(0, 1)       # 交换维度0和1: (3,2,128,128)
a = a.permute(1, 0, 3, 2)   # 维度重排: (3,2,128,128)
 2.2.7 维度扩展​ 
a = torch.randn(2, 1, 128, 128)
a = a.expand(2, 3, 128, 64)  # 复制数据扩展维度
# 要求: 扩展维度必须为1或与原尺寸一致
2.2.8 函数总结
​操作类型​​函数/语法​​关键特性​
基础创建eye()zeros()ones()初始化特殊矩阵
序列生成arange()linspace()控制步长/数量
随机生成rand()randn()randint()均匀/正态/整数分布
维度操作reshape()view()数据不复制改变形状
维度增删unsqueeze()squeeze()添加/移除大小为1的维度
维度交换transpose()permute()调整维度顺序
数据扩展expand()复制数据扩展张量(仅支持1->N的扩展)
NumPy互操作from_numpy()零拷贝数据共享

3 实战使用

使用FashionMNIST数据集(FashionMNIST 是一个经典的计算机视觉基准数据集,由德国电商巨头 Zalando 的研究团队于 2017 年创建,旨在替代过于简单的 MNIST 手写数字数据集。它包含 70,000 张 28x28 像素的时尚单品灰度图像,涵盖 10 个类别。)完成一个基本的图形分类任务

        我们将从环境配置 模型训练与评估 模型使用三个阶段讲起

3.1 环境配置

模型训练

import torch
import torch.nn as nn 
import torch.optim as optim # 导入优化器
from torchvision import datasets, transforms # 导入数据集和数据预处理库
from torch.utils.data import DataLoader # 数据加载库

模型使用

import os # 用于操作文件
import torch
import matplotlib.pyplot as plt
from torchvision import datasets,transforms # 用于数据集和数据变换
from PIL import Image # 用于图形操作
from torchvision.datasets import FashionMNIST # 用于加载FashionMNIST数据集
from train import QYNN, transform

一定要添加一个本地的解释器配置环境 以免冲突 

3.2 模型训练与评估

  train.py


import torch
import torch.nn as nn 
import torch.optim as optim # 导入优化器
from torchvision import datasets, transforms # 导入数据集和数据预处理库
from torch.utils.data import DataLoader # 数据加载库# 设置随机种子
torch.manual_seed(21)# 定义数据预处理
transform = transforms.Compose([transforms.Resize((28,28)),transforms.Grayscale(), #强制灰度图像(1通道)transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize((0.5,), (0.5,))  # 标准化图像数据 灰度图,只需要一个0.5
])# 加载FashionMNIST数据集
train_dataset = datasets.FashionMNIST(root='./FashionMNIST_images/train', train=True, download=True, transform=transform)  # 下载训练集
test_dataset = datasets.FashionMNIST(root='./FashionMNIST_images/test', train=False, download=True, transform=transform) #  下载测试集# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # 对训练集进行打包,指定批次为64
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) # 对测试集进行打包# 打印数据集大小和样本检查
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")# 定义神经网络模型
class QYNN(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(28*28, 128)  # 第一个全连接层 先转换为一维向量self.fc2 = nn.Linear(128, 10)  # 第二个全连接层 输出10个类别def forward(self, x):x = torch.flatten(x, start_dim=1)  # 展平数据,方便进行全连接x = torch.relu(self.fc1(x))  # 非线性x = self.fc2(x) # 十分类 输出层return x # 检查是否有 GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 初始化模型
model = QYNN().to(device) # 将模型移植到 GPU 或 CPU 上# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵
optimizer = optim.SGD(model.parameters(), lr=0.01) # lr 学习率 用来调整模型收敛速度# 训练模型
epochs = 10
best_acc = 0 # 初始化最佳准确率
best_model_wts = None # 用于保存最佳权重
for epoch in range(epochs): # 0-9running_loss = 0.0model.train()  # 设置模型为训练模式for inputs, labels in train_loader:# 移动数据到GPU 或 CPUinputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad() # 梯度清零outputs = model(inputs) # 将图片塞进网络训练获得 输出 前向传播loss = criterion(outputs, labels) # 根据输出和标签做对比计算损失loss.backward() # 反向传播optimizer.step() # 更新参数running_loss += loss.item() # loss值累加print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")# 测试模型
model.eval() # 设置模型为评估模式
correct = 0 # 正确的数量
total = 0 # 样本总数
with torch.no_grad(): # 不用进行梯度计算for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs, 1) # _取到的最大值,可以不要, 我们需要的是最大值对应的索引 也就是label(predicted)total += labels.size(0) # 获取当前批次样本数量correct += (predicted == labels).sum().item() # 对预测对的值进行累加accuracy = 100 * correct / total # 计算准确率
print(f"Epoch{epoch+1}/{epochs},Accuracy on test set: {correct/total:.2%}")# 如果当前模型的准确率比之前的最佳准率好 则保存模型权重
if accuracy > best_acc: best_acc = accuracy
best_model_wts = model.state_dict() # 保存最佳模型的权重torch.save(model.state_dict(), "./FashionMNIST_images/model.pt")
print("Best model weights saved !")

3.3 模型使用 

test.py

import os # 用于操作文件
import torch
import matplotlib.pyplot as plt
from torchvision import datasets,transforms # 用于数据集和数据变换
from PIL import Image # 用于图形操作
from torchvision.datasets import FashionMNIST # 用于加载FashionMNIST数据集
from train import QYNN, transform# 定义数据集保存路径
data_dir = './FashionMNIST_images' # 数据集的根目录
train_dir = os.path.join(data_dir, 'train') # 训练集保存路经)
test_dir = os.path.join(data_dir, 'test')  # 测试集保存路径# 定义分类标签 FashionMNIST共有10个类别
class_names = ['T-shirt/top',   # 0: T恤/上衣'Trouser',       # 1: 裤子'Pullover',      # 2: 套头衫'Dress',         # 3: 连衣裙'Coat',          # 4: 外套'Sandal',        # 5: 凉鞋'Shirt',         # 6: 衬衫'Sneaker',       # 7: 运动鞋'Bag',           # 8: 包'Ankle boot'     # 9: 短靴
]model = QYNN()
model.load_state_dict(torch.load("./FashionMNIST_images/model.pt"))
model.eval()# 定义推理和可视化函数
def infer_and_visualize_image(image_path,model,classses):# 打开图形并进行预处理img = Image.open(image_path).convert('L') # 确保灰度图片img = transform(img).unsqueueeze(0) # 增加一个批次维度# 推理with torch.no_grad():output = model(img)_, predicted = torch.max(output, 1)# 可视化图形和预测结果plt.imshow(img.squeeze(), cmap='gray')plt.title(f"Predicted {classses[predicted[0]]}")plt.axis('off')plt.show()# 输入图形路径image_path = r""infer_and_visualize_image(image_path, model, class_names)

 训练效果展示


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/93417.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/93417.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/93417.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[python刷题模板] LogTrick

[python刷题模板] LogTrick 一、 算法&数据结构1. 描述2. 复杂度分析3. 常见应用4. 常用优化二、 模板代码1. 特定或值的最短子数组2. 找特定值3. 找位置j的最后一次被谁更新4. 问某个或和的数量三、其他四、更多例题五、参考链接一、 算法&数据结构 1. 描述 LogTric…

Vim与VS Code

Vim is a clone, with additions, of Bill Joys vi text editor program for Unix. It was written by Bram Moolenaar based on source for a port of the Stevie editor to the Amiga and first released publicly in 1991.其实这个本身不是 IDE &#xff08;只有在加入和配置…

[2025CVPR-图象分类方向]CATANet:用于轻量级图像超分辨率的高效内容感知标记聚合

​1. 研究背景与动机​ ​问题​&#xff1a;Transformer在图像超分辨率&#xff08;SR&#xff09;中计算复杂度随空间分辨率呈二次增长&#xff0c;现有方法&#xff08;如局部窗口、轴向条纹&#xff09;因内容无关性无法有效捕获长距离依赖。​现有局限​&#xff1a; SPI…

课题学习笔记3——SBERT

1 引言在构建基于知识库的问答系统时&#xff0c;"语义匹配" 是核心难题 —— 如何让系统准确识别 "表述不同但含义相同" 的问题&#xff1f;比如用户问 "对亲人的期待是不是欲&#xff1f;"&#xff0c;系统能匹配到知识库中 "追名逐利是欲…

在Word和WPS文字中把全角数字全部改为半角

大部分情况下我们在Word或WPS文字中使用的数字或标点符号都是半角&#xff0c;但是有时不小心按错了快捷键或者点到了输入法的全角半角切换图标&#xff0c;就输入了全角符号和数字。不用担心&#xff0c;使用它们自带的全角、半角转换功能即可快速全部转换回来。一、为什么会输…

数据结构的基本知识

一、集合框架1、什么是集合框架Java集合框架(Java Collection Framework),又被称为容器(container),是定义在java.util包下的一组接口(interfaces)和其实现类(classes).主要表现为把多个元素(element)放在一个单元中,用于对这些元素进行快速、便捷的存储&#xff08;store&…

WebStack-Hugo | 一个静态响应式导航主题

WebStack-Hugo | 一个静态响应式导航主题 #10 shenweiyan announced in 1.3-折腾 WebStack-Hugo | 一个静态响应式导航主题#10 ​编辑shenweiyan on Oct 23, 2023 6 comments 7 replies Return to top shenweiyan on Oct 23, 2023 Maintainer Via&#xff1a;我给自己…

01 基于sklearn的机械学习-机械学习的分类、sklearn的安装、sklearn数据集、数据集的划分、特征工程中特征提取与无量纲化

文章目录机械学习机械学习分类1. 监督学习2. 半监督学习3. 无监督学习4. 强化学习机械学习的项目开发步骤scikit-learn1 scikit-learn安装2 sklearn数据集1. sklearn 玩具数据集鸢尾花数据集糖尿病数据集葡萄酒数据集2. sklearn现实世界数据集20 新闻组数据集3. 数据集的划分特…

携全双工语音通话大模型亮相WAIC,Soul重塑人机互动新范式

近日&#xff0c;WAIC 2025在上海隆重开幕。作为全球人工智能领域的顶级盛会&#xff0c;本届WAIC展览聚焦底层能力的演进与具体垂类场景的融合落地。坚持“模应一体”方向、立足“AI社交”的具体场景&#xff0c;Soul App此次携最新升级的自研端到端全双工语音通话大模型亮相&…

第2章 cmd命令基础:常用基础命令(1)

Hi~ 我是李小咖&#xff0c;主要从事网络安全技术开发和研究。 本文取自《李小咖网安技术库》&#xff0c;欢迎一起交流学习&#x1fae1;&#xff1a;https://imbyter.com 本节介绍的命令有目录操作&#xff08;cd&#xff09;、清屏操作&#xff08;cls&#xff09;、设置颜色…

Java 10 新特性解析

Java 10 新特性解析 文章目录Java 10 新特性解析1. 引言2. 本地变量类型推断&#xff08;JEP 286&#xff09;2.1. 概述2.2. 使用场景2.3. 限制2.4. 与之前版本的对比2.5. 风格指南2.6. 示例代码2.7. 优点与注意事项3. 应用程序类数据共享&#xff08;JEP 310&#xff09;3.1. …

【WRF工具】服务器中安装编译GrADS

目录 安装编译 GrADS 所需的依赖库 conda下载库包 安装编译 GrADS 编译前检查依赖可用性 安装编译 GrADS 参考 安装编译 GrADS 所需的依赖库 以统一方式在 $HOME/WRFDA_LIBS/grads_deps 下安装所有依赖: # 选择一个目录用于安装所有依赖库 export DIR=$HOME/WRFDA_LIBS库包1…

数据结构之队列(C语言)

1.队列的定义&#xff1a; 队列&#xff08;Queue&#xff09;是一种基础且重要的线性数据结构&#xff0c;遵循先进先出&#xff08;FIFO&#xff09;​​ 原则&#xff0c;即最早入队的元素最先出队&#xff0c;与栈不同的是出队列的顺序是固定的。队列具有以下特点&#xff…

C#开发基础之深入理解“集合遍历时不可修改”的异常背后的设计

前言 欢迎关注【dotnet研习社】&#xff0c;今天我们聊聊一个基础问题“集合已修改&#xff1a;可能无法执行枚举操作”背后的设计。 在日常 C# 开发中&#xff0c;我们常常会操作集合&#xff08;如 List<T>、Dictionary<K,V> 等&#xff09;。一个新手开发者极…

【工具】图床完全指南:从选择到搭建的全方位解决方案

前言 在数字化内容创作的时代&#xff0c;图片已经成为博客、文档、社交媒体等平台不可或缺的元素。然而&#xff0c;如何高效、稳定地存储和分发图片资源&#xff0c;一直是内容创作者面临的重要问题。图床&#xff08;Image Hosting&#xff09;作为专门的图片存储和分发服务…

深度学习篇---PaddleDetection模型选择

PaddleDetection 是百度飞桨推出的目标检测开发套件&#xff0c;提供了丰富的模型库和工具链&#xff0c;覆盖从轻量级移动端到高性能服务器的全场景需求。以下是核心模型分类、适用场景及大小选择建议&#xff08;通俗易懂版&#xff09;&#xff1a;一、主流模型分类及适用场…

cmseasy靶机密码爆破通关教程

靶场安装1.首先我们需要下载一个cms靶场CmsEasy_7.6.3.2_UTF-8_20200422,下载后解压在phpstudy_pro的网站根目录下。2.然后我们去访问一下安装好的网站&#xff0c;然后注册和链接数据库3.不知道自己数据库密码的可以去小皮面板里面查看4.安装好后就可以了来到后台就可以了。练…

【C语言】指针深度剖析(一)

文章目录一、内存和地址1.1 内存的基本概念1.2 编址的原理二、指针变量和地址2.1 取地址操作符&#xff08;&&#xff09;2.2 指针变量和解引用操作符&#xff08;*&#xff09;2.2.1 指针变量2.2.2 指针类型的解读2.2.3 解引用操作符2.3 指针变量的大小三、指针变量类型的…

半导体企业选用的跨网文件交换系统到底应该具备什么功能?

在半导体行业的数字化转型过程中&#xff0c;跨网文件交换已成为连接研发、生产、供应链的关键纽带。半导体企业的跨网文件交换不仅涉及设计图纸、工艺参数等核心知识产权&#xff0c;还需要满足跨国协同、合规审计等复杂需求。那么&#xff0c;一款适合半导体行业的跨网文件交…

影刀RPA_初级课程_玩转影刀自动化_网页操作自动化

声明&#xff1a;相关内容来自影刀学院&#xff0c;本文章为自用笔记&#xff0c;切勿商用&#xff01;&#xff08;若有侵权&#xff0c;请联络删除&#xff09; 1. 基本概念与操作 1.1 正确处理下拉框元素&#xff08;先判断页面元素&#xff0c;后进行流程编制&#xff09;…