说下register_buffer和Parameter的异同

相同点

方面描述
追踪都会被加入 state_dict(模型保存时会保存下来)。
Module 的绑定都会随着模型移动到 cuda / cpu / float() 等而自动迁移。
都是 nn.Module 的一部分都可以通过模块属性访问,如 self.x

不同点

方面torch.nn.Parameterregister_buffer
是否是可训练参数✅ 是,会被视为模型需要优化的参数(model.parameters() 中包含)❌ 否,不会被优化器更新
梯度计算默认 requires_grad=True,参与反向传播默认 requires_grad=False,不参与反向传播
用途场景模型的权重、偏置等需要学习的参数均值、方差、mask、位置编码等常量或状态,如 BatchNorm 中的 running mean/var
注册方式self.w = nn.Parameter(tensor)self.register_parameter("w", nn.Parameter(...))self.register_buffer("buf", tensor)
是否显示在 parameters()✅ 会显示❌ 不会显示
是否能直接赋值注册✅ 可以直接赋值❌ 必须通过 register_buffer() 注册,否则不会记录到 state_dict

使用建议

情境推荐使用
需要优化nn.Parameter
只做记录或参与计算但不优化register_buffer
实现自定义模块(如 BatchNorm)时的状态register_buffer
使用位置编码、attention maskregister_buffer
模型保存中需要但不训练register_buffer

这里我自己写了一个测试代码,分别运行ToyModel1 2 3 保存并读取,相信会对这两个函数有很深刻的认识。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ToyModel(nn.Module):def __init__(self, inChannels, outChannels):super().__init__()self.a1 = 1 # 实例成员,不会保存在ckpt中self.a2 = 2self.linear = nn.Linear(inChannels, outChannels)self.init_weights()def init_weights(self):for m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)nn.init.zeros_(m.bias)def forward(self, x):out = self.linear(x)return outclass ToyModel2(nn.Module):def __init__(self, inChannels, outChannels):super().__init__()self.a1 = 1 # 实例成员,不会保存在ckpt中self.a2 = 2self.linear = nn.Linear(inChannels, outChannels)self.init_weights()self.b1 = nn.Parameter(torch.randn(outChannels),) # 模型参数,requires_grad=True, 保存进ckptdef init_weights(self):for m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)nn.init.zeros_(m.bias)def forward(self, x):out = self.linear(x)out += self.b1return outclass ToyModel3(nn.Module):def __init__(self, inChannels, outChannels):super().__init__()self.a1 = 1 # 实例成员,不会保存在ckpt中self.a2 = 2self.linear = nn.Linear(inChannels, outChannels)self.init_weights()self.b1 = nn.Parameter(torch.randn(outChannels),)self.register_buffer("c1", torch.ones_like(self.b1), persistent=True) # 类成员,requires_grad=False, 保存进ckpt,用于保存需要直接计算的常量,可以用self.c1访问def init_weights(self):for m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)nn.init.zeros_(m.bias)def forward(self, x):out = self.linear(x)out += self.b1out += self.c1return out
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from pathlib import Pathfrom models import ToyModel2, ToyModel, ToyModel3logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(name)s - %(levelname)s - %(lineno)s - %(message)s')if __name__ == "__main__":savePath = Path("toymodel3.pth")logger = logging.getLogger(__name__)inp = torch.randn(3, 5)model = ToyModel3(inp.size(1), inp.size(1) * 2)pred = model(inp)logger.info(f"{pred.size()=}")for m in model.modules():logger.info(m)for name, param in model.named_parameters():logger.info(f"{name = }, {param.size() = }, {param.requires_grad=}")for name, buffer in model.named_buffers():logger.info(f"{name = }, {buffer.size() = }")torch.save(model.state_dict(), savePath)
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Pathfrom models import ToyModel, ToyModel2, ToyModel3if __name__ == "__main__":savePath = Path("toymodel3.pth")inp = torch.randn(3, 5)model = ToyModel3(inp.size(1), inp.size(1) * 2)ckpt = torch.load(savePath, map_location="cpu", weights_only=True)model.load_state_dict(ckpt)pred = model(inp)print(f"{pred.size()=}")for m in model.modules():print(m)for name, param in model.named_parameters():print(f"{name = }, {param.size() = }, {param.requires_grad=}")for name, buffer in model.named_buffers():print(f"{name = }, {buffer.size() = }")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/90600.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/90600.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/90600.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

吉吉巳资源整站源码完整打包,适用于搭建资源聚合/整合类站点,全网独家,拿来就用

想要搭建一个资源整合站点,如影视聚合类站点、资讯聚合类站点、图集聚合类站点等,需要花费大量的时间来查找合适的系统或源码。然后要去测试,修复bug,一直到能够正常的运营使用,花费的时间绝对不短,今天分享…

嵌入式学习的第三十五天-进程间通信-HTTP

TCP/IP协议模型:应用层:HTTP;传输层:TCP UDP;网络层:IPv4 IPv6网络接口层一、HTTP协议1. 万维网WWW(World Wide Web) 世界范围内的,联机式的信息储藏所。 万维网解决了获取互联网上的数据时需要解决的以下问题&#x…

es 和 lucene 的区别

1. Lucene 是“发动机”,ES 是“整车”Lucene:只是一个 Java 库,提供倒排索引、分词、打分等底层能力。你必须自己写代码处理索引创建、更新、删除、分片、分布式、故障恢复、API 封装等所有逻辑。Elasticsearch:基于 Lucene 的分…

AS32S601 系列 MCU芯片GPIO Sink/Source 能力测试方法

一、引言随着电子技术的飞速发展,微控制器(MCU)在工业控制、汽车电子、商业航天等众多领域得到了广泛应用。国科安芯推出的AS32S601 系列 MCU 以其卓越的性能和可靠性,成为了众多设计工程师的首选之一。为了确保其在实际应用中的稳…

JAVA-08(2025.07.24学习记录)

面向对象类package com.mm;public class Person {/*** 名词-属性*/String name;int age;double height;/*** 动词-方法*/public void sleep(String add) {System.out.println("我在" add "睡觉");}public String introduce() {return "我的名字是&q…

地下隧道管廊结构健康监测系统 测点的布设及设备选型

隧道监测背景 隧道所处地下环境复杂,在施工过程中会面临围堰变形、拱顶沉降、净空收敛、初衬应力变化、土体塌方等多种危险情况。在隧道营运过程中,也会受到材料退化、地震、人为破坏等因素影响,引发隧道主体结构的劣化和损坏,若不…

node.js卸载与安装超详细教程

文章目录一、卸载Step1:通过控制面板删除node版本Step2:删除node的安装目录Step3:查找.npmrc文件是否存在,有就删除。Step4:查看以下文件是否存在,有就删除Step5:打开系统设置,检查系…

飞算JavaAI“删除接口信息” 功能:3 步清理冗余接口,让管理效率翻倍

在飞算JavaAI的接口设计与管理流程中,“删除接口信息” 功能为用户提供了灵活调整接口方案的便利。该功能的存在,让用户能够在接口生命周期的前期(审核阶段)及时清理无需创建的接口,保证接口管理的简洁性与高效性。一、…

行业热点丨SimLab解决方案如何高效应对3D IC多物理场与ECAD建模挑战?

半导体行业正快速超越传统2D封装技术,积极采用 3D集成电路(3D ICs)和2.5D 先进封装等方案。这些技术通过异构芯粒、硅中介层和复杂多层布线实现更高性能与集成度。然而,由于电子计算机辅助设计(ECAD)数据规…

2025暑期—05神经网络-BP网络

按误差反向传播(简称误差反传)训练的多层前馈网络线性回归或者分类不需要使用神经元,原有最小二程即可。求解J依次变小。使用泰勒展开,只看第一阶。偏导是确定的,需要让J小于0的delta WkWk构造完成后 J(Wk1)已知&#…

qml的信号槽机制

qml的信号槽机制和qtwidget差不多,但是使用方法不一样,qtwidget一般直接用connect函数把信号和槽一绑定就完事了,qml分为自动绑定和手动绑定。信号自动绑定在一个组件里面定义一个信号,用signal定义,当事件触发&#x…

Unity国际版下载链接分享(非c1国内版)

转载Unity国际版下载链接分享(非c1国内版) - 哔哩哔哩 大家平时使用Unity注意一下会发现,现在我们下载的Unity版本号后面都一个c1,但是大家在B站学习时大神UP主们使用的Unity版本号大都是没有c1的。 例如:我在用的是…

第4章唯一ID生成器——4.1 分布式唯一ID

在复杂的系统中,每个业务实体都需要使用ID做唯一标识,以方便进行数据操作。例如,每个用户都有唯一的用户ID,每条内容都有唯一的内容ID,甚至每条内容下的每条评论都有唯一的评论ID。 4.1.1 全局唯一与UUID 在互联网还未…

图论水题日记

cf1805D 题意 给定一棵树,规定dis(u,v)≥kdis(u,v) \geq kdis(u,v)≥k时(u,v)(u,v)(u,v)之间存在一条无向边,求k(1,2,...n)k(1,2,...n)k(1,2,...n)时图中的连通块个数 思路 前置知识:树上一点到其最远的点一定是树直径的两个端点之一若一个点…

自定义线程

每个程序至少有一个线程 —— 主线程 主线程是程序的起点,你可以从它开始创建新的线程来执行任务。为此,你需要创建自定义线程,编写在线程中执行的代码,并启动它。 通过继承创建自定义线程 创建新线程有两种主要方式:继…

2025真实面试试题分析-安卓客户端开发

以下是对安卓客户端开发工程师面试问题的分类整理、领域占比分析及高频问题精选(基于​​85道问题,总出现次数118次​​)。按技术领域整合为​​7大核心类别​​,按占比排序并精选高频问题标注优先级(1-5🌟…

算法学习笔记:29.拓扑排序——从原理到实战,涵盖 LeetCode 与考研 408 例题

拓扑排序(Topological Sorting)是一种针对有向无环图(DAG)的线性排序算法,它将图中的顶点按照一定规则排列,使得对于图中的任意一条有向边 u→v,顶点 u 都排在顶点 v 之前。拓扑排序在任务调度、…

利用Web3加密技术保障您的在线数据安全

在这个信息爆炸的数字化时代,保护个人和企业数据安全变得尤为重要。Web3技术以其去中心化和加密特性,为在线数据安全提供了新的解决方案。本文将探讨Web3技术如何通过加密技术保障您的在线数据安全,并介绍如何有效利用这些技术。 什么是Web3技…

Vue实现el-checkbox单选并回显选中

先说需求 我要在页面进行checkbox单选并回显 第一步先把基本的页面写好噢&#xff1a;vue代码&#xff1a;别忘了写change啊<el-form-item label"按钮颜色:" prop"menuColor"><el-checkbox-group v-model"buttonColor" change"bin…

动态规划--序列找优问题【1】

一、说明 动态规划似乎针对问题很多&#xff0c;五花八门&#xff0c;似乎每一个问题都有一套具体算法。其实不是的&#xff0c;动态规划只有两类&#xff1a;1&#xff09;针对图的路径问题 2&#xff09;针对一个序列的问题。本篇讲动态规划针对序列的算法范例。 二、动态规划…