import torch
import torch.nn as nn# 定义通道注意力
class ChannelAttention(nn.Module):def __init__(self, in_channels, ratio=16):"""通道注意力机制初始化参数:in_channels: 输入特征图的通道数ratio: 降维比例,用于减少参数量,默认为16"""super().__init__()# 全局平均池化,将每个通道的特征图压缩为1x1,保留通道间的平均值信息self.avg_pool = nn.AdaptiveAvgPool2d(1)# 全局最大池化,将每个通道的特征图压缩为1x1,保留通道间的最显著特征self.max_pool = nn.AdaptiveMaxPool2d(1)# 共享全连接层,用于学习通道间的关系# 先降维(除以ratio),再通过ReLU激活,最后升维回原始通道数self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // ratio, bias=False),  # 降维层nn.ReLU(),  # 非线性激活函数nn.Linear(in_channels // ratio, in_channels, bias=False)   # 升维层)# Sigmoid函数将输出映射到0-1之间,作为各通道的权重self.sigmoid = nn.Sigmoid()def forward(self, x):"""前向传播函数参数:x: 输入特征图,形状为 [batch_size, channels, height, width]返回:调整后的特征图,通道权重已应用"""# 获取输入特征图的维度信息,这是一种元组的解包写法b, c, h, w = x.shape# 对平均池化结果进行处理:展平后通过全连接网络avg_out = self.fc(self.avg_pool(x).view(b, c))# 对最大池化结果进行处理:展平后通过全连接网络max_out = self.fc(self.max_pool(x).view(b, c))# 将平均池化和最大池化的结果相加并通过sigmoid函数得到通道权重attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)# 将注意力权重与原始特征相乘,增强重要通道,抑制不重要通道return x * attention #这个运算是pytorch的广播机制
## 空间注意力模块
class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):# 通道维度池化avg_out = torch.mean(x, dim=1, keepdim=True)  # 平均池化:(B,1,H,W)max_out, _ = torch.max(x, dim=1, keepdim=True)  # 最大池化:(B,1,H,W)pool_out = torch.cat([avg_out, max_out], dim=1)  # 拼接:(B,2,H,W)attention = self.conv(pool_out)  # 卷积提取空间特征return x * self.sigmoid(attention)  # 特征与空间权重相乘

## CBAM模块
class CBAM(nn.Module):def __init__(self, in_channels, ratio=16, kernel_size=7):super().__init__()self.channel_attn = ChannelAttention(in_channels, ratio)self.spatial_attn = SpatialAttention(kernel_size)def forward(self, x):x = self.channel_attn(x)x = self.spatial_attn(x)return x
# 测试下通过CBAM模块的维度变化
# 输入卷积的尺寸为
# 假设输入特征图:batch=2,通道=512,尺寸=26x26
x = torch.randn(2, 512, 26, 26) 
cbam = CBAM(in_channels=512)
output = cbam(x)  # 输出形状不变:(2, 512, 26, 26)
print(f"Output shape: {output.shape}")  # 验证输出维度

 cnn+CBAM

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 数据预处理(与原代码一致)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 加载数据集(与原代码一致)
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 定义带有CBAM的CNN模型
class CBAM_CNN(nn.Module):def __init__(self):super(CBAM_CNN, self).__init__()# ---------------------- 第一个卷积块(带CBAM) ----------------------self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(32) # 批归一化self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2)self.cbam1 = CBAM(in_channels=32)  # 在第一个卷积块后添加CBAM# ---------------------- 第二个卷积块(带CBAM) ----------------------self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.bn2 = nn.BatchNorm2d(64)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2)self.cbam2 = CBAM(in_channels=64)  # 在第二个卷积块后添加CBAM# ---------------------- 第三个卷积块(带CBAM) ----------------------self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)self.bn3 = nn.BatchNorm2d(128)self.relu3 = nn.ReLU()self.pool3 = nn.MaxPool2d(kernel_size=2)self.cbam3 = CBAM(in_channels=128)  # 在第三个卷积块后添加CBAM# ---------------------- 全连接层 ----------------------self.fc1 = nn.Linear(128 * 4 * 4, 512)self.dropout = nn.Dropout(p=0.5)self.fc2 = nn.Linear(512, 10)def forward(self, x):# 第一个卷积块x = self.conv1(x)x = self.bn1(x)x = self.relu1(x)x = self.pool1(x)x = self.cbam1(x)  # 应用CBAM# 第二个卷积块x = self.conv2(x)x = self.bn2(x)x = self.relu2(x)x = self.pool2(x)x = self.cbam2(x)  # 应用CBAM# 第三个卷积块x = self.conv3(x)x = self.bn3(x)x = self.relu3(x)x = self.pool3(x)x = self.cbam3(x)  # 应用CBAM# 全连接层x = x.view(-1, 128 * 4 * 4)x = self.fc1(x)x = self.relu3(x)x = self.dropout(x)x = self.fc2(x)return x# 初始化模型并移至设备
model = CBAM_CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
# 训练函数
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs):model.train()all_iter_losses = []iter_indices = []train_acc_history = []test_acc_history = []train_loss_history = []test_loss_history = []for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totaltrain_acc_history.append(epoch_train_acc)train_loss_history.append(epoch_train_loss)# 测试阶段model.eval()test_loss = 0correct_test = 0total_test = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_testtest_acc_history.append(epoch_test_acc)test_loss_history.append(epoch_test_loss)scheduler.step(epoch_test_loss)print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc# 绘图函数
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('每个 Iteration 的训练损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='训练准确率')plt.plot(epochs, test_acc, 'r-', label='测试准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.title('训练和测试准确率')plt.legend()plt.grid(True)plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='训练损失')plt.plot(epochs, test_loss, 'r-', label='测试损失')plt.xlabel('Epoch')plt.ylabel('损失值')plt.title('训练和测试损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 执行训练
epochs = 50
print("开始使用带CBAM的CNN训练模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")# # 保存模型
# torch.save(model.state_dict(), 'cifar10_cbam_cnn_model.pth')
# print("模型已保存为: cifar10_cbam_cnn_model.pth")

@浙大疏锦行 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/87601.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/87601.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/87601.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在小程序中实现实时聊天:WebSocket最佳实践

前言 在当今互联网应用中,实时通信已经成为一个标配功能,特别是对于需要即时响应的场景,如在线客服、咨询系统等。本文将分享如何在小程序中实现一个高效稳定的WebSocket连接,以及如何处理断线重连、消息发送与接收等常见问题。 W…

Python网络爬虫编程新手篇

网络爬虫是一种自动抓取互联网信息的脚本程序,广泛应用于搜索引擎、数据分析和内容聚合。这次我将带大家使用Python快速构建一个基础爬虫,为什么使用python做爬虫?主要就是支持的库很多,而且同类型查询文档多,在同等情…

LeetCode.283移动零

题目链接:283. 移动零 - 力扣(LeetCode) 题目描述: 给定一个数组 nums,编写一个函数将所有 0 移动到数组的末尾,同时保持非零元素的相对顺序。 请注意 ,必须在不复制数组的情况下原地对数组进行…

2025年7月4日漏洞文字版表述一句话版本(漏洞危害以及修复建议),通常用于漏洞通报中简洁干练【持续更新中】,漏洞通报中对于各类漏洞及修复指南

漏洞及修复指南 一、暗链 危害:攻击者通过技术手段在用户网页中插入隐藏链接或代码,并指向恶意网站,可导致用户信息泄露、系统感染病毒,用户访问被劫持至恶意网站,泄露隐私或感染恶意软件,被黑客利用进行…

python --飞浆离线ocr使用/paddleocr

依赖 # python3.7.3 paddleocr2.7.0.2 paddlepaddle2.5.2 loguru0.7.3from paddleocr import PaddleOCR import cv2 import numpy as npif __name__ __main__:OCR PaddleOCR(use_doc_orientation_classifyFalse, # 检测文档方向use_doc_unwarpingFalse, # 矫正扭曲文档use…

数据结构与算法:贪心(三)

前言 感觉开始打cf了以后贪心的能力有了明显的提升,让我们谢谢cf的感觉场。 一、跳跃游戏 II class Solution { public:int jump(vector<int>& nums) {int n=nums.size();//怎么感觉这个题也在洛谷上刷过(?)int cur=0;//当前步最远位置int next=0;//多跳一步最远…

【Redis篇】数据库架构演进中Redis缓存的技术必然性—高并发场景下穿透、击穿、雪崩的体系化解决方案

&#x1f4ab;《博主主页》&#xff1a;    &#x1f50e; CSDN主页__奈斯DB    &#x1f50e; IF Club社区主页__奈斯、 &#x1f525;《擅长领域》&#xff1a;擅长阿里云AnalyticDB for MySQL(分布式数据仓库)、Oracle、MySQL、Linux、prometheus监控&#xff1b;并对…

Docker 实践与应用案例

引言 在当今的软件开发和部署领域&#xff0c;高效、可移植且一致的环境搭建与应用部署是至关重要的。Docker 作为一款轻量级的容器化技术&#xff0c;为解决这些问题提供了卓越的方案。Docker 通过容器化的方式&#xff0c;将应用及其依赖项打包成一个独立的容器&#xff0c;…

《论三生原理》以非共识路径实现技术代际跃迁‌?

AI辅助创作&#xff1a; 《论三生原理》以颠覆传统数学范式的非共识路径驱动多重技术代际跃迁&#xff0c;其突破性实践与争议并存&#xff0c;核心论证如下&#xff1a; 一、技术代际跃迁的实证突破‌ ‌芯片架构革新‌ 为华为三进制逻辑门芯片提供理论支撑&#xff0c;通过对…

一体机电脑为何热度持续上升?消费者更看重哪些功能?

一体机电脑&#xff08;AIO&#xff0c;All-in-One&#xff09;将主机硬件与显示器集成于单一机身。通常仅需连接电源线&#xff0c;配备无线键盘、鼠标即可启用。相比传统台式电脑和笔记本电脑&#xff0c;选购一体机的客户更看重一体机的以下特点。 一体机凭借其节省空间、简…

无人机载重模块技术要点分析

一、技术要点 1. 结构设计创新 双电机卷扬系统&#xff1a;采用主电机&#xff08;张力控制&#xff09;和副电机&#xff08;卷扬控制&#xff09;协同工作&#xff0c;解决绳索缠绕问题&#xff0c;支持30米绳长1.2m/s高速收放&#xff0c;重载稳定性提升。 轴双桨布局…

【大模型推理】工作负载的弹性伸缩

基于Knative的LLM推理场景弹性伸缩方案 1.QPS 不是一个好的 pod autoscaling indicator 在LLM推理中&#xff0c; 为什么 2. concurrency适用于单次请求资源消耗大且处理时间长的业务&#xff0c;而rps则适合较短处理时间的业务。 3.“反向弹性伸缩”的概念 4。 区分两种不同的…

STM32F103_Bootloader程序开发12 - IAP升级全流程

导言 本教程使用正点原子战舰板开发。 《STM32F103_Bootloader程序开发11 - 实现 App 安全跳转至 Bootloader》上一章节实现App跳转bootloader&#xff0c;接着&#xff0c;跳转到bootloader后&#xff0c;下位机要发送报文‘C’给IAP上位机&#xff0c;表示我准备好接收固件数…

AI驱动的未来软件工程范式

引言&#xff1a;迈向智能驱动的软件工程新范式 本文是一份关于构建和实施“AI驱动的全生命周期软件工程范式”的简要集成指南。它旨在提供一个独立、完整、具体的框架&#xff0c;指导组织如何将AI智能体深度融合到软件开发的每一个环节&#xff0c;实现从概念到运维的智能化…

Hawk Insight|美国6月非农数据点评:情况远没有看上去那么好

7月3日&#xff0c;美国近期最重要的劳动力数据——6月非农数据公布。在ADP遇冷之后&#xff0c;市场对这份报告格外期待。 根据美国劳工统计局公布报告&#xff0c;美国6月非农就业人口增加 14.7万人&#xff0c;预期 10.6万人&#xff0c;4月和5月非农就业人数合计上修1.6万人…

Python 的内置函数 reversed

Python 内建函数列表 > Python 的内置函数 reversed Python 的内置函数 reversed() 是一个用于序列反转的高效工具函数&#xff0c;它返回一个反向迭代器对象。以下是关于该函数的详细说明&#xff1a; 基本用法 语法&#xff1a;reversed(seq)参数&#xff1a;seq 可以是…

沟通-交流-说话-gt-jl-sh-goutong-jiaoliu-shuohua

沟通,先看|问状态(情绪) 老婆下班回家,我说,到哪儿了,买点玉米哦;她说你为啥不买, 我说怎么如此大火气, 她说你安排我&#xff0c;我不情愿;你怎么看 和女人沟通不能目标优先 先问状态并表达关心 用感谢代替要求&#xff08;“你上次买的玉米特别甜&#xff0c;今天突然又馋了…

Ubuntu20.04运DS-5

准备工作&#xff1a; cd /home/rlk/rlk/runninglinuxkernel_5.0 #make clean mkdir _install_arm64/dev sudo mknod _install_arm64/dev/console c 5 1 ./build_ds5_arm64.sh git checkout boot-wrapper-aarch64/fvp-base-gicv3-psci.dtb ./build_ds5_arm64.sh创建工程步骤2.5…

区块链网络P2P通信原理

目录 区块链网络P2P通信原理引言:去中心化的网络基石1. P2P网络基础架构1.1 区块链网络拓扑1.2 节点类型对比2. 节点发现与连接2.1 初始引导过程2.2 节点发现协议3. 网络通信协议3.1 消息结构3.2 核心消息类型4. 数据传播机制4.1 交易传播流程4.2 Gossip协议实现4.3 区块传播优…

RNN和Transformer区别

RNN&#xff08;循环神经网络&#xff09;和 Transformer 是两种广泛应用于自然语言处理&#xff08;NLP&#xff09;和其他序列任务的深度学习架构。它们在设计理念、性能特点和应用场景上存在显著区别。以下是它们的详细对比&#xff1a;1. 基本架构RNN&#xff08;循环神经网…