一、理解 nn.Parameter
-
本质是什么?
nn.Parameter
是torch.Tensor
的一个子类。- 这意味着它继承了 Tensor 的所有属性和方法(如
.data
,.grad
,.requires_grad
,.shape
,.dtype
,.device
,.backward()
等)。 - 它本身不是一个函数或模块,而是一种特殊的张量类型。
-
核心目的/作用:
- 标识模型参数: 它的主要作用是标记一个 Tensor,告诉 PyTorch 这个 Tensor 是模型的一部分,是需要被优化器在训练过程中更新(学习) 的参数。
- 自动注册: 当一个
nn.Parameter
被分配为一个nn.Module
的属性时(通常在__init__
方法中),PyTorch 会自动将其注册到该 Module 的parameters()
列表中。这是最关键的特性! - 优化器目标: 优化器(如
torch.optim.SGD
,torch.optim.Adam
)通过调用model.parameters()
来获取所有需要更新的参数。只有注册了的nn.Parameter
(以及nn.Module
子模块中的nn.Parameter
)才会被包含在这个列表中。
-
与普通 Tensor (
torch.Tensor
) 的区别:特性 nn.Parameter
普通 torch.Tensor
自动注册 ✅ 当是 Module 属性时自动加入 parameters()
❌ 不会自动注册 优化器更新目标 ✅ 默认会被优化器更新 ❌ 默认不会被优化器更新 requires_grad
默认为 True
默认为 False
用途 定义模型需要学习的权重 (Weights) 和偏置 (Biases) 存储输入数据、中间计算结果、常量、缓冲区等 -
关键结论:
- 如果你想定义一个会被优化器更新的模型参数(权重、偏置),务必使用
nn.Parameter
包装你的 Tensor,并将其设置为nn.Module
的属性。 - 普通 Tensor 即使
requires_grad=True
,如果没有被注册(通过nn.Parameter
或register_parameter()
),也不会被优化器更新。它们可能用于存储需要梯度的中间状态或自定义计算。
- 如果你想定义一个会被优化器更新的模型参数(权重、偏置),务必使用
二、nn.Parameter
的创建与初始化方法大全
创建 nn.Parameter
的核心是:先创建一个 Tensor,然后用 nn.Parameter()
包装它。初始化方法的多样性体现在如何创建这个底层的 Tensor。以下是常见方法:
方法 1:直接包装 Tensor (最灵活)
import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self, input_size, output_size):super().__init__()# 方法 1a: 使用 torch.tensor 创建并包装self.weight = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]])) # 显式指定值 (不常用)# 方法 1b: 使用 torch 函数创建并包装 (最常用!)self.weight = nn.Parameter(torch.randn(input_size, output_size)) # 正态分布初始化self.bias = nn.Parameter(torch.zeros(output_size)) # 常数初始化 (0)def forward(self, x):return x @ self.weight + self.bias
- 优点: 绝对控制,可以使用任何创建 Tensor 的函数。
- 常用函数:
torch.randn(*size)
: 标准正态分布 (均值 0, 标准差 1) 初始化。最常用基础初始化。torch.rand(*size)
: [0, 1) 均匀分布初始化。torch.zeros(*size)
: 全 0 初始化 (常用于偏置)。torch.ones(*size)
: 全 1 初始化 (较少直接用于权重)。torch.full(size, fill_value)
: 用指定值填充。torch.empty(*size).uniform_(-a, a)
: 在[-a, a]
均匀分布初始化 (手动实现均匀分布)。torch.empty(*size).normal_(mean, std)
: 指定均值和标准差的正态分布初始化 (手动实现正态分布)。
方法 2:使用 torch.nn.init
模块 (推荐用于特定初始化策略)
PyTorch 提供了 torch.nn.init
模块,包含许多常用的、研究证明有效的初始化函数。这些函数通常原地修改传入的 Tensor。
class MyModule(nn.Module):def __init__(self, input_size, output_size):super().__init__()# 1. 先创建一个未初始化的或基础初始化的 Tensor (通常用 empty 或 zeros)self.weight = nn.Parameter(torch.empty(input_size, output_size))self.bias = nn.Parameter(torch.zeros(output_size))# 2. 应用 nn.init 函数进行初始化 (原地操作)nn.init.xavier_uniform_(self.weight) # Xavier/Glorot 均匀初始化 (适用于 tanh, sigmoid)# nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu') # Kaiming/He 正态初始化 (适用于 ReLU)nn.init.constant_(self.bias, 0.1) # 将偏置初始化为常数 0.1
- 优点: 使用经过验证的、针对不同激活函数设计的初始化策略,通常能获得更好的训练起点和稳定性。代码更清晰表达意图。
- 常用
nn.init
函数:- 常数初始化:
nn.init.constant_(tensor, val)
:用val
填充。nn.init.zeros_(tensor)
:全 0。nn.init.ones_(tensor)
:全 1。
- 均匀分布初始化:
nn.init.uniform_(tensor, a=0.0, b=1.0)
:[a, b]
均匀分布。
- 正态分布初始化:
nn.init.normal_(tensor, mean=0.0, std=1.0)
:指定mean
和std
的正态分布。
- Xavier / Glorot 初始化 (适用于饱和激活函数如 tanh, sigmoid):
nn.init.xavier_uniform_(tensor, gain=1.0)
:均匀分布,范围基于gain
和输入/输出神经元数 (fan_in/fan_out) 计算。nn.init.xavier_normal_(tensor, gain=1.0)
:正态分布,标准差基于gain
和 fan_in/fan_out 计算。
- Kaiming / He 初始化 (适用于 ReLU 及其变种):
nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
:均匀分布。nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
:正态分布。mode
:'fan_in'
(默认, 保持前向传播方差) 或'fan_out'
(保持反向传播方差)。nonlinearity
:'relu'
(默认) 或'leaky_relu'
。
- 正交初始化 (Orthogonal Initialization):
nn.init.orthogonal_(tensor, gain=1)
:生成正交矩阵(或行正交、列正交的张量),有助于缓解深度网络中的梯度消失/爆炸。
- 单位矩阵初始化 (Identity Initialization):
nn.init.eye_(tensor)
:尽可能将张量初始化为单位矩阵(对于非方阵,会初始化成尽可能接近单位矩阵的形式)。适用于某些 RNN 或残差连接。
- 对角矩阵初始化 (Diagonal Initialization):
nn.init.dirac_(tensor)
:尽可能初始化为 Dirac delta 函数(多维卷积核中常用,保留输入通道信息)。主要用于卷积层。
- 常数初始化:
方法 3:使用 nn.Linear
, nn.Conv2d
等内置模块 (隐式初始化)
当你使用 PyTorch 提供的标准层(如 nn.Linear
, nn.Conv2d
, nn.LSTM
等)时,它们内部已经使用 nn.Parameter
定义了权重和偏置,并自动应用了合理的默认初始化策略。
class MyModule(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(784, 256) # 内部包含 weight (Parameter) 和 bias (Parameter)self.fc2 = nn.Linear(256, 10)# 这些层的参数已经被初始化了 (通常是某种均匀或正态分布)
- 优点: 最方便,无需手动创建和初始化参数。对于标准层,默认初始化通常足够好。
- 查看/修改内置模块的初始化:
- 你可以通过
nn.Linear.weight.data
或nn.Conv2d.bias
访问这些内置的nn.Parameter
。 - 如果你想修改默认初始化,可以在创建层后,使用
nn.init
函数重新初始化它们的.weight
和.bias
:self.fc1 = nn.Linear(784, 256) nn.init.kaiming_normal_(self.fc1.weight, mode='fan_out', nonlinearity='relu') nn.init.constant_(self.fc1.bias, 0.0)
- 你可以通过
方法 4:从另一个模块或状态字典加载 (预训练/迁移学习)
# 假设 pretrained_model 是一个已经训练好的模型
pretrained_dict = pretrained_model.state_dict()# 创建新模型
new_model = MyModule()# 获取新模型的状态字典
model_dict = new_model.state_dict()# 1. 严格加载: 键必须完全匹配
new_model.load_state_dict(pretrained_dict)# 2. 非严格加载: 只加载键匹配的参数
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
new_model.load_state_dict(model_dict)# 3. 部分加载/初始化: 手动指定
new_model.some_layer.weight.data.copy_(pretrained_model.some_other_layer.weight.data) # 直接复制数据
- 优点: 迁移学习、微调、模型集成的基础。利用预训练知识加速训练或提升性能。
三、总结与最佳实践
-
何时用
nn.Parameter
?- 总是用它来定义你的模型层中需要被优化器更新的权重 (Weights) 和偏置 (Biases)。
- 对于模型中的输入数据、中间计算结果、常量、统计量(如 BatchNorm 的 running_mean)等不需要更新的张量,使用普通
torch.Tensor
(通常通过self.register_buffer()
注册为缓冲区,以便正确地在设备间移动和序列化)。
-
初始化方法选择:
- 新手/快速原型: 使用内置层 (
nn.Linear
,nn.Conv2d
等),它们有合理的默认初始化。 - 自定义层/需要特定策略:
- 基础:
torch.randn
(正态),torch.zeros
(偏置)。 - 推荐: 优先使用
torch.nn.init
模块中的函数:- 对于使用
tanh
/sigmoid
的网络:考虑nn.init.xavier_uniform_
/nn.init.xavier_normal_
。 - 对于使用
ReLU
/LeakyReLU
的网络:强烈推荐nn.init.kaiming_uniform_
/nn.init.kaiming_normal_
(根据mode
和nonlinearity
选择)。
- 对于使用
- 特殊需求: 常数 (
nn.init.constant_
),正交 (nn.init.orthogonal_
),单位矩阵 (nn.init.eye_
),Dirac (nn.init.dirac_
)。
- 基础:
- 新手/快速原型: 使用内置层 (
-
关键步骤:
- 在
__init__
中:- 使用
torch.*
函数或torch.empty
创建一个 Tensor。 - 用
nn.Parameter()
包装这个 Tensor。 - (可选但推荐) 使用
nn.init.*_
函数对这个nn.Parameter
进行原地初始化(如果使用torch.randn
/torch.zeros
等创建时已初始化,此步可省)。
- 使用
- 对于内置层,初始化已自动完成,但可以按需修改。
- 对于预训练模型,使用
load_state_dict
加载参数进行初始化。
- 在
-
注意:
- 初始化参数的大小 (
size
) 必须根据层的设计(输入维度、输出维度、卷积核大小等)正确设置。 - 确保
nn.Parameter
被设置为nn.Module
的属性(直接赋值),否则不会被自动注册到parameters()
中。 - 理解不同初始化方法背后的原理(如 Xavier/Glorot 考虑输入输出方差, Kaiming/He 考虑 ReLU 的修正)对于设计深层网络非常重要。
- 初始化参数的大小 (