在 PyTorch 中,nn.Module
是所有神经网络模块的基类,用于构建和组织深度学习模型。它提供了一系列工具和功能,使模型的定义、训练和部署更加高效和灵活。
nn
= Neural Network(神经网络)
核心作用:
模块化设计:将复杂的神经网络拆分为多个可复用的子模块(如卷积层、全连接层、激活函数等)。
自动管理参数:通过
nn.Parameter
自动追踪模型中的可训练参数(如权重和偏置)。前向传播定义:通过重写
forward()
方法定义数据的流动路径。设备管理:一键将整个模型移至 GPU(如
model.to('cuda')
)。训练 / 评估模式切换:通过
model.train()
和model.eval()
切换模型状态(影响 Dropout、BatchNorm 等层)。
Torch.nn 组件库
torch.nn
是 PyTorch 中用于构建神经网络的核心库,提供了各种预定义的层、损失函数和工具类,帮助你快速搭建和训练深度学习模型。它的主要特点包括:
模块化设计:所有组件(如卷积层、全连接层)都继承自
nn.Module
,便于组织和复用。自动求导:与 PyTorch 的自动微分系统(
autograd
)无缝集成。GPU 支持:所有组件可轻松迁移至 GPU 运行。
torch.nn
是工具盒,nn.Module
是组装工具的框架。自定义模型时,你需要用 nn.Module
作为基类,并用 torch.nn
中的组件填充它。 nn.Module
是 torch.nn
模块库中的一个核心基类,用于定义和组织神经网络。它是 PyTorch 模型构建的基础,所有层和模型最终都继承自它。
官网示例代码:
import torch.nn as nn
import torch.nn.functional as Fclass Model(nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):x = F.relu(self.conv1(x))return F.relu(self.conv2(x))
代码练习:
import torch
from torch import nnclass myModule(nn.Module):def __init__(self):super().__init__()def forward(self,input):output=input+1return outputmodel1=myModule()
x = torch.tensor(1.0)
output=model1(x)
print(output)