@浙大疏锦行 Python day51

复习日,DDPM

class DenoiseDiffusion():def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):super().__init__()self.eps_model = eps_modelself.n_steps = n_stepsself.device = deviceself.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)        # beta值self.alpha = 1. - self.beta                                         # alpha值self.alpha_bar = torch.cumprod(self.alpha, dim=0)                   # alpha_bar值   self.sigma2 = self.beta                                             # sampling中的sigma_tself.tools = Tools()# forward-diffusion process 获得 xt 所服从的高斯分布的mean和vardef q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:mean = self.tools.gather(self.alpha_bar, t) ** 0.5 * x0var = 1 - self.tools.gather(self.alpha_bar, t)return mean, var# forward-diffusion process,生成xtdef q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):if eps is None:eps = torch.randn_like(x0)mean, var = self.q_xt_x0(x0, t)return mean + (var ** 0.5) * eps   # return xt 第t时刻加完噪声的图片# 只有 sampling时才会用到的函数,执行Denoise Process# sampling,根据xt和t推出x_{t-1}         抽象出来的一步,可以用于循环n次def p_sample(self, xt: torch.Tensor, t: torch.Tensor):eps_theta = self.eps_model(xt, t)alpha_bar = self.tools.gather(self.alpha_bar, t)alpha = self.tools.gather(self.alpha, t)eps_coef = (1 - alpha) / (1 - alpha_bar) ** 0.5mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)var = self.tools.gather(self.sigma2, t)eps = torch.randn(xt.shape, device=xt.device)return mean + (var ** 0.5) * eps        # sigma_t * eps + mean# 会更新哪些模型的参数呢?# loss functiondef loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):batch_size = x0.shape[0]t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)if noise is None:noise = torch.randn_like(x0)xt = self.q_sample(x0, t, eps=noise)        # 传入的值为随机噪声 -- 高斯分布eps_theta = self.eps_model(xt, t)           # 模型预测值return F.mse_loss(noise, eps_theta)  # mse loss
# 激活函数
class Swish(nn.Module):def forward(self, x):return x* torch.sigmoid(x)class ResidualBlock(nn.Module):"""每一个Residual block都有两层CNN做特征提取"""def __init__(self, in_channels: int, out_channels: int, time_channels: int,n_groups: int = 32, dropout: float = 0.1):"""Params:in_channels:  输入图片的channel数量out_channels: 经过residual block后输出特征图的channel数量time_channels:time_embedding的向量维度,例如t原来是个整型,值为1,表示时刻1,现在要将其变成维度为(1, time_channels)的向量n_groups:     Group Norm中的超参dropout:      dropout rate"""super().__init__()# 第一层卷积 = Group Norm + CNNself.norm1 = nn.GroupNorm(n_groups, in_channels)self.act1 = Swish()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))# 第二层卷积 = Group Norm + CNNself.norm2 = nn.GroupNorm(n_groups, out_channels)self.act2 = Swish()self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))# 当in_c = out_c时,残差连接直接将输入输出相加;# 当in_c != out_c时,对输入数据做一次卷积,将其通道数变成和out_c一致,再和输出相加if in_channels != out_channels:self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))            # 使用 1x1卷积修改通道数else:self.shortcut = nn.Identity()      # 占位 # t向量的维度time_channels可能不等于out_c,所以我们要对起做一次线性转换self.time_emb = nn.Linear(time_channels, out_channels)self.time_act = Swish()self.dropout = nn.Dropout(dropout)def forward(self, x: torch.Tensor, t: torch.Tensor):"""Params:x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width)t: 输入数据t,尺寸大小为(batch_size, time_c)【配合图例进行阅读】"""# 1.输入数据先过一层卷积h = self.conv1(self.act1(self.norm1(x)))# 2. 对time_embedding向量,通过线性层使time_c变为out_c,再和输入数据的特征图相加h += self.time_emb(self.time_act(t))[:, :, None, None]# 3、过第二层卷积h = self.conv2(self.dropout(self.act2(self.norm2(h))))# 4、返回残差连接后的结果return h + self.shortcut(x)# Attention Block
# 通道注意力机制class AttentionBlock(nn.Module):"""Attention模块和Transformer中的multi-head attention原理及实现方式一致"""def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):"""Params:n_channels:等待做attention操作的特征图的channel数n_heads:   attention头数d_k:       每一个attention头处理的向量维度n_groups:  Group Norm超参数"""super().__init__()# 一般而言,d_k = n_channels // n_heads,需保证n_channels能被n_heads整除if d_k is None:d_k = n_channels# 定义Group Normself.norm = nn.GroupNorm(n_groups, n_channels)# Multi-head attention层: 定义输入token分别和q,k,v矩阵相乘后的结果self.projection = nn.Linear(n_channels, n_heads * d_k * 3)# MLP层self.output = nn.Linear(n_heads * d_k, n_channels)self.scale = d_k ** -0.5self.n_heads = n_headsself.d_k = d_kdef forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):"""Params:x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width)t: 输入数据t,尺寸大小为(batch_size, time_c)【配合图例进行阅读】"""# t并没有用到,但是为了和ResidualBlock定义方式一致,这里也引入了t_ = t# 获取shapebatch_size, n_channels, height, width = x.shape# 将输入数据的shape改为(batch_size, height*weight, n_channels)# 这三个维度分别等同于transformer输入中的(batch_size, seq_length, token_embedding)x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)# 计算输入过矩阵q,k,v的结果,self.projection通过矩阵计算,一次性把这三个结果出出来 也就是qkv矩阵是三个结果的拼接# 其shape为:(batch_size, height*weight, n_heads, 3 * d_k)qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)# 将拼接结果切开,每一个结果的shape为(batch_size, height*weight, n_heads, d_k)q, k, v = torch.chunk(qkv, 3, dim=-1)# 以下是正常计算attention score的过程,不再做说明attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scaleattn = attn.softmax(dim=2)res = torch.einsum('bijh,bjhd->bihd', attn, v)# 将结果reshape成(batch_size, height*weight,, n_heads * d_k)# 复习一下:n_heads * d_k = n_channelsres = res.view(batch_size, -1, self.n_heads * self.d_k)# MLP层,输出结果shape为(batch_size, height*weight,, n_channels)res = self.output(res)# 残差连接res += x# 将输出结果从序列形式还原成图像形式,# shape为(batch_size, n_channels, height, width)res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)return res
class DownBlock(nn.Module):def __init__(self, in_channels: int, out_channels: int, time_channels: int,use_attention: bool = False):super().__init__()self.res_block = ResidualBlock(in_channels, out_channels, time_channels)if use_attention:self.attn_block = AttentionBlock(out_channels)else:self.attn_block = nn.Identity()def forward(self, x: torch.Tensor, t: torch.Tensor):x = self.res_block(x, t)x = self.attn_block(x)return xclass UpBlock(nn.Module):def __init__(self, in_channels: int, out_channels: int, time_channels: int,use_attention: bool = False):super.__init__()self.res_block = ResidualBlock(in_channels + out_channels, out_channels, time_channels)if use_attention:self.attn = AttentionBlock(out_channels)else:self.attn = nn.Identity()def forward(self, x: torch.Tensor, t: torch.Tensor):x = self.res_block(x, t)x = self.attn(x)return x
class TimeEmbedding(nn.Module):def __init__(self, n_channels: int):"""Params:n_channels:即time_channel"""super().__init__()self.n_channels = n_channelsself.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)self.act = Swish()self.lin2 = nn.Linear(self.n_channels, self.n_channels)def forward(self, t: torch.Tensor):"""Params:t: 维度(batch_size),整型时刻t"""# 以下转换方法和Transformer的位置编码一致# 【强烈建议大家动手跑一遍,打印出每一个步骤的结果和尺寸,更方便理解】half_dim = self.n_channels // 8emb = math.log(10_000) / (half_dim - 1)emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)emb = t[:, None] * emb[None, :]emb = torch.cat((emb.sin(), emb.cos()), dim=1)# Transform with the MLPemb = self.act(self.lin1(emb))emb = self.lin2(emb)# 输出维度(batch_size, time_channels)return emb
class Upsample(nn.Module):"""上采样"""def __init__(self, n_channels):super().__init__()self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))def forward(self, x: torch.Tensor, t: torch.Tensor):_ = treturn self.conv(x)class Downsample(nn.Module):"""下采样"""def __init__(self, n_channels):super().__init__()self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))def forward(self, x: torch.Tensor, t: torch.Tensor):_ = treturn self.conv(x)class MiddleBlock(nn.Module):def __init__(self, n_channels: int, time_channels: int):super.__init__()self.res1 = ResidualBlock(n_channels, n_channels, time_channels)self.attn = AttentionBlock(n_channels)self.res2 = ResidualBlock(n_channels, n_channels, time_channels)def forward(self, x: torch.Tensor, t: torch.Tensor):x = self.res1(x, t)x = self.attn(x)x = self.res2(x, t)return x
class UNet(Module):"""DDPM UNet去噪模型主体架构"""def __init__(self, image_channels: int = 3, n_channels: int = 64,ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),n_blocks: int = 2):"""Params:image_channels:原始输入图片的channel数,对RGB图像来说就是3n_channels:    在进UNet之前,会对原始图片做一次初步卷积,该初步卷积对应的out_channel数,也就是图中左上角的第一个墨绿色箭头ch_mults:      在Encoder下采样的每一层的out_channels倍数,例如ch_mults[i] = 2,表示第i层特征图的out_channel数,是第i-1层的2倍。Decoder上采样时也是同理,用的是反转后的ch_multsis_attn:       在Encoder下采样/Decoder上采样的每一层,是否要在CNN做特征提取后再引入attention(会在下文对该结构进行详细说明)n_blocks:      在Encoder下采样/Decoder下采样的每一层,需要用多少个DownBlock/UpBlock(见图),Deocder层最终使用的UpBlock数=n_blocks + 1     """super().__init__()# 在Encoder下采样/Decoder上采样的过程中,图像依次缩小/放大,# 每次变动都会产生一个新的图像分辨率# 这里指的就是不同图像分辨率的个数,也可以理解成是Encoder/Decoder的层数n_resolutions = len(ch_mults)# 对原始图片做预处理,例如图中,将32*32*3 -> 32*32*64self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))# time_embedding,TimeEmbedding是nn.Module子类,我们会在下文详细讲解它的属性和forward方法self.time_emb = TimeEmbedding(n_channels * 4)# --------------------------# 定义Encoder部分# --------------------------# down列表中的每个元素表示Encoder的每一层down = []# 初始化out_channel和in_channelout_channels = in_channels = n_channels# 遍历每一层for i in range(n_resolutions):# 根据设定好的规则,得到该层的out_channelout_channels = in_channels * ch_mults[i]# 根据设定好的规则,每一层有n_blocks个DownBlockfor _ in range(n_blocks):down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))in_channels = out_channels# 对Encoder来说,每一层结束后,我们都做一次下采样,但Encoder的最后一层不做下采样if i < n_resolutions - 1:down.append(Downsample(in_channels))# self.down即是完整的Encoder部分self.down = nn.ModuleList(down)# --------------------------# 定义Middle部分# --------------------------self.middle = MiddleBlock(out_channels, n_channels * 4, )# --------------------------# 定义Decoder部分# --------------------------# 和Encoder部分基本一致,可对照绘制的架构图阅读up = []in_channels = out_channelsfor i in reversed(range(n_resolutions)):# `n_blocks` at the same resolutionout_channels = in_channelsfor _ in range(n_blocks):up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))out_channels = in_channels // ch_mults[i]up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))in_channels = out_channelsif i > 0:up.append(Upsample(in_channels))# self.up即是完整的Decoder部分self.up = nn.ModuleList(up)# 定义group_norm, 激活函数,和最后一层的CNN(用于将Decoder最上一层的特征图还原成原始尺寸)self.norm = nn.GroupNorm(8, n_channels)self.act = Swish()self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))def forward(self, x: torch.Tensor, t: torch.Tensor):"""Params:x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width)t: 输入数据t,尺寸大小为(batch_size)"""# 取得time_embeddingt = self.time_emb(t)# 对原始图片做初步CNN处理x = self.image_proj(x)# -----------------------# Encoder# -----------------------h = [x]# First half of U-Netfor m in self.down:x = m(x, t)h.append(x)# -----------------------# Middle# -----------------------x = self.middle(x, t)# -----------------------# Decoder# -----------------------for m in self.up:if isinstance(m, Upsample):x = m(x, t)else:s = h.pop()# skip_connectionx = torch.cat((x, s), dim=1)x = m(x, t)return self.final(self.act(self.norm(x)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/93850.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/93850.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/93850.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构:生成 (Generating) 一棵 AVL 树

目录 搭建“创世”的舞台 注入序列&#xff0c;观察演化 注入 10 注入 20 注入 30 注入 40 注入 50 注入 25 再次审视 上一讲&#xff0c;我们已经从最根本的逻辑出发&#xff0c;推导出了 AVL 树失衡时所必需的修复操作——旋转 (Rotation)。 现在&#xff0c;我们将…

github 上传代码步骤

登录GitHub → 点击右上角 ​​ → New Repository​​。填写仓库名称&#xff08;建议与本地项目同名&#xff09;&#xff0c;选择 ​​Public/Private​​。​​关键&#xff1a;不要勾选​​ “Initialize with README”&#xff08;避免与本地仓库冲突&#xff09;。点击 …

陪诊小程序系统开发:开启智慧就医新时代

在数字化浪潮的推动下&#xff0c;智慧医疗正逐渐成为现实。陪诊小程序系统的开发&#xff0c;作为智慧医疗领域的一次重要创新&#xff0c;正以其独特的魅力与优势&#xff0c;引领着就医新时代的到来。它不仅改变了传统就医模式&#xff0c;更以科技的力量&#xff0c;让医疗…

朝花夕拾(七)--------从混淆矩阵到分类报告全面解析​

目录 ​​机器学习模型评估指南&#xff1a;从混淆矩阵到分类报告全面解析​​ ​​1. 引言​​ ​​2. 混淆矩阵&#xff1a;模型评估的基石​​ ​​2.1 什么是混淆矩阵&#xff1f;​​ 2.2二分类问题的混淆矩阵 ​​二分类场景下的具体案例​ ​分析案例: 1.​​案例…

Python读取和设置PNG图片的像素值

在Python中&#xff0c;可以使用Pillow库或OpenCV库来读取和写入PNG图片的像素值。以下是两种方法的详细说明&#xff1a;1. 使用Pillow库Pillow是Python中常用的图像处理库&#xff0c;支持多种图像格式&#xff0c;包括PNG。读取像素值from PIL import Imageimg Image.open(…

SkyWalking + Elasticsearch8 容器化部署指南:国内镜像加速与生产级调优

SkyWalking Elasticsearch8 Docker 部署文档本文提供在 Ubuntu 服务器上&#xff0c;使用 Docker Compose 部署 SkyWalking&#xff08;OAPUI&#xff09;与 Elasticsearch 8 的完整步骤&#xff0c;数据/日志落地到 /media/disk2 前置条件 Ubuntu&#xff0c;已具备 sudo 权限…

有符号和无符号的区别

有符号&#xff08;Signed&#xff09;和无符号&#xff08;Unsigned&#xff09;是计算机编程中用来描述整数数据类型能否表示负数的两个概念。它们的主要区别在于能否表示负数以及数值的表示范围。以下是它们的核心区别&#xff1a;1. 能否表示负数有符号&#xff08;Signed&…

8月21日作业

1、Makefile中头文件发生过修改的解决&#xff1a; 处插入*.h依赖&#xff0c;对.h文件打的时间戳进行检查2、头删和输出//五、头删 void delete_head(seq_p s) {empty(s);for(int i1;i<s->len;i){s->data[i-1]s->data[i];}s->len--; }//六、输出 void output(s…

Lucene 8.5.0 的 `.pos` 文件**逻辑结构**

Lucene 8.5.0 的 .pos 文件**逻辑结构**&#xff08;按真实实现重新整理&#xff09; .pos 文件 ├─ Header (CodecHeader) ├─ TermPositions TermCount ← 每个 term 一段&#xff0c;顺序由词典隐式决定 │ ├─ PackedPosDeltaBlock N ← 仅当 **无 payl…

基于Matlab多技术融合的红外图像增强方法研究

红外图像在低照度、强干扰和复杂环境下具有较强的成像能力&#xff0c;但受传感器噪声、成像条件及大气衰减等因素影响&#xff0c;原始红外图像往往存在对比度低、细节模糊及光照不均等问题。本文针对红外图像质量退化的特点&#xff0c;提出了一种基于多算法融合的红外图像增…

【时时三省】集成测试 简介

山不在高,有仙则名。水不在深,有龙则灵。 ----CSDN 时时三省 目录 1,集成测试含义 2,集成测试 验证方法 3,集成测试 用例设计方法 4,集成测试输出物 5,集成测试注意点 1,集成测试含义 单元测试在以V模型的流程中,对应的是架构设计阶段。在 单元测试 和 架构设计…

leetcode 76 最小覆盖子串

一、题目描述二、解题思路整体思路&#xff1a;模拟寻找最小覆盖子集的过程&#xff0c;由于可借助同向双指针且可以做到指针不回退&#xff0c;所以可以用滑动窗口的思想来解决这个问题。具体思路&#xff1a;(1)数组hash1用于统计t中每一个字符出现的频次&#xff0c;变量kin…

阿里云ECS服务器的公网IP地址

文章目录环境背景查询公网IP地址阿里云控制台阿里云客户端工具&#xff08;图形界面&#xff09;阿里云CLI工具&#xff08;命令行&#xff09;其它方法元数据服务器ipinfo.io参考注&#xff1a;本文介绍了如何获取阿里云ECS服务器的公网IP地址&#xff0c;可以顺便了解一下和阿…

IPSec 与 IKE 核心知识点总结

一、IPSec 安全基础IPSec 是保障 IP 数据传输安全的核心协议&#xff0c;其核心围绕密钥管理和安全策略约定展开&#xff0c;具体包括以下关键内容&#xff1a;1. 对称密钥的作用与要求对称密钥是 IPSec 实现加密、验证的基础&#xff0c;主要用于三个场景&#xff1a;加密 / 解…

C2ComponentStore

1. C2ComponentStore这是 Codec 2.0 HAL 的抽象接口&#xff08;frameworks/av/media/codec2/core/include/C2ComponentStore.h&#xff09;。代表一个「组件工厂」&#xff0c;负责&#xff1a;枚举当前可用的 Codec2 组件&#xff08;解码器、编码器&#xff09;。创建组件&a…

AI 在医疗领域的应用与挑战

引言介绍 AI 技术迅猛发展的大背景&#xff0c;引出其在医疗领域的重要应用。阐述研究 AI 医疗应用及挑战对推动医疗行业进步的重要意义。AI 在医疗领域的应用现状疾病诊断辅助&#xff1a;描述 AI 影像识别技术在识别 X 光、CT、MRI 影像中疾病特征的应用&#xff0c;如对肺癌…

【GPT入门】第51课 Conda环境迁移教程:将xxzh环境从默认路径迁移到指定目录

【GPT入门】第51课 Conda环境迁移教程&#xff1a;将xxzh环境从默认路径迁移到指定目录步骤1&#xff1a;创建目标目录&#xff08;若不存在&#xff09;步骤2&#xff1a;克隆环境到新路径步骤3&#xff1a;验证新环境可用性步骤4&#xff1a;删除旧环境&#xff08;可选&…

应急响应-模拟服务器挂马后的应急相关操作

工具&#xff1a;攻击机&#xff1a; kail:192.168.108.131 kail下载地址&#xff1a;https://mirrors.aliyun.com/kali-images/kali-2021.3/kali-linux-2021.3-live-i386.iso靶机&#xff1a;windows 7: 192.168.108.1321、在kali中制作木马文件&#xff1a;vhost.exe&#xf…

记一次 .NET 某光谱检测软件 内存暴涨分析

一&#xff1a;背景 1. 讲故事 训练营里的一位学员找到我&#xff0c;说他们的系统会出现内存暴涨的情况&#xff0c;看了下也不是托管堆的问题&#xff0c;让我协助一下到底怎么回事&#xff1f;既然有dump了&#xff0c;那就开始分析之旅吧。 二&#xff1a;内存暴涨分析 1. …

基于OpenCV的物体识别与计数

在计算机视觉领域&#xff0c;利用图像处理技术进行物体识别和计数是一项基础且重要的任务。本文将介绍一种使用OpenCV库实现的高效物体识别与计数方法&#xff0c;并提供一些代码片段以帮助理解各个步骤。 这是前几年做过传统图像处理计数的项目&#xff0c;通过传统图像处理之…