引言

在计算机视觉领域,UNet架构因其在图像分割任务中的卓越表现而广受欢迎。近年来,注意力机制的引入进一步提升了UNet的性能。本文将深入分析一个结合了线性注意力机制的UNet实现,探讨其设计原理、代码实现以及在医学图像分割等任务中的应用潜力。

UNet架构概述

UNet最初由Ronneberger等人提出,主要用于生物医学图像分割。其独特的U形结构由编码器(下采样路径)和解码器(上采样路径)组成,通过跳跃连接将低层特征与高层特征相结合,既保留了空间信息又利用了深层的语义信息。

传统的UNet结构简单有效,但随着研究的深入,人们发现引入注意力机制可以显著提升模型性能,特别是在处理复杂场景和微小结构时。

线性注意力机制

注意力机制的基本概念

注意力机制的核心思想是让模型能够"关注"输入数据中最相关的部分。在传统的自注意力机制中,计算复杂度通常是O(N²),这对于高分辨率图像来说计算成本很高。

线性注意力实现

在我们的实现中,采用了线性注意力机制来降低计算复杂度。以下是关键的LinearAttention类实现:

class LinearAttention(nn.Module):def __init__(self, channels):super(LinearAttention, self).__init__()self.query = nn.Conv2d(channels, channels // 8, kernel_size=1)self.key = nn.Conv2d(channels, channels // 8, kernel_size=1)self.value = nn.Conv2d(channels, channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):batch_size, C, height, width = x.size()# 计算query, key, valueq = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1)  # (B, N, C')k = self.key(x).view(batch_size, -1, height * width)  # (B, C', N)v = self.value(x).view(batch_size, -1, height * width)  # (B, C, N)# 线性注意力计算kv = torch.bmm(k, v)  # (B, C', C)z = 1 / (torch.bmm(q, k.sum(dim=2, keepdim=True)) + 1e-6)  # (B, N, 1)attn = torch.bmm(q, kv)  # (B, N, C)out = attn * z  # (B, N, C)out = out.view(batch_size, C, height, width)return self.gamma * out + x

这个实现有几个关键特点:

  1. 通道缩减:通过将通道数减少到1/8来降低计算复杂度

  2. 线性复杂度:通过矩阵乘法的重新排列,将复杂度从O(N²)降低到O(N)

  3. 可学习的gamma参数:控制注意力特征与原始特征的混合比例

网络组件详解

双卷积块

双卷积块是UNet的基本构建模块,包含两个连续的3x3卷积层,每个卷积层后接批量归一化和ReLU激活函数。我们的实现增加了可选的注意力机制:

class DoubleConv(nn.Module):def __init__(self, in_channels, out_channels, use_attention=False):super(DoubleConv, self).__init__()self.use_attention = use_attentionself.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if use_attention:self.attention = LinearAttention(out_channels)def forward(self, x):x = self.double_conv(x)if self.use_attention:x = self.attention(x)return x

下采样模块

下采样模块由最大池化层和双卷积块组成:

class Down(nn.Module):def __init__(self, in_channels, out_channels, use_attention=False):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),DoubleConv(in_channels, out_channels, use_attention))def forward(self, x):return self.downsampling(x)

上采样模块

上采样模块使用转置卷积进行上采样,然后与编码路径的特征图拼接,最后通过双卷积块:

class Up(nn.Module):def __init__(self, in_channels, out_channels, use_attention=False):super(Up, self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels, use_attention)def forward(self, x1, x2):x1 = self.upsampling(x1)x = torch.cat([x2, x1], dim=1)return self.conv(x)

完整的UNet架构

结合上述组件,我们构建了完整的UNet模型:

class UNet(nn.Module):def __init__(self, in_channels=1, num_classes=1):super(UNet, self).__init__()self.in_channels = in_channelsself.num_classes = num_classes# 编码器部分self.in_conv = DoubleConv(in_channels, 64, use_attention=True)self.down1 = Down(64, 128, use_attention=True)self.down2 = Down(128, 256, use_attention=True)self.down3 = Down(256, 512, use_attention=True)self.down4 = Down(512, 1024)# 解码器部分self.up1 = Up(1024, 512, use_attention=True)self.up2 = Up(512, 256, use_attention=True)self.up3 = Up(256, 128, use_attention=True)self.up4 = Up(128, 64, use_attention=True)self.out_conv = OutConv(64, num_classes)def forward(self, x):# 编码路径x1 = self.in_conv(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)# 解码路径x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)return self.out_conv(x)

这个架构有几个值得注意的特点:

  1. 对称结构:编码器和解码器基本对称,但最深层的下采样块没有使用注意力机制

  2. 渐进式通道变化:通道数从64开始,每次下采样翻倍,直到1024

  3. 广泛的注意力应用:除了最深层的下采样,其他所有层都应用了注意力机制

注意力机制的应用策略

在我们的实现中,注意力机制的应用策略值得关注:

  1. 编码路径:前四个下采样块中,前三个使用了注意力机制

  2. 解码路径:所有上采样块都使用了注意力机制

  3. 输入输出:输入卷积和最终输出卷积没有使用注意力机制

这种策略基于以下考虑:

  • 深层特征已经具有高度抽象性,可能不需要额外的注意力

  • 解码路径需要精确的定位,注意力机制尤为重要

  • 输入输出层结构简单,注意力机制的收益可能不明显

性能优化考虑

  1. 内存效率:线性注意力显著降低了内存消耗

  2. 计算效率:通过通道缩减和线性复杂度计算保持高效

  3. 数值稳定性:在注意力计算中添加了小常数(1e-6)防止除零错误

实际应用建议

  1. 医学图像分割:这种结构特别适合CT/MRI图像分割任务

  2. 参数调整:可以根据任务复杂度调整注意力层的位置和数量

  3. 输入通道:当前设置为1通道输入,适用于灰度医学图像

扩展可能性

  1. 多模态输入:修改输入通道数以适应RGB或多模态医学图像

  2. 深度监督:在解码路径中添加辅助输出

  3. 注意力变体:尝试其他类型的注意力机制如通道注意力

结论

本文详细分析了一个结合线性注意力机制的UNet实现。这种架构在保持UNet原有优势的同时,通过精心设计的注意力机制提升了模型对重要特征的关注能力。线性注意力的引入使得模型在高分辨率图像上也能高效运行,为医学图像分割等任务提供了有力的工具。

代码实现展示了如何将现代注意力机制与传统UNet架构有机结合,这种模式也可以应用于其他视觉任务的网络设计中。读者可以根据具体任务需求调整注意力层的位置和数量,找到最佳的性能平衡点。

随着注意力机制的不断发展,我们期待看到更多高效、精准的UNet变体出现,推动医学图像分析和其他视觉任务的进步。

完整代码

如下:

import torch.nn as nn
import torch
import mathclass LinearAttention(nn.Module):def __init__(self, channels):super(LinearAttention, self).__init__()self.query = nn.Conv2d(channels, channels // 8, kernel_size=1)self.key = nn.Conv2d(channels, channels // 8, kernel_size=1)self.value = nn.Conv2d(channels, channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):batch_size, C, height, width = x.size()# 计算query, key, valueq = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1)  # (B, N, C')k = self.key(x).view(batch_size, -1, height * width)  # (B, C', N)v = self.value(x).view(batch_size, -1, height * width)  # (B, C, N)# 线性注意力计算kv = torch.bmm(k, v)  # (B, C', C)z = 1 / (torch.bmm(q, k.sum(dim=2, keepdim=True)) + 1e-6)  # (B, N, 1)attn = torch.bmm(q, kv)  # (B, N, C)out = attn * z  # (B, N, C)out = out.view(batch_size, C, height, width)return self.gamma * out + xclass DoubleConv(nn.Module):def __init__(self, in_channels, out_channels, use_attention=False):super(DoubleConv, self).__init__()self.use_attention = use_attentionself.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if use_attention:self.attention = LinearAttention(out_channels)def forward(self, x):x = self.double_conv(x)if self.use_attention:x = self.attention(x)return xclass Down(nn.Module):def __init__(self, in_channels, out_channels, use_attention=False):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),DoubleConv(in_channels, out_channels, use_attention))def forward(self, x):return self.downsampling(x)class Up(nn.Module):def __init__(self, in_channels, out_channels, use_attention=False):super(Up, self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels, use_attention)def forward(self, x1, x2):x1 = self.upsampling(x1)x = torch.cat([x2, x1], dim=1)return self.conv(x)class OutConv(nn.Module):def __init__(self, in_channels, num_classes):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)def forward(self, x):return self.conv(x)class UNet(nn.Module):def __init__(self, in_channels=1, num_classes=1):super(UNet, self).__init__()self.in_channels = in_channelsself.num_classes = num_classes# 编码器部分self.in_conv = DoubleConv(in_channels, 64, use_attention=True)self.down1 = Down(64, 128, use_attention=True)self.down2 = Down(128, 256, use_attention=True)self.down3 = Down(256, 512, use_attention=True)self.down4 = Down(512, 1024)# 解码器部分self.up1 = Up(1024, 512, use_attention=True)self.up2 = Up(512, 256, use_attention=True)self.up3 = Up(256, 128, use_attention=True)self.up4 = Up(128, 64, use_attention=True)self.out_conv = OutConv(64, num_classes)def forward(self, x):# 编码路径x1 = self.in_conv(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)# 解码路径x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)return self.out_conv(x)model = UNet(in_channels=1, num_classes=1)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/85901.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/85901.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/85901.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity技能编辑器深度构建指南:打造专业级战斗系统

本文为技术团队提供完整的技能编辑器开发指南,涵盖核心架构设计、资源管线搭建和协作工作流实现,帮助您构建专业级的战斗技能系统。 一、核心架构设计 1. 基础框架搭建 专用场景模板: 创建SkillEditorTemplate.unity场景 核心节点&#xff…

《游戏工业级CI/CD实战:Jenkins+Node.js自动化构建与本地网盘部署方案》

核心架构图 一、游戏开发CI/CD全流程设计 工作流时序图 二、Jenkins分布式构建配置 1. 节点管理(支持Win/Linux/macOS) // Jenkinsfile 分布式配置示例 pipeline {agent {label game-builder // 匹配带标签的构建节点}triggers {pollSCM(H/5 * * * *)…

Python内存使用分析工具深度解析与实践指南(上篇)

文章目录 引言1. sys.getsizeof()功能程序示例适用场景 2. pandas.Series.memory_usage()功能程序示例适用场景 3. pandas.Series.memory_usage(deepTrue)功能程序示例适用场景注意事项 4. pympler.asizeof()功能安装程序示例适用场景 5. tracemalloc(标准库&#x…

Python 使用 Requests 模块进行爬虫

目录 一、请求数据二、获取并解析数据四、保存数据1. 保存为 CSV 文件2. 保存为 Excel 文件打开网页图片并将其插入到 Excel 文件中 五、加密参数逆向分析1. 定位加密位置2. 断点调试分析3. 复制相关 js 加密代码,在本地进行调试(难)4. 获取 …

MySQL行转列、列转行

要达到的效果: MySQL不支持动态行转列 原始数据: 以行的方式存储 CREATE TABLE product_sales (id INT AUTO_INCREMENT PRIMARY KEY,product_name VARCHAR(50) NOT NULL,category VARCHAR(50) NOT NULL,sales_volume INT NOT NULL,sales_date DATE N…

云创智称YunCharge充电桩互联互通平台使用说明讲解

云创智称YunCharge充电桩互联互通平台使用说明讲解 一、云创智称YunCharge互联互通平台简介 云创智称YunCharge(YunCharge)互联互通平台,旨在整合全国充电桩资源,实现多运营商、多平台、多用户的统一接入和管理,打造开…

HTML+JS实现类型excel的纯静态页面表格,同时单元格内容可编辑

<!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>在线表格</title><style>table {border…

Gartner金融AI应用机会雷达-学习心得

一、引言 在当今数字化时代,人工智能(AI)技术正以前所未有的速度改变着各个行业,金融领域也不例外。财务团队面临着如何从AI投资中获取最大价值的挑战。许多首席财务官(CFO)和财务领导者期望在未来几年增加对AI的投入并从中获得更多收益。据调查,90%的CFO和财务领导者预…

像素着色器没有绘制的原因

背景 directX调用了 draw&#xff0c;顶点着色器运行&#xff0c;但是像素着色器没有运行。 原因 光栅化阶段被剔除 说明&#xff1a;如果几何图元&#xff08;如三角形&#xff09;在光栅化阶段被剔除&#xff0c;像素着色器就不会被调用。常见剔除原因&#xff1a; 背面…

jenkins对接、jenkins-rest

https://www.bilibili.com/video/BV1RqNRz5Eo6 Jenkins是一款常见的构建管理工具&#xff0c;配置好后操作也很简单&#xff0c;只需去控制台找到对应的项目&#xff0c;再输入分支名即可 如果每次只发个位数的项目到也还好&#xff0c;一个个进去点嘛。但如果一次要发几十个项…

北斗导航深度接入小程序打车:高精度定位如何解决定位漂移难题?

你有没有遇到过这样的尴尬&#xff1a; 在写字楼、地下车库或密集楼群中叫车&#xff0c;系统显示的位置和你实际所在位置差了几十米甚至上百米&#xff1b;司机因为找不到你而绕圈&#xff0c;耽误时间还多花平台费用&#xff1b;有时明明站在A出口&#xff0c;司机却跑到B口…

MySQL 主要集群解决方案

MySQL 主要有以下几种集群解决方案&#xff0c;每种方案针对不同的应用场景和需求设计&#xff1a; 1. MySQL Replication&#xff08;主从复制&#xff09; 类型&#xff1a;异步/半同步复制架构&#xff1a;单主多从特点&#xff1a; 读写分离&#xff0c;主库写&#xff0c…

基于vue3+express的非遗宣传网站

​ 一个课程大作业&#xff0c;需要源码可联系&#xff0c;可以在http://8.138.189.55:3001/浏览效果 前端技术 Vue.js 3&#xff1a;我选择了Vue 3作为核心前端框架&#xff0c;并采用了其最新的Composition API开发模式&#xff0c;这使得代码组织更加灵活&#xff0c;逻辑…

【7】图像变换(上)

本节偏难,不用过于深究 考纲 文章目录 可考题【简答题】补充第三版内容:图像金字塔2023甄题【压轴题】习题7.1【第三版】1 基图像2 与傅里叶相关的变换2.1 离散哈特利变换(DHT)可考题【简答题】2.2 离散余弦变换(DCT)2021甄题【简答题】2.3 离散正弦变换(DST)可考题【简…

WinUI3入门9:自制SplitPanel

初级代码游戏的专栏介绍与文章目录-CSDN博客 我的github&#xff1a;codetoys&#xff0c;所有代码都将会位于ctfc库中。已经放入库中我会指出在库中的位置。 这些代码大部分以Linux为目标但部分代码是纯C的&#xff0c;可以在任何平台上使用。 源码指引&#xff1a;github源…

【面板数据】上市公司投资者保护指数(2010-2023年)

上市公司投资者保护指数是基于上市公司年报中公开披露的多项内容&#xff0c;从信息透明度、公司治理结构、关联交易披露、控股股东行为规范等多个维度&#xff0c;评估企业是否在制度上和实际操作中有效保障投资者&#xff0c;特别是中小投资者的合法权益。本分享数据基于我国…

如何解决USB远距离传输难题?一文了解POE USB延长器及其行业应用

在日常办公、教学、医疗和工业系统中&#xff0c;USB接口设备扮演着越来越关键的角色。无论是视频采集设备、键盘鼠标&#xff0c;还是打印机、条码枪&#xff0c;USB早已成为主流连接标准。然而&#xff0c;USB原生传输距离的限制&#xff08;通常在5米以内&#xff09;常常成…

PostgreSQL(TODO)

(TODO) 功能MySQLPostgreSQLJSON 支持支持&#xff0c;但功能相对弱非常强大&#xff0c;支持 JSONB、索引、函数等并发控制行级锁&#xff08;InnoDB&#xff09;&#xff0c;不支持 MVCC多版本并发控制&#xff08;MVCC&#xff09;&#xff0c;性能更好存储过程/触发器支持&…

LINUX 623 FTP回顾

FTP 权限 /etc/vsftpd/vsftpd.conf anonymous_enableNO local_enableNO 服务器 .20 [rootweb vsftpd]# grep -v ^# vsftpd.conf anonymous_enableNO local_enableYES local_root/data/kefu2 chroot_local_userYES allow_writeable_chrootYES write_enableYES local_umask02…

leetcode:77. 组合

学习要点 学习回溯思想&#xff0c;学习回溯技巧&#xff1b;大家应当先看一下下面这几道题 leetcode&#xff1a;46. 全排列-CSDN博客leetcode&#xff1a;78. 子集-CSDN博客leetcode&#xff1a;90. 子集 II-CSDN博客 题目链接 77. 组合 - 力扣&#xff08;LeetCode&#x…