目录

一、实验目的

二、实验环境

三、实验内容

3.1 完成解压数据集相关操作

3.2分析代码结构并运行代码查看结果

3.3修改超参数(批量大小、学习率、Epoch)并对比分析不同结果

3.4修改网络结构(隐藏层数、神经元个数)并对比分析不同结果

四、实验小结


一、实验目的

  1. 了解python语法
  2. 了解全连接神经网络结构
  3. 调整超参数、修改网络结构并对比分析其结果

二、实验环境

Baidu 飞桨AI Studio

三、实验内容

3.1 完成解压数据集相关操作

输入以下两行命令解压数据集

(1)cd ./data/data230     

(2)unzip Minst.zip

运行后结果如图1所示

图 1 解压数据集

3.2分析代码结构并运行代码查看结果

代码结构:

import torchfrom torch import nn, optimfrom torch.autograd import Variablefrom torch.utils.data import DataLoaderfrom torchvision import datasets, transformsbatch_size = 64learning_rate = 0.02class Batch_Net(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(Batch_Net, self).__init__()self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1), nn.ReLU(True))self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True))self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return xdata_tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])train_dataset = datasets.MNIST(root='./data/data230', train=True, transform=data_tf, download=True)test_dataset = datasets.MNIST(root='./data/data230', train=False, transform=data_tf)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)#model = net.simpleNet(28 * 28, 300, 100, 10)# model = Activation_Net(28 * 28, 300, 100, 10)model = Batch_Net(28 * 28, 300, 100, 10)if torch.cuda.is_available():model = model.cuda()criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=learning_rate)epoch = 0for data in train_loader:img, label = dataimg = img.view(img.size(0), -1)if torch.cuda.is_available():img = img.cuda()label = label.cuda()else:img = Variable(img)label = Variable(label)out = model(img)loss = criterion(out, label)print_loss = loss.data.item()optimizer.zero_grad()loss.backward()optimizer.step()epoch+=1if epoch%100 == 0:print('epoch: {}, loss: {:.4}'.format(epoch, loss.data.item()))model.eval()eval_loss = 0eval_acc = 0for data in test_loader:img, label = dataimg = img.view(img.size(0), -1)if torch.cuda.is_available():img = img.cuda()label = label.cuda()out = model(img)loss = criterion(out, label)eval_loss += loss.data.item()*label.size(0)_, pred = torch.max(out, 1)num_correct = (pred == label).sum()eval_acc += num_correct.item()print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_dataset)),eval_acc / (len(test_dataset))))

代码分析

代码实现了通过使用批标准化的神经网络模型对MNIST数据集进行分类。

1.首先定义一些模型中会用到的超参数,在实验中设置批量大小(batch_size)为64,学习率(learning_rate)为0.02

2.定义Batch_Net神经网络类继承自torch库的nn.Module。

(1)定义的初始化函数_init_接收输入维度(in_dim)、两个隐藏层神经元数量(n_hidden_1和n_hidden_2)和输出维度(out_dim)作为参数,定义第一层和第二层网络结构,包括线性变换、批标准化和ReLU激活函数,第三层网络结构只包括线性变换。

(2)定义的forward前向传播函数接受输入x,依次经过三层网络结构处理后返回处理结果。

3.将数据进行预处理。包括使用transforms.ToTensor()将图片转换成PyTorch中处理的对象Tensor,并进行标准化操作,通过transforms.Normalize()做归一化(减均值,再除以标准差)操作,通过transforms.Compose()函数组合各种预处理的操作。

4.下载并加载MNIST训练数据(train_dataset)以及测试数据(test_dataset);创建数据加载器,用于批量加载训练数据(train_loader)和测试数据(test_loader)。

5.选择训练的模型,实验中选择Batch_Net模型,判断GPU是否可用,如果GPU可用,则将模型放到GPU上运行。在定义损失函数(criterion)和优化器(optimizer),损失函数使用交叉熵损失,优化器使用随机梯度下降优化器。

6.开始训练模型,遍历训练数据,通过img.view(img.size(0), -1)将图像数据调整为一维张量,以便与模型的输入匹配。计算损失值(loss)并进行反向传播和参数更新,每100个epoch打印一次训练损失,最后评估模型在测试集上的性能,计算并打印总损失(将损失值乘以当前批次的样本数量,累加到eval_loss中)和准确率(eval_acc)。

运行代码后的结果如图所示:

图 2 批量大小=64,学习率=0.02,Epoch次数=100

3.3修改超参数(批量大小、学习率、Epoch)并对比分析不同结果

1.当只修改批量大小batch_size时(学习率=0.02,Epoch次数=100)

(1)batch_size=16,结果如图3所示

图 3

(2)batch_size=64,结果如图4所示

图 4

(3)batch_size=128,结果如图5所示

图 5

对比总结:增大batch_size后,数据的处理速度加快,运行时间变短,跑完一次 epoch(全数据集)所迭代的次数减少。

2.当只修改学习率learning_rate时(批量大小=64,Epoch次数=100)

(1)learning_rate=0.005,结果如图6所示。

图 6

(2)learning_rate=0.02,结果如图7所示。

图 7

(3)learning_rate=0.1,结果如图8所示。

图 8

(4)learning_rate=0.4,结果如图9所示。

图 9

对比总结:当学习率适当增大时可能会有助于降低损失函数,提高模型的精确度,但是当学习率超出一定范围则会降低模型的精确率。

3.当只修改Epoch次数时(批量大小=64,学习率=0.02)

(1)Epoch=50,结果如图10所示。

图 10

(2)Epoch=100,结果如图11所示。

图 11

(3)Epoch=200,结果如图12所示。

图 12

对比总结:不同Epoch次数导致损失函数不同、模型的准确率有所差别。

3.4修改网络结构(隐藏层数、神经元个数)并对比分析不同结果

1.修改神经网络的隐藏层数

(1)隐藏层数为两层时的结果(n_hidden_1、n_hidden_2),结果如图13所示。

图 13

(2)隐藏层数为三层时的结果(n_hidden_1、n_hidden_2、n_hidden_3),结果如图14所示。

图 14

对比结果:可以看到,增加隐藏层个数后,运行时间增加,损失函数有所降低,三层隐藏层网络结构的模型准确率较高于两层隐藏层模型的准确率。

2.修改神经元个数(这里基于三层隐藏层数进行实验)

(1)神经元个数参数设定为model = Batch_Net(28 * 28, 300, 200,100, 10),结果如图15所示。

图 15

(2)神经元个数参数设定为model = Batch_Net(28 * 28, 400, 300,200, 10),结果如图16所示。

图 16

(3)神经元个数参数设定为model = Batch_Net(28 * 28, 200, 100,50, 10),结果如图17所示。

图 17

实验小

超参数是在训练神经网络之前设置的,而不是通过训练过程中学习得出的。常见的超参数包括学习率、批量大小、隐藏层的数量、神经元的个数等等。在使用深度神经网络时,正确地调整超参数是提高模型性能的关键。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/84152.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/84152.shtml
英文地址,请注明出处:http://en.pswp.cn/web/84152.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

openEuler安装BenchmarkSQL

BenchmarkSQL是一个用于评估数据库性能的开源工具。它模拟TPC-C(Transaction Processing Performance Council)基准测试场景,该场景主要用于衡量数据库在处理大量并发事务时的能力。TPC-C测试场景模拟了一个典型的批发分销商的业务环境&#…

分库分表之优缺点分析

大家好,我是工藤学编程 🦉一个正在努力学习的小博主,期待你的关注实战代码系列最新文章😉C实现图书管理系统(Qt C GUI界面版)SpringBoot实战系列🐷【SpringBoot实战系列】Sharding-Jdbc实现分库…

【2025年超详细】Git 系列笔记-4 git版本号及git相关指令运用。

系列笔记 【2025年超详细】Git 系列笔记-1 Git简述、Windows下git安装、Linux下git安装_displaying 2e144 commits. adjust this setting in -CSDN博客 【2025年超详细】Git 系列笔记-2 github连接超时问题解决_2025访问github-CSDN博客 【2025年超详细】Git 系列笔记-3 Git…

图像特征检测算法SuperPoint和SuperGlue

SuperPoint 背景与概述 :SuperPoint 是一个自监督的全卷积神经网络,用于提取图像中的兴趣点及其描述子。它在 2018 年由 Magic Leap 提出,通过在合成数据集上预训练一个基础检测器 MagicPoint,然后利用同胚适应技术对真实图像数据…

nginx 和 springcloud gateway cors 跨域如何设置

在跨域资源共享(CORS)配置中,Nginx 和 API Gateway(如Spring Cloud Gateway、Kong等)是两种常见的解决方案,它们的配置逻辑和适用场景有所不同。以下是详细对比和配置示例: 一、核心区别 维度NginxAPI Gateway定位反向代理/Web服务器微服务流量入口配置位置基础设施层应…

电路笔记(信号):一阶低通RC滤波器 一阶线性微分方程推导 拉普拉斯域表达(传递函数、频率响应)分析

目录 RC 低通滤波器电路一阶线性微分方程推导拉普拉斯域表达(传递函数)传递函数 H ( s ) H(s) H(s)频率响应(令 s j ω s j\omega sjω)幅频特性:相位特性:Bode 图(线性系统频率响应&#x…

【Git】删除远程分支时,本地分支还能看到

当远程仓库的分支被删除后,本地通过 git branch -a 或 git remote show origin 仍能看到这些分支的引用,是因为本地存储的远程跟踪分支(位于 refs/remotes/origin/)未被同步更新。以下是解决方法: 解决方案&#xff1…

Cubase 通过 MIDIPLUS MIDI 键盘进行走带控制的设置方法

第一步,在官网下载xml配置文件。 https://midiplus.com/upload/202101/29/Xpro & Xpro_mini控制脚本(Cubase).zip 第二步,Cubase中按如图步骤添加映射。 将MIDI键盘连接到电脑后打开Cubase软件,点选菜单“工作室”->“工作室设置”&…

第十八章 Linux之Python定制篇——Python开发平台Ununtu

1. Ubuntu介绍 Ubuntu(友帮拓、优般图、乌班图)是一个以桌面应用为主的开源GUN/Linux操作系统,Ubuntu基于GUN/Linux,支持x86、amd64(即x64)和ppc架构,有全球专业开发团队(Canonical…

推荐轻量级文生视频模型(Text-to-Video)

1. ModelScope T2V by 阿里达摩院(推荐) 模型名:damo/text-to-video-synthesis 输入:一句文字描述(如:"a panda is dancing") 输出:2秒视频(16帧&#xff0c…

流编辑器sed

sed简介 sed是一种流编辑器,处理时,把当前处理的行存储在临时缓冲区中,称为模式空间,接着用sed命令处理缓冲区中的内容,处理完成后,把缓冲区的内容送往屏幕。接着处理下行,这样不断重复&#xf…

商用密码基础知识介绍(上)

一、密码的基础知识 1、密码分类 根据《中华人民共和国密码法》,国家对密码实行分类管理,分为密码分为核心密码、普通密码和商用密码。 (1)核心密码、普通密码 核心密码、普通密码用于保护国家秘密信息,核心密码保护…

PROFINET主站S7-1500通过协议网关集成欧姆龙NJ系列TCP/IP主站

一、项目背景 某大型新能源电池生产企业,致力于提升电池生产的自动化水平和智能化程度。其生产线上,部分关键设备采用了不同的通信协议。在电池生产的前段工序,如原材料搅拌、涂布等环节,使用了西门子S7-1500系列PLC作为ROFINET协…

Vue3 + TypeScript + Element Plus + el-input 输入框列表按回车聚焦到下一行

应用效果:从第一行输入1,按回车,聚焦到第二行输入2,按回车,聚焦到第三行…… 一、通过元素 id,聚焦到下一行的输入框 关键技术点: 1、动态设置元素 id 属性为::id"input-appl…

FramePack 全面测评:革新视频生成体验

在 AI 视频生成领域,FramePack 自问世便备受瞩目,它凭借独特的技术架构,号称能打破传统视频生成对高端硬件的依赖,让普通电脑也能产出高质量视频。此次测评,我们将全方位剖析 FramePack,探究它在实际应用中…

html中的table标签以及相关标签

表格标签可以通过指定的标签完成数据展示 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>表格标签</title> </head> <body><table border"2"><!-- tr是表行 r…

springboot+vue3+vue-simple-uploader轻松实现大文件分片上传Minio

最近在写视频课程的上传&#xff0c;需要上传的视频几百MB到几个G不等&#xff0c;普通的上传都限制了文件的大小&#xff0c;况且上传的文件太大的话会超时、异常等。所以这时候需要考虑分片上传了&#xff0c;把需要上传的视频分成多个小块上传到&#xff0c;最后再合并成一个…

AI 重构代码实战:如何用飞算 JavaAI 快速升级遗留系统?

在企业数字化进程中&#xff0c;遗留系统如同陈旧的基础设施&#xff0c;虽承载着重要业务逻辑&#xff0c;但因技术落后、架构复杂&#xff0c;升级维护困难重重。飞算 JavaAI 的出现&#xff0c;为遗留系统的二次开发带来了新的转机&#xff0c;其基于智能分析与关联项目的技…

鸿蒙运动开发实战:打造专属运动视频播放器

##鸿蒙核心技术##运动开发##Media Kit&#xff08;媒体服务&#xff09;# 在当今数字化时代&#xff0c;运动健身已经成为许多人生活的一部分。今天我将在应用中添加视频播放器&#xff0c;帮助用户在运动前、运动后更好地进行热身和拉伸。这篇文章将从代码核心点入手&#xf…

一个包含15个界面高质量的电商APP客户端UI解决方案

一个包含15个界面高质量的电商APP客户端UI解决方案 您可以将其用于电商APP应用项目。包含一系列完整的界面设计元素&#xff0c;包括欢迎页、登录、注册、首页、产品分类、产品详情、尺码选择、购物车、订单、支付&#xff0c;覆盖电商APP的大部分界面。每个部分都精心设计&…