AlexNet(详解)——从原理到 PyTorch 实现(含训练示例)

文章目录

  • AlexNet(详解)——从原理到 PyTorch 实现(含训练示例)
  • 1. 发展历史与比赛成绩
  • 2. AlexNet 的核心思想(一句话)
  • 3. 模型结构总览(概览表)
  • 4. 逐层计算举例(重点:尺寸 & 参数如何得到)
      • 例 1:Conv1 输出尺寸(两种常见约定)
      • 例 2:参数量计算(按层逐项示例)
  • 5. 关键设计点解析(为什么这些创新重要)
  • 6. PyTorch 实现(完整代码 —— 可复制粘贴)
  • 7. 训练与评估(实践步骤 + 超参数建议)
  • 8. 实验扩展(建议做的对比实验)
  • 9. 总结

简介(为什么写这篇文章)
AlexNet 是 2012 年由 Alex Krizhevsky、Ilya Sutskever 和 Geoffrey Hinton 提出的卷积神经网络,它在 ILSVRC-2012 上大幅度优于当时其他方法,标志着深度学习在大规模视觉识别上的一次转折。本文目标是:讲清 AlexNet 的发展/比赛成绩、核心思想与创新、逐层结构与维度/参数计算举例,并给出一个可运行的 PyTorch 实现与训练示例
论文地址


1. 发展历史与比赛成绩

  • 作者 / 时间:Alex Krizhevsky, Ilya Sutskever, Geoffrey Hinton,发表于 NIPS 2012(论文标题 ImageNet Classification with Deep Convolutional Neural Networks)。这是 AlexNet 的权威来源。([NeurIPS 会议录][1])
  • 比赛成绩:AlexNet 在 ILSVRC-2012 上获得显著胜出 —— top-5 错误率约 15.3%,相较第二名有非常大的优势(原文和后续资料中有对比说明)。这次胜利推动了深度卷积网络在计算机视觉领域的广泛应用。([NeurIPS 会议录][1], [维基百科][2])

2. AlexNet 的核心思想(一句话)

把较深的卷积网络(比当时常见的浅网络更深)、非饱和激活(ReLU)、大量数据(ImageNet)和 GPU 加速结合起来:通过 局部感受野 + 权值共享 + 下采样(池化) 学习层次化特征,并用若干技巧(ReLU、数据增强、Dropout、局部响应归一化等)防止过拟合与加速训练,从而在大型图像分类任务上取得突破。([NeurIPS 会议录][1])


3. 模型结构总览(概览表)

说明:不同实现(paper vs Caffe vs torchvision)在输入尺寸/补零细节上略有差异,常见将输入视为 227×227×3224×224×3。下表以常见重现(Caffe/多数教程)为例,输出大小基于 227×227(或对 224×224 做微调后也可得到相同的中间尺寸)。使用 AdaptiveAvgPool2d((6,6)) 可以避免输入尺寸差异导致的维度问题(后面代码中已采用)。本文中参数量计算以常见复现(final flatten = 256×6×6)为基准。([NeurIPS 会议录][1], [多伦多大学计算机系][3])

层号层类型kernel / stride / pad输出通道输出尺度(示例)备注
输入3227×227×3(或 224×224×3)先做 scale → crop
Conv1Conv 11×11, s=4, p=211×11 / 4 / 29655×55×96ReLU → LRN → MaxPool(3,2)
Pool1MaxPool 3×3, s=227×27×96
Conv2Conv 5×5, s=1, p=2, groups=25×5 / 1 / 225627×27×256ReLU → LRN → Pool
Pool2MaxPool 3×3, s=213×13×256
Conv3Conv 3×3, s=1, p=13×3 / 1 / 138413×13×384ReLU
Conv4Conv 3×3, s=1, p=1, groups=23×3 / 1 / 138413×13×384ReLU
Conv5Conv 3×3, s=1, p=1, groups=23×3 / 1 / 125613×13×256ReLU → Pool (->6×6×256)
FC6Linear40961×1×4096Dropout(0.5)
FC7Linear40961×1×4096Dropout(0.5)
FC8Linear1000logitsSoftmax / CrossEntropyLoss

注:paper 中 conv2、conv4、conv5 的“分组连接”(groups=2)设计最初出于 GPU 内存/并行的工程实现需要(在两块 GPU 上分别计算并部分连接),现代实现用 groups 可以在单 GPU 上复现该连接方式。([NeurIPS 会议录][1], [PyTorch Forums][4])


4. 逐层计算举例(重点:尺寸 & 参数如何得到)

下面先给出常用的卷积输出公式,然后做具体示例与参数量计算。

卷积输出尺寸公式(2D,单维):

O=⌊W−K+2PS⌋+1O = \left\lfloor\frac{W - K + 2P}{S}\right\rfloor + 1 O=SWK+2P+1

其中 WWW 是输入宽(高同理),KKK 是核大小,PPP 是 padding,SSS 是 stride,OOO 是输出宽(或高)。


例 1:Conv1 输出尺寸(两种常见约定)

  • 若用 输入 227×227,kernel=11,stride=4,pad=0,则
    O=(227−11)/4+1=55O=(227-11)/4+1=55O=(22711)/4+1=55 → 输出 55×55(很多 Caffe 实现采用 227);
  • 若用 输入 224×224,但加上 pad=2(常见复现做法),kernel=11,stride=4:
    O=⌊(224−11+2×2)/4⌋+1=⌊217/4⌋+1=54+1=55O=\lfloor(224-11+2×2)/4\rfloor+1=\lfloor217/4\rfloor+1=54+1=55O=⌊(22411+2×2)/4+1=217/4+1=54+1=55
    因此许多实现通过加 pad=2 在 224 和 227 的差异上取得相同的 55×55 输出(实现细节不同但逻辑等价)。([多伦多大学计算机系][3])

例 2:参数量计算(按层逐项示例)

参数(weights)数目 = out_channels × in_channels × kernel_h × kernel_w,再加上 out_channels 个偏置项(如果有 bias)。

举几个关键层(常见复现、flatten=256×6×6):

  • Conv196×3×11×11+96=34,848+96=34,94496 \times 3 \times 11 \times 11 + 96 = 34,848 + 96 = 34,94496×3×11×11+96=34,848+96=34,944 个参数。
  • Conv2(分组):paper 实现把输入通道分到两组(每组 48),卷积核为 5×5,输出 256:
    参数 = 256×48×5×5+256=307,200+256=307,456256 \times 48 \times 5 \times 5 + 256 = 307,200 + 256 = 307,456256×48×5×5+256=307,200+256=307,456
  • Conv3384×256×3×3+384=885,120384 \times 256 \times 3 \times 3 + 384 = 885,120384×256×3×3+384=885,120
  • Conv4384×192×3×3+384=663,936384 \times 192 \times 3 \times 3 + 384 = 663,936384×192×3×3+384=663,936
  • Conv5256×192×3×3+256=442,624256 \times 192 \times 3 \times 3 + 256 = 442,624256×192×3×3+256=442,624
  • FC6(输入 256×6×6=9216):4096×9216+4096=37,752,832+4096=37,756,9284096 \times 9216 + 4096 = 37,752,832 + 4096 = 37,756,9284096×9216+4096=37,752,832+4096=37,756,928
  • FC74096×4096+4096=16,781,3124096 \times 4096 + 4096 = 16,781,3124096×4096+4096=16,781,312
  • FC81000×4096+1000=4,096,000+1000=4,097,0001000 \times 4096 + 1000 = 4,096,000 + 1000 = 4,097,0001000×4096+1000=4,096,000+1000=4,097,000

把这些加起来(各层之和)大约 60,965,224 ≈ 61M 参数(paper 给出的规模约 60M 左右,和上面的逐层分解是一致的常见复现结果)。这说明:FC 层占了绝大多数参数。([Stack Overflow][5], [NeurIPS 会议录][1])


5. 关键设计点解析(为什么这些创新重要)

  1. ReLU(Rectified Linear Unit):比 sigmoid/tanh 的非线性更简单、不饱和、反向传播梯度消失更少,训练更快、收敛更好。AlexNet 强调 ReLU 这是性能跃升的重要因素之一。([NeurIPS 会议录][1])
  2. Local Response Normalization (LRN):paper 中用于增强局部“竞争性”,帮助泛化(现在 BN 更常用了)。LRN 在 conv1/conv2 后使用以稍微提升精度(但现代工作中效果有限)。([NeurIPS 会议录][1])
  3. 分组卷积(groups=2):paper 在 conv2/conv4/conv5 采用分组连接,最初是出于 “在 2 块GTX 580 GPU 上并行训练 / 内存受限” 的工程需要(每块 GPU 处理一半的通道并部分连接)。现在可用 groups 在单卡上复现。([NeurIPS 会议录][1], [PyTorch Forums][4])
  4. 重采样/池化(overlapping pooling):AlexNet 使用 kernel=3, stride=2 的 pooling(窗口有重叠),paper 指出重叠 pooling 相比不重叠可以略微提高泛化。([NeurIPS 会议录][1])
  5. 数据增强(包括 PCA lighting):两种简单但有效的数据扩增:图像随机裁切/左右翻转 + RGB 空间的 PCA 颜色扰动(paper 提到),这些“廉价”的增强能极大扩充 effective dataset 并降低过拟合。([NeurIPS 会议录][1])
  6. Dropout(FC 层):在 FC6/FC7 使用 dropout(0.5) 有效降低过拟合,显著提高泛化。([NeurIPS 会议录][1])

6. PyTorch 实现(完整代码 —— 可复制粘贴)

下面给出一个**忠实还原(常见复现)**的 AlexNet PyTorch 实现(包括 LRN、groups、Dropout、AdaptiveAvgPool,适配不同输入尺寸)。把整个代码直接复制到你的 .py/笔记本中即可运行/微调。

# alexnet_pytorch.py
import torch
import torch.nn as nn
import torch.nn.functional as Fclass AlexNetOriginal(nn.Module):def __init__(self, num_classes=1000, dropout=0.5):super(AlexNetOriginal, self).__init__()self.features = nn.Sequential(# Conv1: 3 -> 96, kernel 11, stride 4, pad 2nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),nn.ReLU(inplace=True),nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2.0),nn.MaxPool2d(kernel_size=3, stride=2),# Conv2: 96 -> 256, kernel 5, pad 2, groups=2 (paper used 2 GPUs)nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2),nn.ReLU(inplace=True),nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2.0),nn.MaxPool2d(kernel_size=3, stride=2),# Conv3: 256 -> 384, kernel 3, pad 1nn.Conv2d(256, 384, kernel_size=3, padding=1),nn.ReLU(inplace=True),# Conv4: 384 -> 384, kernel 3, pad 1, groups=2nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2),nn.ReLU(inplace=True),# Conv5: 384 -> 256, kernel 3, pad 1, groups=2nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)# ensure fixed flatten size: use adaptive pooling -> 6x6self.avgpool = nn.AdaptiveAvgPool2d((6, 6))self.classifier = nn.Sequential(nn.Dropout(p=dropout),nn.Linear(256 * 6 * 6, 4096),nn.ReLU(inplace=True),nn.Dropout(p=dropout),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes),)def forward(self, x):x = self.features(x)x = self.avgpool(x)           # shape -> (N, 256, 6, 6)x = torch.flatten(x, 1)       # shape -> (N, 256*6*6)x = self.classifier(x)return x# Example: instantiate model
# model = AlexNetOriginal(num_classes=1000)
# print(model)

说明:

  • LocalResponseNorm 在 PyTorch 中可用,但在现代网络中通常被 BN(BatchNorm)替代。
  • groups=2 用于复现原论文的分组连接;在多 GPU 时可以映射到不同设备,在单 GPU 上也能按分组工作(等价于并行的两个卷积再 concat)。([NeurIPS 会议录][1], [PyTorch Forums][4])

7. 训练与评估(实践步骤 + 超参数建议)

数据准备(paper 的处理):

  1. 将训练图像缩放,使短边为 256(保持纵横比)。
  2. 从缩放图像提取随机 224×224 补丁(并随机镜像)用于训练;评估时使用中心裁剪(center crop)。
  3. 进行像素级的“lighting” PCA 扰动(paper 中提到的颜色主成分扰动)或使用更简单的 ColorJitter。([NeurIPS 会议录][1])

超参数(paper 的设置,作为起点):

  • 优化器:SGD(momentum)
  • 初始学习率:lr = 0.01(paper)
  • momentum = 0.9
  • weight_decay = 0.0005
  • batch_size = 128(如果 GPU 内存不足可降到 64/32)
  • 学习率衰减:当验证误差停滞时手动将 lr 降 ×0.1(paper 中总共减少 3 次,最终 lr≈1e-5)
  • 训练轮数:paper 大约训练 90 个 epoch(总耗时 5–6 天,用两块 GTX 580 GPU)。现实中通常用更强硬件或直接 fine-tune 预训练模型。([NeurIPS 会议录][1], [维基百科][2])

训练代码模板(伪代码,关键点):

# 伪代码概览(简化版,不含 DataLoader 构造)
model = AlexNetOriginal(num_classes=1000).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)  # 举例for epoch in range(epochs):model.train()for images, labels in train_loader:images, labels = images.to(device), labels.to(device)logits = model(images)loss = criterion(logits, labels)optimizer.zero_grad()loss.backward()optimizer.step()# 验证 & 学习率调整scheduler.step()# 记录 train/val loss 与 top-1/top-5 accuracy

评估(Top-1 / Top-5):
torch.topk 可以计算 top-k 准确率;ImageNet 采用 top-1 与 top-5 指标。

可行替代(如果你资源有限):

  • 直接使用 torchvision.models.alexnet(pretrained=True) 做 fine-tune(更快、更实用)。
  • 在 CIFAR-10 / CIFAR-100 或自定义小数据集上练习,先保证实现无误,再上大规模数据。([PyTorch Docs][6])

8. 实验扩展(建议做的对比实验)

  1. 激活函数对比:ReLU vs LeakyReLU vs ELU(训练速度与最终精度比较)。
  2. 池化方式:Overlapping MaxPool(paper) vs non-overlapping vs AveragePool。
  3. BN vs LRN:在 conv 层后替换 LRN 为 BatchNorm,观察训练稳定性与收敛速度(BN 通常更好)。
  4. Dropout 和 权重衰减的组合:研究不同 dropout 概率和 weight_decay 对泛化影响。
  5. 数据增强:比较仅随机裁剪/镜像与加入 ColorJitter / PCA lighting 的效果。
  6. 优化器/学习率策略:SGD+momentum vs Adam/AdamW vs cosine lr schedule。

做这些实验时,把对比的关键指标(train/val loss、top-1/top-5 accuracy、训练时间)画成曲线,会很直观。


9. 总结

  • AlexNet 的成功来自于“把深度网络 + ReLU + 大规模数据 + GPU + 一些工程技巧(数据增强、dropout、LRN、分组计算)”结合起来。它证明了深度卷积网络在大数据集上的巨大潜力,从而推动了后续更深、更高效模型的发展(如 VGG、GoogLeNet、ResNet 等)。([NeurIPS 会议录][1], [维基百科][2])
  • 实践建议:若只想快速上手并取得良好结果,优先选择 预训练模型 + 微调;若目标是理解与复现原论文,按本文给出的实现与训练超参做实验并记录对比会很有收获。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/96274.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/96274.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/96274.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《sklearn机器学习——指标和评分1》

3个不同的API可供评估模型预测质量: 评估器评分方法:评估器有一个score方法,它给计划解决的问题提供一个初始评估标准。这部分内容不在这里讨论,但会出现在每一个评估器的文件中。 评分参数:使用交叉验证(…

人工智能中的线性代数总结--简单篇

numpy库中的dot函数来计算矩阵和向量的点积def matrix_vector_dot_product(a, b):import numpy as npif (len(a[0]) ! len(b)):return -1# 使用tolist()将结果转换为列表return np.dot(a, b).tolist()原始方法def matrix_vector_dot_product(matrix, vector):if len(matrix[0])…

又是全网首创/纯Qt实现28181设备模拟器/rtp视频点播/桌面转28181/任意文件转28181/跨平台

一、前言说明 这个工具前前后后也算是废了不少功夫,最开始是因为28181服务端的组件已经完美实现,对照国标文档看了很多遍,逐个实现需要的交互协议,整体上比onvif协议要难不少,主要是涉及到的东西比较多,有…

安卓逆向(一)Ubuntu环境配置

一、Ubuntu 1、虚拟机 首先准备一个Ubuntu的虚拟机,就随便新建一个就行,我这里使用的是Ubuntu21.04,但是内存跟硬盘大小最好设置的稍微大一点。 2、基础环境 (1)解决apt-get update报错问题 apt-get是Linux系统中一个管…

Go 1.25在性能方面做了哪些提升?

Go 1.25 在性能方面带来了多项重要提升,主要有以下几个方面: 实验性垃圾回收器 GreenTea GC:针对小对象密集型应用优化,显著提升小对象标记和扫描性能,垃圾回收开销减少0-40%,暂停时间缩短,吞吐…

Python与XML文件处理详解(2续):xml.dom.minidom模块高阶使用方法

目录 第一部分:高级节点操作与遍历方法 1.1 更精确的节点导航 1.2 使用 cloneNode() 复制节点 1.3 节点插入、替换与高级管理 第二部分:文档创建与高级输出控制 2.1 使用 Document 工厂方法完整创建文档 2.2 高级输出与序列化控制 第三部分:实用工具函数与模式处理 …

如何利用 ChatGPT 辅助写作

引言 介绍人工智能辅助写作的兴起,ChatGPT 在写作领域的应用潜力,以及本文的核心目标。 ChatGPT 在写作中的核心功能 概述 ChatGPT 的主要功能,包括文本生成、润色、结构优化、灵感激发等。 利用 ChatGPT 辅助写作的具体方法 生成创意与灵感 …

【有鹿机器人自述】我在社区的365天:扫地、卖萌、治愈人心

大家好,我是有鹿巡扫机器人,编号RD-07。今天我想和大家分享这一年来的工作见闻——没错,我们机器人也会"观察"和"感受",尤其是在连合直租将我送到这个社区后,发生的点点滴滴让我拥有了前所未有的&…

第五十五天(SQL注入增删改查HTTP头UAXFFRefererCookie无回显报错复盘)

#数据库知识: 1、数据库名,表名,列名,数据 2、自带数据库,数据库用户及权限 3、数据库敏感函数,默认端口及应用 4、数据库查询方法(增加删除修改更新) #SQL注入产生原理&#xf…

怎么用 tauri 创建一个桌面应用程序(Electron)

以前用 Electron 做过一个桌面应用程序,打包体积确实很大,启动也很慢。这次先 tauri。 并且用 bun 代替 npm 速度更快,避免总是出现依赖问题。 前端用 react 为了学习下,用 js 先现在主流的 typescript。 安装 bun npm instal…

【通过Docker快速部署Tomcat9.0】

文章目录前言一、部署docker二、部署Tomcat2.1 创建存储卷2.2 运行tomcat容器2.3 查看tomcat容器2.4 查看端口是否监听2.5 防火墙开放端口三、访问Tomcat前言 Tomcat介绍 Tomcat 是由 Apache 软件基金会(Apache Software Foundation)开发的一个开源 Jav…

LabVIEW UI 分辨率适配

针对 LabVIEW UI 在不同分辨率下的适配,现有方案分三类:一是现有 VI 可通过 “VI 属性 - 窗口大小” 勾选比例保持或控件缩放选项快速调整,也可取消勾选或换等宽字体防控件移位;二是项目初期以最低目标分辨率为基准,用…

国产化FPGA开发板:2050-基于JFMK50T4(XC7A50T)的核心板

(IEB-PS-3051-邮票孔) 一、核心板概述 板卡基于JFMK50T4国产化FPGA芯片,设计的一款工业级核心板,板卡集成主芯片、电源、DDR、配置芯片,大大减轻客户的扩展开发困难。丰富的IO和4个GTP,让用户轻…

Webpack 核心原理剖析

时至今日,Webpack 已迭代到 5.x 版本,其功能模块的扩充和复杂度的提升使得源码学习成本陡增。官方文档的晦涩表述更是让许多开发者望而却步。然而,理解 Webpack 的核心原理对优化构建流程、定制化打包方案至关重要。本文将通过简化流程和代码…

移植Qt4.8.7到ARM40-A5

移植Qt4.8.7到ARM40-A5 主机平台:Ubuntu 16.04 LTS(x64) 目标平台:ARM40-A5 Qt版本:Qt4.8.7 ARM GCC编译工具链: arm-2014.05-29-arm-none-linux-gnueabi-i686-pc-linux-gnu.tar.bz2 ----------## Qt移植步骤 ## 1、了解Ubuntu&am…

C++_哈希

1. unordered系列关联式容器在C98中,STL提供了底层为红黑树结构的一系列关联式容器,在查询时效率可达到$log_2 N$,即最差情况下需要比较红黑树的高度次,当树中的节点非常多时,查询效率也不理想。最好 的查询是&#xf…

Redis 内存管理机制:深度解析与性能优化实践

🧠 Redis 内存管理机制:深度解析与性能优化实践 文章目录🧠 Redis 内存管理机制:深度解析与性能优化实践🧠 一、Redis 内存架构全景💡 Redis 内存组成结构📊 内存占用分布示例⚙️ 二、内存分配…

cargs: 一个轻量级跨平台命令行参数解析库

目录 1.简介 2.安装与集成 3.项目的目录结构及介绍 4.核心数据结构与函数 5.基本使用示例 6.应用案例和最佳实践 7.高级用法 8.与其他库的对比 9.总结 1.简介 cargs 是一个轻量级、无依赖的 C 语言命令行参数解析库,虽然本身是 C 库,但可以无缝…

【数学建模】质量消光系数在烟幕遮蔽效能建模中的核心作用

前言:欢迎各位光临本博客,这里小编带你直接手撕质量相关系数,文章并不复杂,愿诸君耐其心性,忘却杂尘,道有所长!!!! **🔥个人主页:IF’…

Java代码审计实战:XML外部实体注入(XXE)深度解析

Java代码审计实战:XML外部实体注入(XXE)深度解析XML外部实体注入(XXE)是Web应用程序中一种常见但又常常被忽视的漏洞。它利用了XML解析器解析XML文档时,允许引用外部实体这个特性。如果解析器没有禁用外部实…