@浙大疏锦行

  1. 图像数据的格式:灰度和彩色数据
  2. 模型的定义
  3. 显存占用的4种地方
    1. 模型参数+梯度参数
    2. 优化器参数
    3. 数据批量所占显存
    4. 神经元输出中间状态
  4. batchisize和训练的关系

一、 图像数据的介绍

    图像数据,相较于结构化数据(表格数据)他的特点在于他每个样本的的形状并不是(特征数,),而是(宽,高,通道数)

    结构化数据(如表格)的形状通常是 (样本数, 特征数),例如 (1000, 5) 表示 1000 个样本,每个样本有 5 个特征。图像数据的形状更复杂,需要保留空间信息(高度、宽度、通道),因此不能直接用一维向量表示。其中颜色信息往往是最开始输入数据的通道的含义,因为每个颜色可以用红绿蓝三原色表示,因此一般输入数据的通道数是 3。   

1.1 灰度图像

# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签
# 可视化原始图像(需要反归一化)
def imshow(img):img = img * 0.3081 + 0.1307  # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 显示灰度图像plt.show()print(f"Label: {label}")
imshow(image)

    MNIST 数据集是手写数字的 灰度图像,每个像素点的取值范围为 0-255(黑白程度),因此 通道数为 1。图像尺寸统一为 28×28 像素。

1.2 彩色图像

    在 PyTorch 中,图像数据的形状通常遵循 (通道数, 高度, 宽度) 的格式(即 Channel First 格式),这与常见的 (高度, 宽度, 通道数)(Channel Last,如 NumPy 数组)不同。---注意顺序关系,

注意点:

1. 如果用matplotlib库来画图,需要转换下顺序,我们后续介绍

2. 模型输入通常需要 批次维度(Batch Size),形状变为 (批次大小, 通道数, 高度, 宽度)。例如,批量输入 10 张 MNIST 图像时,形状为 (10, 1, 28, 28)。

# 打印一张彩色图像,用cifar-10数据集
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np# 设置随机种子确保结果可复现
torch.manual_seed(42)
# 定义数据预处理步骤
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化处理
])# 加载CIFAR-10训练集
trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform
)# 创建数据加载器
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True
)# CIFAR-10的10个类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 随机选择一张图片
sample_idx = torch.randint(0, len(trainset), size=(1,)).item()
image, label = trainset[sample_idx]# 打印图片形状
print(f"图像形状: {image.shape}")  # 输出: torch.Size([3, 32, 32])
print(f"图像类别: {classes[label]}")# 定义图像显示函数(适用于CIFAR-10彩色图像)
def imshow(img):img = img / 2 + 0.5  # 反标准化处理,将图像范围从[-1,1]转回[0,1]npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))  # 调整维度顺序:(通道,高,宽) → (高,宽,通道)plt.axis('off')  # 关闭坐标轴显示plt.show()# 显示图像
imshow(image)

二、 图像相关的神经网络的定义

2.1 黑白图像模型的定义

# 先归一化,再标准化
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])
import matplotlib.pyplot as plt# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)
# 定义两层MLP神经网络
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()  # 将28x28的图像展平为784维向量self.layer1 = nn.Linear(784, 128)  # 第一层:784个输入,128个神经元self.relu = nn.ReLU()  # 激活函数self.layer2 = nn.Linear(128, 10)  # 第二层:128个输入,10个输出(对应10个数字类别)def forward(self, x):x = self.flatten(x)  # 展平图像x = self.layer1(x)   # 第一层线性变换x = self.relu(x)     # 应用ReLU激活函数x = self.layer2(x)   # 第二层线性变换,输出logitsreturn x# 初始化模型
model = MLP()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # 将模型移至GPU(如果可用)from torchsummary import summary  # 导入torchsummary库
print("\n模型结构信息:")
summary(model, input_size=(1, 28, 28))  # 输入尺寸为MNIST图像尺寸

我们关注和之前结构化MLP的差异

1. 输入需要展平操作

    MLP 的输入层要求输入是一维向量,但 MNIST 图像是二维结构(28×28 像素),形状为 [1, 28, 28](通道 × 高 × 宽)。nn.Flatten()展平操作 将二维图像 “拉成” 一维向量(784=28×28 个元素),使其符合全连接层的输入格式。

    其中不定义这个flatten方法,直接在前向传播的过程中用 x = x.view(-1, 28 * 28) 将图像展平为一维向量也可以实现

2. 输入数据的尺寸包含了通道数input_size=(1, 28, 28)

3. 参数的计算

  • 第一层 layer1(全连接层)

权重参数:输入维度 × 输出维度 = 784 × 128 = 100,352

偏置参数:输出维度 = 128

合计:100,352 + 128 = 100,480

  • 第二层 layer2(全连接层)

权重参数:输入维度 × 输出维度 = 128 × 10 = 1,280

偏置参数:输出维度 = 10

合计:1,280 + 10 = 1,290

  • 总参数:100,480(layer1) + 1,290(layer2) = 101,770

2.2 彩色图像模型的定义

class MLP(nn.Module):def __init__(self, input_size=3072, hidden_size=128, num_classes=10):super(MLP, self).__init__()# 展平层:将3×32×32的彩色图像转为一维向量# 输入尺寸计算:3通道 × 32高 × 32宽 = 3072self.flatten = nn.Flatten()# 全连接层self.fc1 = nn.Linear(input_size, hidden_size)  # 第一层self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes)  # 输出层def forward(self, x):x = self.flatten(x)  # 展平:[batch, 3, 32, 32] → [batch, 3072]x = self.fc1(x)      # 线性变换:[batch, 3072] → [batch, 128]x = self.relu(x)     # 激活函数x = self.fc2(x)      # 输出层:[batch, 128] → [batch, 10]return x# 初始化模型
model = MLP()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # 将模型移至GPU(如果可用)from torchsummary import summary  # 导入torchsummary库
print("\n模型结构信息:")
summary(model, input_size=(3, 32, 32))  # CIFAR-10 彩色图像(3×32×32)
  •  第一层 layer1(全连接层)

权重参数:输入维度 × 输出维度 = 3072 × 128 = 393,216

偏置参数:输出维度 = 128

合计:393,216 + 128 = 393,344

  • -第二层 layer2(全连接层)

权重参数:输入维度 × 输出维度 = 128 × 10 = 1,280

偏置参数:输出维度 = 10

合计:1,280 + 10 = 1,290

  •  总参数:393,344(layer1) + 1,290(layer2) = 394,634

 2.3 模型定义与batchsize的关系

    实际定义中,输入图像还存在batchsize这一维度。在 PyTorch 中,模型定义和输入尺寸的指定不依赖于 batch_size,无论设置多大的 batch_size,模型结构和输入尺寸的写法都是不变的。

class MLP(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten() # nn.Flatten()会将每个样本的图像展平为 784 维向量,但保留 batch 维度。self.layer1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.layer2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)  # 输入:[batch_size, 1, 28, 28] → [batch_size, 784]x = self.layer1(x)   # [batch_size, 784] → [batch_size, 128]x = self.relu(x)x = self.layer2(x)   # [batch_size, 128] → [batch_size, 10]return x

    PyTorch 模型会自动处理 batch 维度(即第一维),无论 batch_size 是多少,模型的计算逻辑都不变。batch_size 是在数据加载阶段定义的,与模型结构无关。

    summary(model, input_size=(1, 28, 28))中的input_size不包含 batch 维度,只需指定样本的形状(通道 × 高 × 宽)。

三、显存占用的主要组成部分

    昨天说到了在面对数据集过大的情况下,由于无法一次性将数据全部加入到显存中,所以采取了分批次加载这种方式。即一次只加载一部分数据,保证在显存的范围内。

    那么显存设置多少合适呢?如果设置的太小,那么每个batchsize的训练不足以发挥显卡的能力,浪费计算资源;如果设置的太大,会出现OOT(out of memory)

显存一般被以下内容占用:

1. 模型参数与梯度:模型的权重(Parameters)和对应的梯度(Gradients)会占用显存,尤其是深度神经网络(如 Transformer、ResNet 等),一个 1 亿参数的模型(如 BERT-base),单精度(float32)参数占用约 400MB(1e8×4Byte),加上梯度则翻倍至 800MB(每个权重参数都有其对应的梯度)。

2. 部分优化器(如 Adam)会为每个参数存储动量(Momentum)和平方梯度(Square Gradient),进一步增加显存占用(通常为参数大小的 2-3 倍)

3. 其他开销。

from torch.utils.data import DataLoader# 定义训练集的数据加载器,并指定batch_size
train_loader = DataLoader(dataset=train_dataset,  # 加载的数据集batch_size=64,          # 每次加载64张图像shuffle=True            # 训练时打乱数据顺序
)# 定义测试集的数据加载器(通常batch_size更大,减少测试时间)
test_loader = DataLoader(dataset=test_dataset,batch_size=1000,shuffle=False
)

3.1 模型参数与梯度(FP32 精度)

  • 1字节(Byte)= 8位(bit),是计算机存储的最小寻址单位。  
  • 位(bit)是二进制数的最小单位(0或1),例如`0b1010`表示4位二进制数。
  • 1KB=1024字节;1MB=1024KB=1,048,576字节

3.2 优化器状态

  SGD

  • SGD优化器**不存储额外动量**,因此无额外显存占用。  
  • SGD 随机梯度下降,最基础的优化器,直接沿梯度反方向更新参数。
  • 参数更新公式:w = w - learning_rate * gradient

 Adam

  • Adam优化器:自适应学习率优化器,结合了动量(Momentum)和梯度平方的指数移动平均。  
  • 每个参数存储动量(m)和平方梯度(v),占用约 `101,770 × 8 Byte ≈ 806 KB`  
  • 动量(m):每个参数对应一个动量值,数据类型与参数相同(float32),占用 403 KB。
  • 梯度平方(v):每个参数对应一个梯度平方值,数据类型与参数相同(float32),占用 403 KB。

3.3.数据批量(batch_size)的显存占用

  • 单张图像尺寸:`1×28×28`(通道×高×宽),归一化转换为张量后为`float32`类型  

          单张图像显存占用:`1×28×28×4 Byte = 3,136 Byte ≈ 3 KB`  

  • 批量数据占用:`batch_size × 单张图像占用`  

          例如:`batch_size=64` 时,数据占用为 `64×3 KB ≈ 192 KB`  

         `batch_size=1024` 时,数据占用为 `1024×3 KB ≈ 3 MB`

3.4. 前向/反向传播中间变量

  • 对于两层MLP,中间变量(如`layer1`的输出)占用较小:  

  - `batch_size×128`维向量:`batch_size×128×4 Byte = batch_size×512 Byte`  

  - 例如`batch_size=1024`时,中间变量约 `512 KB`

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/92730.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/92730.shtml
英文地址,请注明出处:http://en.pswp.cn/web/92730.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

十八、MySQL-DML-数据操作-插入(增加)、更新(修改)、删除

DML数据操作添加数据更新(修改)数据删除数据总结代码: -- DML:数据操作语言-- -- DML:插入数据-insert -- 1.为tb_emp表的username,name,gender 字股插入值insert into tb_emp(username,name,gender,create_time,update_time) values (Toki,小时,2,now()…

Linux 安装 JDK 8u291 教程(jdk-8u291-linux-x64.tar.gz 解压配置详细步骤)​

一、准备工作 ​下载 JDK 安装包​ 去 Oracle 官网或者可信的镜像站下载: ​jdk-8u291-linux-x64.tar.gz​ (这是一个压缩包,不是安装程序,解压就能用) ​jdk-8u291-linux-x64.tar.gz​下载链接:https://pa…

蓝桥杯----锁存器、LED、蜂鸣器、继电器、Motor

(七)、锁存器1、原理蓝桥杯中数据传入口都是P0,也就是数码管段选、位选数据、LED亮灭的数据、蜂鸣器启动或禁用的数据,外设启动或者关闭都需要通过P0写入数据,那么如何这样共用一个端口会造成冲突嘛,答案是肯定的。所以蓝桥杯加入…

AI热点周报(8.3~8.9):OpenAI重返开源,Anthropic放大招,Claude4.1、GPT5相继发布

名人说:博观而约取,厚积而薄发。——苏轼《稼说送张琥》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 目录一、OpenAI的"开源回归":时隔5年的战略大转弯1. GPT-OSS系列&a…

《Kubernetes部署篇:基于x86_64+aarch64架构CPU+containerd一键离线部署容器版K8S1.33.3高可用集群》

总结:整理不易,如果对你有帮助,可否点赞关注一下? 更多详细内容请参考:企业级K8s集群运维实战 一、部署背景 由于业务系统的特殊性,我们需要针对不同的客户环境部署基于containerd容器版 K8S 1.33.3集群&a…

Linux抓包命令tcpdump详解笔记

文章目录一、tcpdump 是什么?二、基本语法三、常用参数说明四、抓包示例(通俗易懂)1. 抓所有数据包(默认 eth0)2. 指定接口抓包3. 抓取端口 80 的数据包(即 HTTP 请求)4. 抓取访问某个 IP 的数据…

抖音、快手、视频号等多平台视频解析下载 + 磁力嗅探下载、视频加工(提取音频 / 压缩等)

跟你们说个安卓上的下载工具,还挺厉害的。它能支持好多种下载方式,具体多少种我没细数,反正挺全乎的。​ 平时用得最多的就是视频解析,像抖音、快手、B 站上那些视频,想存下来直接用它就行,连海外视频的也能…

【iOS】JSONModel源码学习

JSONModel源码学习前言JSONModel的使用最基础的使用转换属性名称自定义错误模型嵌套JSONModel的继承源码实现initWithDictionaryinit__doesDictionaryimportDictionary优点前言 之前了解过JSONModel的一些使用方法等,但是对于底层实现并不清楚了解,今天…

SmartMediaKit 模块化音视频框架实战指南:场景链路 + 能力矩阵全解析

✳️ 引言:从“内核能力”到“模块体系”的演进 自 2015 年起,大牛直播SDK(SmartMediaKit)便致力于打造一个可深度嵌入、跨平台兼容、模块自由组合的实时音视频基础能力框架。经过多轮技术迭代与场景打磨,该 SDK 已覆…

【第5话:相机模型1】针孔相机、鱼眼相机模型的介绍及其在自动驾驶中的作用及使用方法

相机模型介绍及相机模型在自动驾驶中的作用及使用方法 相机模型是计算机视觉中的核心概念,用于描述真实世界中的点如何投影到图像平面上。在自动驾驶系统中,相机模型用于环境感知,如物体检测和场景理解。下面我将详细介绍针孔相机模型和鱼眼相…

推荐一款优质的开源博客与内容管理系统

Halo是一款由Java Spring Boot打造的开源博客与内容管理系统(CMS),在 GitHub上拥有超过36K Start的活跃开发者社区。它使用GPL‑3.0授权开源,稳定性与可维护性极高。 Halo的设计简洁、注重性能,同时保持高度灵活性&a…

【GPT入门】第43课 使用LlamaFactory微调Llama3

【GPT入门】第43课 使用LlamaFactory微调Llama31.环境准备2. 下载基座模型3.LLaMA-Factory部署与启动4. 重新训练![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/e7aa869f8e2c4951a0983f0918e1b638.png)1.环境准备 采购autodl服务器,24G,GPU,型号3090&am…

计算机网络:如何理解目的网络不再是一个完整的分类网络

这一理解主要源于无分类域间路由(CIDR)技术的广泛应用,它打破了传统的基于类的IP地址分配方式。具体可从以下方面理解: 传统分类网络的局限性:在早期互联网中,IP地址被分为A、B、C等固定类别,每…

小米开源大模型 MiDashengLM-7B:不仅是“听懂”,更能“理解”声音

目录 前言 一、一枚“重磅炸弹”:开源,意味着一扇大门的敞开 二、揭秘MiDashengLM-7B:它究竟“神”在哪里? 2.1 “超级耳朵” 与 “智慧大脑” 的协作 2.2 突破:从 “听见文字” 到 “理解世界” 2.3 创新训练&a…

mysql出现大量redolog、undolog排查以及解决方案

排查步骤 监控日志增长情况 -- 查看InnoDB状态 SHOW ENGINE INNODB STATUS;-- 查看redo log配置和使用情况 SHOW VARIABLES LIKE innodb_log_file%; SHOW VARIABLES LIKE innodb_log_buffer_size;-- 查看undo log信息 SHOW VARIABLES LIKE innodb_undo%;检查长时间运行的事务 -…

华为网路设备学习-28(BGP协议 三)路由策略

目录: 一、BGP路由汇总1、注:使用network命令注入的BGP不会被自动汇总2、主类网络号计算过程如下:3.示例 开启BGP路由自动汇总bgp100 开启BGP路由自动汇总import-route 直连路由 11.1.1.0 /24对端 为 10.1.12.2 AS 2004.手动配置BGP路…

微信小程序中实现表单数据实时验证的方法

一、实时验证的基本实现思路表单实时时验证通过监听表单元素的输入事件,在用户输入过程中即时对数据进行校验,并并即时反馈验证结果,主要实现步骤包括:为每个表单字段绑定输入事件在事件处理函数中获取当前输入值应用验证规则进行…

openpnp - 顶部相机如果超过6.5米影响通讯质量,可以加USB3.0信号放大器延长线

文章目录openpnp - 顶部相机如果超过6.5米影响通讯质量,可以加USB3.0信号放大器延长线概述备注ENDopenpnp - 顶部相机如果超过6.5米影响通讯质量,可以加USB3.0信号放大器延长线 概述 手头有1080x720x60FPS的摄像头模组备件,换上后&#xff…

【驱动】RK3576-Debian系统使用ping报错:socket operation not permitted

1、问题描述 在RK3576-Debian系统中,连接了Wifi后,测试网络通断时,报错: ping www.csdn.net ping: socktype: SOCK_RAW ping: socket: Operation not permitted ping: => missing cap_net_raw+p capability or setuid?2、原因分析 2.1 分析打印日志 socktype: SOCK…

opencv:图像轮廓检测与轮廓近似(附代码)

目录 图像轮廓 cv2.findContours(img, mode, method) 绘制轮廓 轮廓特征与近似 轮廓特征 轮廓近似 轮廓近似原理 opencv 实现轮廓近似 轮廓外接矩形 轮廓外接圆 图像轮廓 cv2.findContours(img, mode, method) mode:轮廓检索模式(通常使用第四个模式&am…