最近因为工作需求,接触了Latent Diffusion中VAE训练的相关代码,其中损失函数是由名为LPIPSWithDiscriminator的类进行计算的,包括像素级别的重建损失(rec_loss)、感知损失(p_loss)和基于判别器(g_loss)的对抗损失等。在阅读源码过程中有很多疑问,因此用该博客记录一下学习过程和相关思考,如有不对的地方欢迎评论区批评指正。


LPIPSWithDiscriminator

LPIPSWithDiscriminator主要包含以下三个方法,本文主要关注前向过程forward()和自适应权重计算方法calculate_adaptive_weight():

class LPIPSWithDiscriminator(nn.Module):# 初始化相关属性和方法def __init__(self, ):pass# 根据负对数损失(nll_loss)和对抗损失的梯度自适应计算对抗损失的权重def calculate_adaptive_weight(self, ):pass# 前向传播过程,计算各种损失def forward():pass

首先附上前向传播方法的源码,建议先结合注释把源码大概顺一遍:

    def forward(self, inputs, reconstructions, posteriors, optimizer_idx,global_step, last_layer=None, cond=None, split="train",weights=None):# 图像级别重建损失rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())if self.perceptual_weight > 0:# 感知损失p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())# 问题1:为什么要把感知损失加到每个像素位置的重建损失上???rec_loss = rec_loss + self.perceptual_weight * p_loss# negative log-likelihood loss# 问题2:为什么要进行这种处理?nll_loss = rec_loss / torch.exp(self.logvar) + self.logvarweighted_nll_loss = nll_lossif weights is not None:weighted_nll_loss = weights*nll_lossweighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]# 和标准正态分布计算KL散度kl_loss = posteriors.kl()kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]# now the GAN partif optimizer_idx == 0:# generator updateif cond is None:assert not self.disc_conditionallogits_fake = self.discriminator(reconstructions.contiguous())else:assert self.disc_conditionallogits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))g_loss = -torch.mean(logits_fake)if self.disc_factor > 0.0:try:# 问题3:为什么计算自适应权重d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)except RuntimeError:assert not self.trainingd_weight = torch.tensor(0.0)else:d_weight = torch.tensor(0.0)disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_losslog = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),"{}/rec_loss".format(split): rec_loss.detach().mean(),"{}/d_weight".format(split): d_weight.detach(),"{}/disc_factor".format(split): torch.tensor(disc_factor),"{}/g_loss".format(split): g_loss.detach().mean(),}return loss, logif optimizer_idx == 1:# second pass for discriminator updateif cond is None:logits_real = self.discriminator(inputs.contiguous().detach())logits_fake = self.discriminator(reconstructions.contiguous().detach())else:logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),"{}/logits_real".format(split): logits_real.detach().mean(),"{}/logits_fake".format(split): logits_fake.detach().mean()}return d_loss, log

我看完源码,有三个问题:

  • 问题1(未解决):为什么要把感知损失p_loss和未求均值的重建损失rec_loss相加?因为通过debug可知,p_loss没有空间维度,是单个值,但rec_loss是保留空间维度的,例如输入数据是一张三通道512×512大小的图像, 则rec_loss的维度为(1,3,512,512),由于广播机制,两者相加会导致感知损失p_loss会被添加到每个像素位置的重建损失rec_loss上。这个问题我到现在还没有理解,如果有懂的大佬可以在评论区指出。
  • 问题2:为什么负对数损失的计算公式是那样?它的意义是什么?
  • 问题3:为什么要自适应的计算对抗损失的权重系数?

问题2

为什么要使用负对数损失?

负对数损失只是最后结果的形式,我们还需要了解具体的过程(我个人的理解,不一定正确):我们通常假设数据服从某种概率分布(通常为高斯分布),因此我们也可以把重建误差看作服从均值为0,方差为\sigma^{2}的高斯分布:

x-\hat{x}\sim \mathcal{N}(0,\sigma^{2})

 其概率密度函数为:

p(x-\hat{x}|0,\sigma^{2})=\frac{1}{\sqrt{2\pi}\sigma^{2}}\text{exp}(-\frac{[(x-\hat{x})-0]^{2}}{2\sigma^{2}})

对两边取负对数:

-\log{p(x-\hat{x}|0,\sigma^{2})}=\frac{(x-\hat{x})^{2}}{2\sigma^{2}}+\frac{1}{2}\log{2\pi\sigma^{2}}

不考虑常数部分的话进一步可化简为:

-\log{p(x-\hat{x}|0,\sigma^{2})}=\frac{(x-\hat{x})^{2}}{\sigma^{2}}+\log{\sigma^{2}}

上述公式的具体含义是每个像素位置重建误差,但注意,代码实现里假设每个像素位置的方差是一样的,即默认图像所有区域重建的不确定性(难度)是一样的,这显然是一种简易的方法。我们可以将其看作是对重建误差的缩放,一方面,第二项正则项要求方差越小越好;另一方面,也进一步约束第一项中分子部分的重建误差越小越好。对数方差作为模型的参数是可学习的。

但这里还有个小问题,按照上述公式,重建误差应该使用L2损失,但代码中则使用了L1损失,具体为什么我没有搞懂。。。。

问题3

为什么要自适应的计算对抗损失的权重系数?

首先说答案:目的是为了平衡负对数损失和对抗损失,避免由单个损失主导模型优化的方向,稳定训练过程。结合下面的源码来了解动态权重是如何计算的:

def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):if last_layer is not None:# 得到nll_loss对于last_layer参数的梯度nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]else:nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]# 根据两者梯度的比值计算对抗损失的权重d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)# 限制权重的大小范围d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()d_weight = d_weight * self.discriminator_weightreturn d_weight

在讲解源码具体作用前,我们首先需要回忆一下损失函数在模型训练过程中的作用:

在模型优化过程中,起决定作用的是损失函数的梯度,而不是具体的损失函数数值。以爬山举例:损失函数的数值表示你现在的海拔,梯度则表示你所处位置有多陡峭;如果你在山顶(loss很大),但脚下很平缓(梯度很小),那么对应的模型参数更新缓慢;如果你在半山腰(loss不大),但坡很陡峭(梯度很大),对应模型参数更新较快。如果存在多个损失函数,那么这多个损失函数的梯度共同决定了模型参数优化的方向(想象向量相加的结果)和大小(不考虑学习率的话)。结合具体任务,重建损失(图像像素差异)和对抗损失(真实性)相当于从不同角度评价了生成图像的质量。

回到代码,语句torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]可得到负对数损失nll_loss对于last_layer模型参数的梯度nll_grads,在实际训练过程中,last_layer传入的是VAE-decoder最后一层参数,因为这一层离图像空间最近;同理也可得到对抗损失g_loss对于last_layer的梯度g_grads。对抗损失权重计算公式为d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)。如果没有g_grads很大,那么模型优化方向主要由判别损失主导,反之则会由负对数损失主导,因此需要在训练过程中动态的计算判别损失的权重,避免让某个损失主导模型优化方向。当然这相当于一个训练的trick,并不是必须的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/86880.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/86880.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/86880.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MIT 6.824学习心得(1) 浅谈分布式系统概论与MapReduce

一个月前机缘巧合,有朋友向我推荐了麻省理工学院非常著名的分布式系统课程MIT 6.824,是由世界五大黑客之一,蠕虫病毒之父Robert Morris教授进行授课。由于我自己也在做基于分布式微服务架构的业务项目,所以对构建分布式系统这个课…

PCL点云库入门(第21讲)——PCL库点云特征之RSD特征描述Radius-based Surface Descriptor(RSD)

一、算法原理 RSD: Radius-based Surface Descriptor由 Marton Zsolt et al. 于 2010 年提出,主要用于 点云中物体的几何形状识别(如球形、柱面、平面等),广泛用于机器人抓取、点云分割和物体识别等任务中。 1.1、RSD 特征的核心…

zookeeper Curator(4):分布式锁

文章目录 分布式锁分布式锁的实现zookeeper 分布式锁原理Curator 实现分布式锁API1. InterProcessMutex(分布式可重入互斥锁)2. InterProcessSemaphoreMutex(分布式非可重入互斥锁)3. InterProcessReadWriteLock(分布式…

设置方法区内存的大小

方法区内存配置 方法区(Method Area)是JVM内存模型的一部分,用于存储类信息、常量、静态变量等数据。在HotSpot虚拟机中,方法区的具体实现为永久代(PermGen)或元空间(Metaspace)&am…

用Flink打造实时数仓:生产环境中的“坑”与“解药”

目录 一、实时数仓的“野心”与“现实” 二、数据采集与接入:别让“源头”卡脖子 2.1 问题1:Kafka数据乱序与延迟 2.2 问题2:MySQL CDC数据同步异常 三、数据处理与计算:别让“算力”成瓶颈 3.1 问题3:多表Join性能低下 3.2 问题4:窗口计算触发延迟 四、状态管理与…

linux 下 Doris 单点部署

目录 1. Doris 下载 2. 环境准备 2.1 Linux 操作系统版本需求 2.2 部署依赖 3. Doris 部署 3.1 修改系统配置 3.1.1 修改系统句柄数 3.1.2 关闭swap分区 3.1.3 修改最大内存映射区域数量 3.2 开放端口 3.3 fe 部署 3.4 be 部署 3.5 be添加到Doris集群 4 验证 4.…

mysql 小版本升级实战分享

环境说明 当前版本:5.6.51 升级目标版本 mysql 5.7.41 服务启停通过systemd管理 升级准备: 环境检查 首先查看当前MySQL的版本信息,执行命令mysql -V,如图: 备份数据 备份所有数据库: 当数据量不是特别大的时候…

Python Ai语音识别教程

语音识别是将人类语音转换为文本的技术,在现代应用中非常有用。本教程将介绍如何使用Python实现基本的AI语音识别功能。 一、文字转语音 #文字转语音 #安装第三方库 pip install pyttsx3 #导包 : import pyttsx3import pyttsx3#创建语音引擎 a1 pytts…

Day11 制作窗口

文章目录 1. 显示窗口(harib08d)2. 消除闪烁1(harib08g)3. 消除闪烁2(harib08h) 本章的前三节做了如下修改: 解决了鼠标无法隐藏在屏幕右侧和下侧的问题。当鼠标隐藏在右侧时会在屏幕最左侧产生…

python+uniapp基于微信小程序蜀味道江湖餐饮管理系统nodejs+java

文章目录 具体实现截图本项目支持的技术路线源码获取详细视频演示:文章底部获取博主联系方式!!!!本系统开发思路进度安排及各阶段主要任务java类核心代码部分展示主要参考文献:源码获取/详细视频演示 ##项目…

postgresql增量备份系列二 pg_probackup

已经很久没有发文章了,主要是最近工作上的内容都不适合发文章公开。可能往后文章发表也不这么频繁了,不过大家有问题我们可以交流。之前有写过PG增量备份的其他工具使用方法,pg_probackup也是应用比较多的PG备份工具。 一. pg_probackup pg_probackup 是一个用于管理 Postg…

云手机主要是指什么?

云手机是指一种可以运行在云服务器中的手机,主要是将云计算技术运用于网络终端服务,通过云服务器来实现云服务的手机,也是一款深度结合了网络服务的手机,通过自带的系统和网络终端可以通过网络实现众多功能。 那么,下面…

CAU数据挖掘 支持向量机

SVM大致思想 线性分类问题 在一群点中用线性函数分类: 但也有线性不可分问题: 线性不可分问题: 最大间隔法 两个平行超平面间隔距离最大 软间隔 部分难以区分的点忽略 升维 通过升维将非线性变为线性 计算统计理论基础 学习过…

探索理解 Spring AI Advisors:构建可扩展的 AI 应用

Spring AI Advisors API 提供了一种灵活且强大的方式来拦截、修改和增强 Spring 应用程序中的 AI 驱动交互。其核心思想类似于 Spring AOP(面向切面编程)中的“通知”(Advice),允许开发者在不修改核心业务逻辑的情况下…

Linux SSH服务全面配置指南:从基础到安全加固

Linux SSH服务全面配置指南:从基础到安全加固 概述 作为网络安全工程师,SSH(Secure Shell)服务的安全配置是我们日常工作中不可忽视的重要环节。本文将从基础配置到高级安全加固,全面解析SSH服务的各项参数&#xff…

.NET测试工具Parasoft dotTEST内置安全标准,编码合规更高效

在追求开发速度的时代,确保代码安全并满足严苛的行业合规标准如OWASP、CWE、PCI DSS、ISO 26262等已成为开发者的核心挑战,但开发人员常因复杂的编码标准和漏洞排查而效率低下。.NET测试工具Parasoft dotTEST内置安全标准,实现即插即用&#…

对象的finalization机制Test

Java语言提供了对象终止(finalization)机制来允许开发人员自定义对象被销毁之前的处理逻辑。当垃圾回收器发现没有引用指向一个对象时,通常接下来要做的就是垃圾回收,即清除该对象,而finalization机制使得在清除此对象之前,总会先…

AI初学者如何对大模型进行微调?——零基础保姆级实战指南

仅需8GB显存,三步完成个人专属大模型训练 四步实战:从环境配置到模型发布 步骤1:云端环境搭建(10分钟) 推荐使用阿里魔塔ModelScope免费GPU资源: # 注册后执行环境初始化 pip3 install --upgrade pip pi…

“单一职责”模式之装饰器模式

目录 “单一职责”模式装饰器模式 Decorator引例动机 Motivation模式定义结构 Structure要点总结 “单一职责”模式 在软件组件的设计中,如果责任划分的不清晰,使用继承得到的结果往往是随着需求的变化,子类急剧膨胀,同时充斥着重…

idea, CreateProcess error=206, 文件名或扩展名太长

idea, CreateProcess error206, 文件名或扩展名太长 解决 “CreateProcess error206, 文件名或扩展名太长” 错误 CreateProcess error206 是 Windows 系统特有的错误,表示命令行参数超出了 Windows 的 32767 字符限制。这个问题在 Java 开发中尤其常见&#xff0c…