17.1 DenseNet网络结构设计

在这里插入图片描述

import torch
from torch import nn
from torchsummary import summary
#卷积层
def conv_block(input_channels,num_channels):net=nn.Sequential(nn.BatchNorm2d(input_channels),nn.ReLU(),nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1))return net
#过渡层
def transition_block(inputs_channels,num_channels):net=nn.Sequential(nn.BatchNorm2d(inputs_channels),nn.ReLU(),nn.Conv2d(inputs_channels,num_channels,kernel_size=1),nn.AvgPool2d(kernel_size=2,stride=2))return net
#DenseNetBlock
class DenseBlock(nn.Module):def __init__(self, num_convs,input_channels,num_channels):super(DenseBlock,self).__init__()layer=[]for i in range(num_convs):layer.append(conv_block(num_channels*i+input_channels,num_channels))self.net=nn.Sequential(*layer)def forward(self,X):for blk in self.net:Y=blk(X)X=torch.cat((X,Y),dim=1)return X
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
num_channels,growth_rate=64,32
num_convs_in_dense_block=[4,4,4,4]
blks=[]
for i,num_convs in enumerate(num_convs_in_dense_block):blks.append(DenseBlock(num_convs,num_channels,growth_rate))# 上一个稠密块的输出通道数num_channels+=num_convs*growth_rate# 在稠密块之间添加一个转换层,使通道数量减半if i!=len(num_convs_in_dense_block)-1:blks.append(transition_block(num_channels,num_channels//2))num_channels=num_channels//2
model=nn.Sequential(b1,*blks,nn.BatchNorm2d(num_channels),nn.ReLU(),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(num_channels,10))
device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model.to(device)
summary(model,input_size=(1,224,224),batch_size=64)

在这里插入图片描述

17.2 DenseNet网络实现Fashion-Mnist分类

################################################################################################################
#DenseNet
################################################################################################################
import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
plt.rcParams['font.family']=['Times New Roman']
class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):epochs = range(1, len(train_loss_list) + 1)plt.figure(figsize=(4, 3))plt.plot(epochs, train_loss_list, label='Train Loss')plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')plt.xlabel('Epoch')plt.ylabel('Value')plt.title(title)plt.legend()plt.grid(True)plt.tight_layout()plt.show()
def train_model(model,train_data,test_data,num_epochs):train_loss_list = []train_acc_list = []test_acc_list = []for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop:#X=X.reshape(X.shape[0],-1)#print(X.shape)X=X.to(device)y=y.to(device)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数total_samples+=X.shape[0]test_acc_samples=0test_samples=0for X,y in test_data:X=X.to(device)y=y.to(device)#X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数test_samples+=X.shape[0]avg_train_loss=total_loss/total_samplesavg_train_acc=total_acc_sample/total_samplesavg_test_acc=test_acc_samples/test_samplestrain_loss_list.append(avg_train_loss)train_acc_list.append(avg_train_acc)test_acc_list.append(avg_test_acc)print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")plot_metrics(train_loss_list, train_acc_list, test_acc_list)return model
def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)
################################################################################################################
#DenseNet
################################################################################################################
#卷积层
def conv_block(input_channels,num_channels):net=nn.Sequential(nn.BatchNorm2d(input_channels),nn.ReLU(),nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1))return net
#过渡层
def transition_block(inputs_channels,num_channels):net=nn.Sequential(nn.BatchNorm2d(inputs_channels),nn.ReLU(),nn.Conv2d(inputs_channels,num_channels,kernel_size=1),nn.AvgPool2d(kernel_size=2,stride=2))return net
#DenseNetBlock
class DenseBlock(nn.Module):def __init__(self, num_convs,input_channels,num_channels):super(DenseBlock,self).__init__()layer=[]for i in range(num_convs):layer.append(conv_block(num_channels*i+input_channels,num_channels))self.net=nn.Sequential(*layer)def forward(self,X):for blk in self.net:Y=blk(X)X=torch.cat((X,Y),dim=1)return X
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
num_channels,growth_rate=64,32
num_convs_in_dense_block=[4,4,4,4]
blks=[]
for i,num_convs in enumerate(num_convs_in_dense_block):blks.append(DenseBlock(num_convs,num_channels,growth_rate))# 上一个稠密块的输出通道数num_channels+=num_convs*growth_rate# 在稠密块之间添加一个转换层,使通道数量减半if i!=len(num_convs_in_dense_block)-1:blks.append(transition_block(num_channels,num_channels//2))num_channels=num_channels//2
################################################################################################################
transforms=transforms.Compose([transforms.Resize(96),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])#第一个是mean,第二个是std
train_img=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=transforms,download=True)
train_data=DataLoader(train_img,batch_size=128,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=128,num_workers=4,shuffle=False)
################################################################################################################
device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model=nn.Sequential(b1,*blks,nn.BatchNorm2d(num_channels),nn.ReLU(),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(num_channels,10))
model.to(device)
model.apply(init_weights)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_data,test_data,num_epochs=10)
################################################################################################################

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/88874.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/88874.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/88874.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网安系列【16】之Weblogic和jboss漏洞

文章目录一 Weblogic1.1 Weblogic相关漏洞1.2 Weblogic漏洞发现1.3 Weblogic漏洞利用二 Jboss2.1 Jboss漏洞2.2 Jboss识别与漏洞利用一 Weblogic WebLogic 是由 Oracle公司 开发的一款基于Java EE(现称Jakarta EE)的企业级应用服务器,主要用…

Unity URP + XR 自定义 Skybox 在真机变黑问题全解析与解决方案(支持 Pico、Quest 等一体机)

在使用 Unity 的 URP 渲染管线开发 XR 应用(如 Pico Neo、Pico 4、Quest 2/3 等一体机)时,很多开发者遇到一个奇怪的问题:打包后,Skybox(天空盒)在某些角度下突然变黑,只在转动头部后…

Cursor、飞算JavaAI、GitHub Copilot、Gemini CLI 等热门 AI 开发工具合集

Cursor:代码编写的智能伙伴​Cursor 是 Anysphere 公司推出的一款 AI 编程工具,它基于微软开源代码编辑器 VS Code 开发,将 AI 技术深度整合到开发人员的工作流程中。Cursor 的功能十分强大,不仅能够自动用纯英文编写代码&#xf…

如何安装历史版本或指定版本的 git

背景 有的时候,我们需要安装指定版本的git,或者希望旧一点的,毕竟我就遇到最新的2.50.1在win10安装后打开就一闪而过,而安装2.49.1就不会 下载 官网可能比较难找,但是这个github仓库:https://github.com/gi…

LaCo: Large Language Model Pruning via Layer Collapse

发表:EMNLP_FINDING_2024 机构:Shanghai Jiao Tong University 连接:LaCo: Large Language Model Pruning via Layer Collapse - ACL Anthology 代码:https://github.com/yangyifei729/LaCo Abstract 基于 Transformer 的大语…

服务器内核级故障排查

目录 **检查内核级故障(Oops/Panic)的具体操作步骤****1. 查看完整 `dmesg` 日志(含时间戳)****2. 过滤关键错误信息****3. 检查系统日志中的内核消息****4. 分析最近一次启动的日志****5. 检查是否有 `vmcore` 转储文件****常见内核错误示例及含义**补充说明:检查内核级故…

Flink学习笔记:整体架构

开一个新坑,系统性的学习下 Flink,计划从整体架构到核心概念再到调优方法,最后是相关源码的阅读。 今天就来学习 Flink 整体架构,我们先看官网的架构图图中包含三部分,分别是 Client、JobManager 和 TaskManager。其中…

【LeetCode 热题 100】105. 从前序与中序遍历序列构造二叉树——(解法二)O(n)

Problem: 105. 从前序与中序遍历序列构造二叉树 给定两个整数数组 preorder 和 inorder ,其中 preorder 是二叉树的先序遍历, inorder 是同一棵树的中序遍历,请构造二叉树并返回其根节点。 【LeetCode 热题 100】105. 从前序与中序遍历序列构…

完美卸载 Ubuntu 双系统:从规划到实施的完整指南

📖 前言 最近成功完成了一次 Ubuntu 双系统的完整卸载,从最初的分区删除到最终解决 GRUB 引导问题,整个过程虽然有些曲折,但最终完美解决。本文将详细分享整个卸载过程,希望能帮助到有类似需求的朋友。 &#x1f3af…

深入理解oracle ADG和RAC

1. 引言 本节详细介绍oracle ADG和RAC。当然这里讲得的详细是相对理论的深入,不涉及到实验,比如ADG和RAC的搭建及调优等。 RAC (Real Application Clusters) 和 ADG (Active Data Guard)是Oracle 的两大核心高可用和灾备技术。它们是 Oracle 数据库高可用…

网络安全实践:从环境搭建到漏洞复现

要求:1.搭建docker2.使用小皮面板搭建pikachu靶场3.使用BP的爆破模块破解pikachu的登陆密码步骤4.Kail的msf复现永恒之蓝一.搭建docker1. Docker介绍Docker 是容器,可以部分完全封闭。封闭意味:一个物质(放到容器)&…

车载诊断架构 --- 诊断功能开发流程

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 做到欲望极简,了解自己的真实欲望,不受外在潮流的影响,不盲从,不跟风。把自己的精力全部用在自己。一是去掉多余,凡事找规律,基础是诚信;二是…

mysql数据库知识

MySQL数据库详解MySQL是目前全球最流行的关系型数据库管理系统之一,以其开源免费、高效稳定、易于扩展等特点,被广泛应用于Web开发、企业级应用等场景。本文将从基础概念、核心特性到实际应用,对MySQL进行全面解析。一、MySQL的基本概念1. 关…

基于springboot的美食文化和旅游推广系统

博主介绍:java高级开发,从事互联网行业多年,熟悉各种主流语言,精通java、python、php、爬虫、web开发,已经做了多年的毕业设计程序开发,开发过上千套毕业设计程序,没有什么华丽的语言&#xff0…

Rust赋能文心大模型4.5智能开发

文心大模型4.5版本概论 文心大模型4.5是百度推出的最新一代大规模预训练语言模型,属于文心大模型(ERNIE)系列。该模型在自然语言处理(NLP)、多模态理解与生成等领域表现出色,广泛应用于智能搜索、内容创作、对话交互等场景。 核心能力 语言理解与生成 支持复杂语义理解…

前端抓包(不启动前端项目就能进行后端调试)--whistle

1、安装 1.1.安装node.js 1.2.安装whistle npm install -g whistle2.安装浏览器插件【SwitchyOmega】在谷歌浏览器应用商店下载安装即可配置proxy127.0.0.1:8989是w2 start的端口号启用代理3.启动服务(每次抓包都得启动) w2 start点击链接访问网页 http:…

kettle从入门到精通 第102课 ETL之kettle xxl-job调度kettle的两种方式

之前我们一起学习过xxl-job调度carte,采用的xxl-job执行器方式,不了解的可以查看《kettle从入门到精通 第六十一课 ETL之kettle 任务调度器,轻松使用xxl-job调用kettle中的job和trans 》 今天我们一起来学习下使用xxl-job直接使用http调用…

纯前端 JavaScript 实现数据导出到 CSV 格式

日常开发中,数据导出到文件通常有两种方式: 在后端处理,以文件流或者资源路径的方式返回;后端返回数据,前端按需处理后再触发浏览器的下载事件,已保存到本地文件。 这里介绍后者的一种零依赖的实现方式。…

香港理工大学实验室定时预约

香港理工大学实验室定时预约 文章目录香港理工大学实验室定时预约简介接单价格软件界面网站预约界面代码对爬虫、逆向感兴趣的同学可以查看文章,一对一小班教学(系统理论和实战教程)、提供接单兼职渠道:https://blog.csdn.net/weixin_35770067/article/d…

Spring AI 项目实战(十七):Spring Boot + AI + 通义千问星辰航空智能机票预订系统(附完整源码)

系列文章 序号文章名称1Spring AI 项目实战(一):Spring AI 核心模块入门2Spring AI 项目实战(二):Spring Boot + AI + DeepSeek 深度实战(附完整源码)3Spring AI 项目实战(三):Spring Boot + AI + DeepSeek 打造智能客服系统(附完整源码)4