说下register_buffer和Parameter的异同
相同点方面 描述 追踪 都会被加入 state_dict
(模型保存时会保存下来)。 与 Module
的绑定 都会随着模型移动到 cuda
/ cpu
/ float()
等而自动迁移。 都是 nn.Module
的一部分 都可以通过模块属性访问,如 self.x
。
不同点方面 torch.nn.Parameter
register_buffer
是否是可训练参数 ✅ 是,会被视为模型需要优化的参数(model.parameters()
中包含) ❌ 否,不会被优化器更新 梯度计算 默认 requires_grad=True
,参与反向传播 默认 requires_grad=False
,不参与反向传播 用途场景 模型的权重、偏置 等需要学习的参数 均值、方差、mask、位置编码等常量或状态 ,如 BatchNorm 中的 running mean/var注册方式 self.w = nn.Parameter(tensor)
或 self.register_parameter("w", nn.Parameter(...))
self.register_buffer("buf", tensor)
是否显示在 parameters()
中 ✅ 会显示 ❌ 不会显示 是否能直接赋值注册 ✅ 可以直接赋值 ❌ 必须通过 register_buffer()
注册,否则不会记录到 state_dict
使用建议情境 推荐使用 需要优化 nn.Parameter
只做记录或参与计算但不优化 register_buffer
实现自定义模块(如 BatchNorm)时的状态 register_buffer
使用位置编码、attention mask register_buffer
模型保存中需要但不训练 register_buffer
这里我自己写了一个测试代码,分别运行ToyModel1 2 3 保存并读取,相信会对这两个函数有很深刻的认识。
import torch
import torch. nn as nn
import torch. nn. functional as Fclass ToyModel ( nn. Module) : def __init__ ( self, inChannels, outChannels) : super ( ) . __init__( ) self. a1 = 1 self. a2 = 2 self. linear = nn. Linear( inChannels, outChannels) self. init_weights( ) def init_weights ( self) : for m in self. modules( ) : if isinstance ( m, nn. Linear) : nn. init. xavier_uniform_( m. weight) nn. init. zeros_( m. bias) def forward ( self, x) : out = self. linear( x) return outclass ToyModel2 ( nn. Module) : def __init__ ( self, inChannels, outChannels) : super ( ) . __init__( ) self. a1 = 1 self. a2 = 2 self. linear = nn. Linear( inChannels, outChannels) self. init_weights( ) self. b1 = nn. Parameter( torch. randn( outChannels) , ) def init_weights ( self) : for m in self. modules( ) : if isinstance ( m, nn. Linear) : nn. init. xavier_uniform_( m. weight) nn. init. zeros_( m. bias) def forward ( self, x) : out = self. linear( x) out += self. b1return outclass ToyModel3 ( nn. Module) : def __init__ ( self, inChannels, outChannels) : super ( ) . __init__( ) self. a1 = 1 self. a2 = 2 self. linear = nn. Linear( inChannels, outChannels) self. init_weights( ) self. b1 = nn. Parameter( torch. randn( outChannels) , ) self. register_buffer( "c1" , torch. ones_like( self. b1) , persistent= True ) def init_weights ( self) : for m in self. modules( ) : if isinstance ( m, nn. Linear) : nn. init. xavier_uniform_( m. weight) nn. init. zeros_( m. bias) def forward ( self, x) : out = self. linear( x) out += self. b1out += self. c1return out
import torch
import torch. nn as nn
import torch. nn. functional as F
import logging
from pathlib import Pathfrom models import ToyModel2, ToyModel, ToyModel3logging. basicConfig( level= logging. INFO, format = '%(asctime)s - %(name)s - %(levelname)s - %(lineno)s - %(message)s' ) if __name__ == "__main__" : savePath = Path( "toymodel3.pth" ) logger = logging. getLogger( __name__) inp = torch. randn( 3 , 5 ) model = ToyModel3( inp. size( 1 ) , inp. size( 1 ) * 2 ) pred = model( inp) logger. info( f" { pred. size( ) = } " ) for m in model. modules( ) : logger. info( m) for name, param in model. named_parameters( ) : logger. info( f" { name = } , { param. size( ) = } , { param. requires_grad= } " ) for name, buffer in model. named_buffers( ) : logger. info( f" { name = } , { buffer . size( ) = } " ) torch. save( model. state_dict( ) , savePath)
import torch
import torch. nn as nn
import torch. nn. functional as F
from pathlib import Pathfrom models import ToyModel, ToyModel2, ToyModel3if __name__ == "__main__" : savePath = Path( "toymodel3.pth" ) inp = torch. randn( 3 , 5 ) model = ToyModel3( inp. size( 1 ) , inp. size( 1 ) * 2 ) ckpt = torch. load( savePath, map_location= "cpu" , weights_only= True ) model. load_state_dict( ckpt) pred = model( inp) print ( f" { pred. size( ) = } " ) for m in model. modules( ) : print ( m) for name, param in model. named_parameters( ) : print ( f" { name = } , { param. size( ) = } , { param. requires_grad= } " ) for name, buffer in model. named_buffers( ) : print ( f" { name = } , { buffer . size( ) = } " )