CNN卷积神经网络

本章提供一个对CNN卷积网络的快速实现

全连接网络 VS 卷积网络

全连接神经网络之所以不太适合图像识别任务,主要有以下几个方面的问题:

  • 参数数量太多 考虑一个输入10001000像素的图片(一百万像素,现在已经不能算大图了),输入层有10001000=100万节点。假设第一个隐藏层有100个节点(这个数量并不多),那么仅这一层就有(1000*1000+1)*100=1亿参数,这实在是太多了!我们看到图像只扩大一点,参数数量就会多很多,因此它的扩展性很差。
  • 没有利用像素之间的位置信息 对于图像识别任务来说,每个像素和其周围像素的联系是比较紧密的,和离得很远的像素的联系可能就很小了。如果一个神经元和上一层所有神经元相连,那么就相当于对于一个像素来说,把图像的所有像素都等同看待,这不符合前面的假设。当我们完成每个连接权重的学习之后,最终可能会发现,有大量的权重,它们的值都是很小的(也就是这些连接其实无关紧要)。努力学习大量并不重要的权重,这样的学习必将是非常低效的。
  • 网络层数限制 我们知道网络层数越多其表达能力越强,但是通过梯度下降方法训练深度全连接神经网络很困难,因为全连接神经网络的梯度很难传递超过3层。因此,我们不可能得到一个很深的全连接神经网络,也就限制了它的能力。

那么,卷积神经网络又是怎样解决这个问题的呢?主要有三个思路:

  • 局部连接 这个是最容易想到的,每个神经元不再和上一层的所有神经元相连,而只和一小部分神经元相连。这样就减少了很多参数。
  • 权值共享 一组连接可以共享同一个权重,而不是每个连接有一个不同的权重,这样又减少了很多参数。
  • 下采样 可以使用Pooling来减少每层的样本数,进一步减少参数数量,同时还可以提升模型的鲁棒性。

对于图像识别任务来说,卷积神经网络通过尽可能保留重要的参数,去掉大量不重要的参数,来达到更好的学习效果。

卷积结构

卷积层

卷积层可以产生一组平行的特征图(feature map),它通过在输入图像上滑动不同的卷积核并执行一定的运算而组成。此外,在每一个滑动的位置上,卷积核与输入图像之间会执行一个元素对应乘积并求和的运算以将感受视野内的信息投影到特征图中的一个元素。这一滑动的过程可称为步幅 Z_s,步幅 Z_s 是控制输出特征图尺寸的一个因素。卷积核的尺寸要比输入图像小得多,且重叠或平行地作用于输入图像中,一张特征图中的所有元素都是通过一个卷积核计算得出的,也即一张特征图共享了相同的权重和偏置项。

池化层

池化(Pooling)是卷积神经网络中另一个重要的概念,它实际上是一种非线性形式的降采样。有多种不同形式的非线性池化函数,而其中“最大池化(Max pooling)”是最为常见的。它是将输入的图像划分为若干个矩形区域,对每个子区域输出最大值。

一个特征的精确位置远不及它相对于其他特征的粗略位置重要。池化层会不断地减小数据的空间大小,因此参数的数量和计算量也会下降,这在一定程度上也控制了过拟合。通常来说,CNN的网络结构中的卷积层之间都会周期性地插入池化层。池化操作提供了另一种形式的平移不变性。因为卷积核是一种特征发现器,我们通过卷积层可以很容易地发现图像中的各种边缘。但是卷积层发现的特征往往过于精确,我们即使高速连拍拍摄一个物体,照片中的物体的边缘像素位置也不大可能完全一致,通过池化层我们可以降低卷积层对边缘的敏感性。

全连接层

最后,在经过几个卷积和最大池化层之后,神经网络中的高级推理通过完全连接层来完成。就和常规的非卷积人工神经网络中一样,完全连接层中的神经元与前一层中的所有激活都有联系。因此,它们的激活可以作为仿射变换来计算,也就是先乘以一个矩阵然后加上一个偏差(bias)偏移量(向量加上一个固定的或者学习来的偏差量)。

卷积神经网络(LeNet)

模型实现

LeNet是最早发布的卷积神经网络之一,因其在计算机视觉任务中的高效性能而受到广泛关注。

用Pytorch框架实现此类模型非常简单。我们只需要实例化一个Sequential块并将需要的层连接在一起。

import torch
from torch import nn
from d2l import torch as d2lnet = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

将一个大小为28×28的单通道(黑白)图像通过LeNet。通过在每一层打印输出的形状,我们可以检查模型
在这里插入图片描述

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)
Conv2d output shape:         torch.Size([1, 6, 28, 28])
Sigmoid output shape:        torch.Size([1, 6, 28, 28])
AvgPool2d output shape:      torch.Size([1, 6, 14, 14])
Conv2d output shape:         torch.Size([1, 16, 10, 10])
Sigmoid output shape:        torch.Size([1, 16, 10, 10])
AvgPool2d output shape:      torch.Size([1, 16, 5, 5])
Flatten output shape:        torch.Size([1, 400])
Linear output shape:         torch.Size([1, 120])
Sigmoid output shape:        torch.Size([1, 120])
Linear output shape:         torch.Size([1, 84])
Sigmoid output shape:        torch.Size([1, 84])
Linear output shape:         torch.Size([1, 10])

模型训练

现在我们已经实现了LeNet,让我们看看LeNet在Fashion-MNIST数据集上的表现。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

由于完整的数据集位于内存中,因此在模型使用GPU计算数据集之前,我们需要将其复制到显存中。

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

与全连接层一样,我们使用交叉熵损失函数和小批量随机梯度下降。

#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

训练和评估LeNet-5模型。

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.469, train acc 0.823, test acc 0.779
55296.6 examples/sec on cuda:0

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/86819.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/86819.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/86819.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

平地起高楼: 环境搭建

技术选型 本小册是采用纯前端的技术栈模拟实现小程序架构的系列文章,所以主要以前端技术栈为主,但是为了模拟一个App应用的效果,以及小程序包下载管理流程的实现,我们还是需要搭建一个基础的App应用。这里我们将选择 Tauri2.0 来…

langgraph学习2 - MCP编程

3中通信方式: 目前sse用的很少 3.开发mcp框架 主流框架2个: MCP skd 官方 Fast Mcp V2 ,(V1捐给MCP 官方) 大模型如何识别用哪个tools, 以及如何使用tools:

CSS 与 JavaScript 加载优化

📄 CSS 与 JavaScript 加载优化指南:位置、阻塞与性能 让你的网页飞起来!🚀 本文详细解析 CSS 和 JavaScript 标签的放置位置如何影响页面性能,涵盖阻塞原理、浏览器机制和最佳实践。掌握这些知识可显著提升用户体验…

WSL安装发行版上安装podman

WSL安装发行版上安装podman 1.WSL拉取发行版1.1 拉取2.2.修改系统拉取的镜像,可以加速软件包的更新 2.podman安装2.1.安装podman 容器工具2.2.配置podman的镜像仓库2.3.拉取n8n镜像并创建容器 本文在windows11上,使用WSL拉取并创建ubuntu24.04虚拟机&…

Excel 常用快捷键与对应 VBA 方法/属性清单

功能描述快捷键VBA 对应方法/属性 (核心逻辑)说明导航 (类似 End 方向键)这些是 End 键行为的直接对应向下到连续区域末尾Ctrl ↓ActiveCell.End(xlDown)从当前单元格向下,遇到第一个空单元格停止。向上到连续区域开头Ctrl ↑ActiveCell.End(xlUp)从当前单元格向上…

计算机组成原理与体系结构-实验四 微程序控制器 (Proteus 8.15)

一、实验目的 1、理解“微程序”设计思想,了解“指令-微指令-微命令”的微程序结构。 2、掌握微程序控制器的结构和设计方法。 二、实验内容 设计一个“最简版本”的 CPU 模型机:利用时序发生器来产生 CPU 的预定时序,通过微程序控制器的自…

安卓端某音乐类 APP 逆向分享(二)协议分析

以歌曲搜索协议为例,查看charles中歌曲搜索协议详情 拷贝出搜索协议的Curl形式 curl -H Host: interface3.music.xxx.com -H Cookie: EVNSM1.0.0; NMCIDoufhty.1667355455436.01.4; versioncode8008050; buildver221010200836; resolution2392x1440; deviceIdYDwXa…

七天学会SpringCloud分布式微服务——03——Nacos远程调用

1、微服务项目配置类放在地方 配置类型应放位置说明通用配置类(如:跨服务通用的拦截器、全局异常处理、统一响应体封装等)可放在一个**公共模块(common/config)**中,被各服务引入实现代码复用,…

基于Java+Spring Boot的校园闲置物品交易系统

源码编号:S561 源码名称:基于Spring Boot的校园闲置物品交易系统 用户类型:多角色,用户、商家、管理员 数据库表数量:12 张表 主要技术:Java、Vue、ElementUl 、SpringBoot、Maven 运行环境&#xff1…

SpringBoot 的 jar 包为什么可以直接运行?

一、普通jar包和SpringBoot jar包有什么区别?什么是jar包?? (1)什么是Jar包? 定义: JAR 包(Java Archive) 是 Java 平台标准的归档文件格式,用于将多个 Jav…

算法-基础算法-递归算法(Python)

文章目录 前言递归和数学归纳法递归三步走递归的注意点避免栈溢出避免重复运算 题目斐波那契数反转链表 前言 递归(Recursion):指的是一种通过重复将原问题分解为同类的子问题而解决的方法。在绝大数编程语言中,可以通过在函数中再…

TVFEMD-CPO-TCN-BiLSTM多输入单输出模型

47-TVFEMD-CPO-TCN-BiLSTM多输入单输出模型 适合单变量,多变量时间序列预测模型(可改进,加入各种优化算法) 时变滤波的经验模态分解TVFEMD时域卷积TCN双向长短期记忆网络BiLSTM时间序列预测模型 另外以及有 TCN-BILSTM …

深入浅出Node.js中间件机制

我们用一个实际的例子来看看中间件是如何运作的。假设我们有一个非常简单的Express应用,它只有两个中间件函数: const express require(express); const app express();app.use((req, res, next) > {console.log(第一个中间件);next(); });app.use…

Vue-15-前端框架Vue之应用基础编程式路由导航

文章目录 1 RouterLink的replace属性1.1 App.vue1.2 应用效果2 编程式路由导航2.1 场景一Home.vue2.2 场景二News.vue3 路由重定向3.1 index.ts3.2 Detail.vue3.3 About.vue1 RouterLink的replace属性 路由每次跳转都有记录,默认是push,可以改为replace。 RouterLink支持两…

android14 设置下连续点击5次Settings标题跳转到拨号界面

部分项目隐藏了拨号器,但开发者需要间距跳转到拨号界面 设置一级界面: packages/apps/Settings/src/com/android/settings/homepage/SettingsHomepageActivity.java 通过dispatchTouchEvent方法先获取Settings标题的区域X,Y数据。 import java.util.Set…

MP分页和连表常用写法

1. 分页查询 方案一&#xff1a;MyBatis XML MyBatis 内置的使用方式&#xff0c;步骤如下&#xff1a; ① 创建 AdminUserMapper.xml 文件&#xff0c;编写两个 SQL 查询语句&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <!DOCTYPE m…

使用 Spring AI Alibaba构建 AI Code Review 应用

很早的时候就想着用AI来做Code Review&#xff0c;最近也看到了一些不错的实现&#xff0c;但是没有一个使用Java来构建的&#xff0c;看的比较费劲&#xff0c;虽然说语言只是一种工具&#xff0c;但是还是想用Java重新写一遍&#xff0c;正好最近Spring AI Alibaba出了正式版…

力扣1590. 使数组和能被 P 整除

这一题的难点在于模运算&#xff0c;对模运算足够了解&#xff0c;对式子进行变换就很容易得到结果&#xff0c;本质上还是一道前缀和哈希表的题 这里重点讲一下模运算。 常见的模运算的用法 (a-b)%k0等价于 a%kb%k 而在这一题中由于多了一个len&#xff0c;&#xff08;数组的…

FPGA内部资源介绍

FPGA内部资源介绍 目录 逻辑资源块LUT&#xff08;查找表&#xff09;加法器寄存器MUX&#xff08;复用器&#xff09;时钟网络资源 全局时钟网络资源区域时钟网络资源IO时钟网络资源 时钟处理单元BLOCK RAMDSP布线资源接口资源 用户IO资源专用高速接口资源 总结 1. 逻辑资源…

CSS 列表

CSS 列表 引言 CSS 列表是网页设计中常用的一种布局方式&#xff0c;它能够帮助我们以更灵活、更美观的方式展示数据。本文将详细介绍 CSS 列表的创建、样式设置以及常用技巧&#xff0c;帮助您更好地掌握这一重要技能。 CSS 列表概述 CSS 列表主要包括两种类型&#xff1a…