本文目录:

  • 一、了解CIFAR-10数据集
  • 二、案例之导包
  • 三、案例之创建数据集
  • 四、案例之搭建神经网络(模型构建)
  • 五、案例之编写训练函数(训练模型)
  • 六、案例之编写预测函数(模型测试)

前言:此前分享了卷积神经网络相关知识,今天实战下:搭建一个卷积神经网络来实现图像分类任务。

一、了解CIFAR-10数据集

CIFAR-10数据集5万张训练图像、1万张测试图像、10个类别、每个类别有6k个图像,图像大小32×32×3。下图列举了10个类,每一类随机展示了10张图片:
在这里插入图片描述
PyTorch 中的 torchvision.datasets 计算机视觉模块封装了 CIFAR10 数据集,如果需要使用可以直接导入。

导入代码:

from torchvision.datasets import CIFAR10

二、案例之导包

import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor  # pip install torchvision -i https://mirrors.aliyun.com/pypi/simple/
import torch.optim as optim
from torch.utils.data import DataLoader
import time
import matplotlib.pyplot as plt
from torchsummary import summary# 每批次样本数
BATCH_SIZE = 8

三、案例之创建数据集

# 1. 数据集基本信息
def create_dataset():# 加载数据集:训练集数据和测试数据# ToTensor: 将image(一个PIL.Image对象)转换为一个Tensortrain = CIFAR10(root='data', train=True, transform=ToTensor())valid = CIFAR10(root='data', train=False, transform=ToTensor())# 返回数据集结果return train, validif __name__ == '__main__':# 数据集加载train_dataset, valid_dataset = create_dataset()# 数据集类别print("数据集类别:", train_dataset.class_to_idx)# 数据集中的图像数据print("训练集数据集:", train_dataset.data.shape)print("测试集数据集:", valid_dataset.data.shape)# 图像展示plt.figure(figsize=(2, 2))plt.imshow(train_dataset.data[1])plt.title(train_dataset.targets[1])plt.show()

运行结果:

数据集类别: {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
训练集数据集: (50000, 32, 32, 3)
测试集数据集: (10000, 32, 32, 3)

图像:

在这里插入图片描述

四、案例之搭建神经网络(模型构建)

需要搭建的CNN网络结构如下:
在这里插入图片描述
我们要搭建的网络结构如下:

  1. 输入形状: 32x32;
  2. 第一个卷积层输入 3 个 Channel, 输出 6 个 Channel, Kernel Size 为: 3x3;
  3. 第一个池化层输入 30x30, 输出 15x15, Kernel Size 为: 2x2, Stride 为: 2;
  4. 第二个卷积层输入 6 个 Channel, 输出 16 个 Channel, Kernel Size 为 3x3;
  5. 第二个池化层输入 13x13, 输出 6x6, Kernel Size 为: 2x2, Stride 为: 2;
  6. 第一个全连接层输入 576 维, 输出 120 维;
  7. 第二个全连接层输入 120 维, 输出 84 维;
  8. 最后的输出层输入 84 维, 输出 10 维。

我们在每个卷积计算之后应用 relu 激活函数来给网络增加非线性因素。

# 模型构建
class ImageClassification(nn.Module):# 定义网络结构def __init__(self):super(ImageClassification, self).__init__()# 定义网络层:卷积层+池化层# 第一个卷积层, 输入图像为3通道,输出特征图为6通道,卷积核3*3self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)# 第一个池化层, 核宽高2*2self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)# 第二个卷积层, 输入图像为6通道,输出特征图为16通道,卷积核3*3self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)# 第二个池化层, 核宽高2*2self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# 全连接层# 第一个隐藏层 输入特征576(一张图像为16*6*6), 输出特征120个self.linear1 = nn.Linear(576, 120)# 第二个隐藏层self.linear2 = nn.Linear(120, 84)# 输出层self.out = nn.Linear(84, 10)# 定义前向传播def forward(self, x):# 卷积+relu+池化x = torch.relu(self.conv1(x))x = self.pool1(x)# 卷积+relu+池化x = torch.relu(self.conv2(x))x = self.pool2(x)# 将特征图做成以为向量的形式:相当于特征向量 全连接层只能接收二维数据集# 由于最后一个批次可能不够8,所以需要根据批次数量来改变形状# x[8, 16, 6, 6] --> [8, 576] -->8个样本,576个特征# x.size(0):1个值是样本数 行数# -1:第2个值由原始x剩余3个维度值相乘计算得到 列数(特征个数)x = x.reshape(x.size(0), -1)# 全连接层x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))# 返回输出结果return self.out(x)if __name__ == '__main__':# 模型实例化model = ImageClassification()summary(model, input_size=(3,32,32), batch_size=1)

运行结果:

在这里插入图片描述

五、案例之编写训练函数(训练模型)

在训练时,使用多分类交叉熵损失函数,Adam 优化器。具体实现代码如下:

def train(model, train_dataset):# 构建数据加载器dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)criterion = nn.CrossEntropyLoss() # 构建损失函数optimizer = optim.Adam(model.parameters(), lr=1e-3) # 构建优化方法epoch = 100  # 训练轮数for epoch_idx in range(epoch):sum_num = 0   # 样本数量total_loss = 0.0  # 损失总和correct = 0  # 预测正确样本数start = time.time()  # 开始时间# 遍历数据进行网络训练for x, y in dataloader:model.train()output = model(x)loss = criterion(output, y)  # 计算损失optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播optimizer.step()  # 参数更新correct += (torch.argmax(output, dim=-1) == y).sum()  # 计算预测正确样本数# 计算每次训练模型的总损失值 loss是每批样本平均损失值total_loss += loss.item()*len(y)  # 统计损失和sum_num += len(y)print('epoch:%2s loss:%.5f acc:%.2f time:%.2fs' %(epoch_idx + 1,total_loss / sum_num,correct / sum_num,time.time() - start))# 模型保存torch.save(model.state_dict(), 'model/image_classification.pth')#联合上面代码一起运行本代码
if __name__ == '__main__':# 数据集加载train_dataset, valid_dataset = create_dataset()# 模型实例化model = ImageClassification()# 模型训练train(model,train_dataset)

运行结果:

epoch: 1 loss:1.67102 acc:0.38 time:26.23s
epoch: 2 loss:1.35650 acc:0.51 time:27.63s
epoch: 3 loss:1.22355 acc:0.57 time:31.10s
epoch: 4 loss:1.14639 acc:0.59 time:66.37s
epoch: 5 loss:1.09468 acc:0.61 time:40.38s
。。。。。。

六、案例之编写预测函数(模型测试)

当已经训练好模型(model),并保存了模型参数(model.state_dict()),可直接实例化模型,并加载训练好的模型参数,然后对测试集中的1万条样本进行预测,查看模型在测试集上的准确率。

def eval(valid_dataset):# 构建数据加载器dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)# 加载模型并加载训练好的权重model = ImageClassification()model.load_state_dict(torch.load('model/image_classification.pth'))# 模型切换评估模式, 如果网络模型中有dropout/BN等层, 评估阶段不进行相应操作model.eval()# 计算精度total_correct = 0total_samples = 0# 遍历每个batch的数据,获取预测结果,计算精度for x, y in dataloader:output = model(x)total_correct += (torch.argmax(output, dim=-1) == y).sum()total_samples += len(y)# 打印精度print('Acc: %.2f' % (total_correct / total_samples))if __name__ == '__main__':train_dataset, valid_dataset = create_dataset()eval(valid_dataset)

运行结果:

Acc: 0.57

最后,大家还可以通过调整lr(学习率)、神经元失活(dropout)、增加神经网络层数等方式来调整模型,提升acc,各看本领吧!

今天的分享到此结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/912274.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/912274.shtml
英文地址,请注明出处:http://en.pswp.cn/news/912274.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

记录多功能按键第二种写法使用定时器周期间隔判断.

逻辑是通过定时器溢出周期进行判断按下次数 比如设置定时器溢出周期为500MS,每次溢出都会判断按键按下次数,如果下个周期前没有触发按下,则结束键值判断.并确定触发键值.清空按下次数标志.测试比一个定时器周期按下按键次数判断写法要稳定... 记录STM32实现多功能按键_stm32一…

【安卓Sensor框架-1】SensorService 的启动流程

内核启动后,首个用户空间进程init(pid1)解析init.rc配置文件,启动关键服务(如Zygote和ServiceManager)。 Zygote服务配置为/system/bin/app_process --zygote --start-system-server,后续用于孵…

centos网卡绑定参考

同事整理分享: 1. 加载 Bonding 模块 modprobe bonding 获取网卡名称 ip a 找到接了网线的网卡名称,记下。 3. 配置物理网卡 创建并编辑 /etc/sysconfig/network-scripts/ifcfg-ens36(ifcfg-后面的内容根据上面找到的具体网卡名称决定&#…

mbedtls ssl handshake error,res:-0x2700

用LinkSDK.c连接第三方云平台出现现象 解决方案: 在_tls_network_establish函数中加入 mbedtls_ssl_conf_authmode(&adapter_handle->mbedtls.ssl_config, MBEDTLS_SSL_VERIFY_NONE);原因解释:用连接方式是不用证书认证/跳过服务端认证。

Spring Security 的方法级权限控制是如何利用 AOP 的?

Spring Security 的方法级权限控制是 AOP 技术在实际应用中一个极其强大的应用典范。它允许我们以声明式的方式保护业务方法,将安全规则与业务逻辑彻底解耦。 核心思想:权限检查的“门卫” 你可以把 AOP 在方法级安全中的作用想象成一个尽职尽责的“门…

一键内网穿透,无需域名和服务器,自动https访问

cloudflare能将内网web转为外网可访问的地址。(这和apiSQL有点类似,apiSQ可以将内网数据库轻松转换为外网的API,并且还支持代理内网已有API,增强安全增加API Key,以https访问等等) 但Cloudfalre tunnel这个…

Sentinel(二):Sentinel流量控制

一、Sentinel 流控规则基本介绍 1、Snetinel 流控规则配置方式 Sentinel 支持可视化的流控规则配置,使用非常简单;可以在监控服务下的“簇点链路” 或 “流控规则” 中 给指定的请求资源配置流控规则;一般推荐在 “簇点链路” 中配置流控规则…

支持PY普冉系列单片机调试工具PY32linK仿真器

PY32 Link是专为 ‌PY32系列ARM-Cortex内核单片机‌(如PY32F002A/030/071/040/403等)设计的仿真器,支持全系列芯片的‌调试和仿真‌功能。‌开发环境兼容性‌支持主流IDE:‌Keil MDK‌ 和 ‌IAR Embedded Workbench‌,…

深入解析Python多服务器监控告警系统:从原理到生产部署

深入解析Python多服务器监控告警系统:从原理到生产部署 整体架构图 核心设计思想 无代理监控:通过SSH直接获取数据,无需在目标服务器安装代理故障隔离:单台服务器故障不影响整体监控多级检测:网络层→资源层→服务层层…

JUC:10.线程、monitor管程、锁对象之间在synchronized加锁的流程(未完)

一、monitor管程工作原理: 首先,synchronized是一个对象锁,当线程运行到某个临界区,这个临界区使用synchronized对对象obj进行了上锁,此时底层发生了什么? 1.当synchronized对obj上锁后,synch…

Elasticsearch(ES)分页

Elasticsearch(简称 ES)本身不适合传统意义上的“深分页”,但提供了多种分页方式,每种适用不同场景。我们来详细讲解: 一、基本分页(from size) 最常用的分页方式,类似 SQL 的 LIM…

原生微信小程序:用 `setData` 正确修改数组中的对象项状态(附实战技巧)

📌 背景介绍 在微信小程序开发中,我们经常需要修改数组中某个对象的某个字段,比如: 列表中的某一项展开/收起多选状态切换数据列表中的临时标记等 一个常见的场景是: lists: [{ show: true }, { show: true }, { s…

Oracle 临时表空间相关操作

一、临时表空间概述 临时表空间(Temporary Tablespace)是Oracle数据库中用于存储临时数据的特殊存储区域,其数据在会话结束或事务提交后自动清除,重启数据库后彻底消失。主要用途包括: 存储排序操作(如OR…

从静态到动态:Web渲染模式的演进和突破

渲染模式有好多种,了解下web的各种渲染模式,对技术选型有很大的参考作用。 一、静态HTML时代 早期(1990 - 1995年)网页开发完全依赖手工编写HTML(HyperText Markup Language)和CSS(层叠样式表…

Flask(六) 数据库操作SQLAlchemy

文章目录 一、准备工作二、最小化可运行示例✅ 补充延迟绑定方式(推荐方式) 三、数据库基本操作(增删改查)1. 插入数据(增)2. 查询数据(查)3. 更新数据(改)4.…

PYTHON从入门到实践7-获取用户输入与while循环

# 【1】获取用户输入 # 【2】python数据类型的转换 input_res input("请输入一个数字\n") if int(input_res) % 10 0:print("你输入的数是10的倍数") else:print("你输入的数不是10的倍数") # 【3】while循环,适合不知道循环多少次…

学习笔记(C++篇)—— Day 8

1.STL简介 STL(standard template libaray-标准模板库):是C标准库的重要组成部分,不仅是一个可复用的组件库,而且是一个包罗数据结构与算法的软件框架。 2.STL的六大组件 先这样,下一部分是string的内容,内容比较多&a…

ant+Jmeter+jenkins接口自动化,如何实现把执行失败的接口信息单独发邮件?

B站讲的最好的自动化测试教程,工具框架附项目实战一套速通,零基础完全轻松掌握!自动化测试课程、web/app/接口 实现AntJMeterJenkins接口自动化失败接口邮件通知方案 要实现只发送执行失败的接口信息邮件通知,可以通过以下步骤实…

恶意Python包“psslib“实施拼写错误攻击,可强制关闭Windows系统

Socket威胁研究团队发现一个名为psslib的恶意Python包,该软件包伪装成提供密码安全功能,实则会突然关闭Windows系统。这个由化名umaraq的威胁行为者开发的软件包,是对知名密码哈希工具库passlib的拼写错误仿冒(typosquatting&…

云原生灰度方案对比:服务网格灰度(Istio ) 与 K8s Ingress 灰度(Nginx Ingress )

服务网格灰度与 Kubernetes Ingress 灰度是云原生环境下两种主流的灰度发布方案,它们在架构定位、实现方式和适用场景上存在显著差异。以下从多个维度对比分析,并给出选型建议: 一、核心区别对比 维度服务网格灰度(以 Istio 为例…