本文是对DISTS图像质量评价指标的代码解读,原文解读请看DISTS文章讲解。
本文的代码来源于IQA-Pytorch工程。

1、原文概要

以前的一些IQA方法对于捕捉纹理上的感知一致性有所欠缺,鲁棒性不足。基于此,作者开发了一个能够在图像结构和图像纹理上都具有与人类相同感知判断的指标,在此之上,还希望纹理能够resample(不需要像素级对齐)之后也是一样的,另外区分开退化(JPEG,JPEG会损失纹理)。实现该指标可以分为4个步骤:

  1. 对图像进行一个初始的变换,从像素空间变换到特征空间。
  2. 对特征提取所谓纹理的表示,对特征提取所谓结构的表示。
  3. 利用纹理和结构的表示,加入一些可学习的权重综合计算一个评价指标。
  4. 利用这个评价指标,进一步优化权重得到纹理区域resample不敏感的指标,且能够有结构和纹理上做感知相似度的模型。

实现后的指标作为优化指标对比其他IQA指标有明显优势,如下图所示。
在这里插入图片描述

2、代码结构

代码实现位于pyiqa/archs/dists_arch.py中
在这里插入图片描述

3 、核心代码模块

L2pooling

这个类实现了我们前面提到的预处理部分替换max-pool的操作。

class L2pooling(nn.Module):def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):super(L2pooling, self).__init__()self.padding = (filter_size - 2) // 2self.stride = strideself.channels = channelsa = np.hanning(filter_size)[1:-1]g = torch.Tensor(a[:, None] * a[None, :])g = g / torch.sum(g)self.register_buffer('filter', g[None, None, :, :].repeat((self.channels, 1, 1, 1)))def forward(self, input):input = input**2out = F.conv2d(input,self.filter,stride=self.stride,padding=self.padding,groups=input.shape[1],)return (out + 1e-12).sqrt()

这里可以看到前向的过程中作者先是进行了一个平方,然后使用了一个self.filter的滤波器,kernel_size为3的hanning窗,stride=2,且是一个深度可分离的卷积,groups与输入通道一致,这代替max-pool完成了一次抗混叠的下采样,最后进行一个sqrt,这与讲解中展示的公式一致,如下所示:
P(x)=g∗(x∗x)P(x)=\sqrt{g*(x*x)}P(x)=g(xx)这个ggg在初始化时被复制了self.channels次,实际它一个通道的数值,读者可以打印如下所示:
[0.06250.1250.06250.1250.250.1250.06250.1250.0625]\begin{bmatrix} 0.0625 & 0.125 & 0.0625 \\ 0.125 & 0.25 & 0.125 \\ 0.0625 & 0.125 & 0.0625 \end{bmatrix} 0.06250.1250.06250.1250.250.1250.06250.1250.0625一个典型的低通滤波器,做了一个空间上根据距离的平均。

DISTS

存放着跟实际计算指标相关的代码。

@ARCH_REGISTRY.register()
class DISTS(torch.nn.Module):r"""DISTS model.Args:pretrained_model_path (String): Pretrained model path."""def __init__(self, pretrained=True, pretrained_model_path=None, **kwargs):"""Refer to official code https://github.com/dingkeyan93/DISTS"""super(DISTS, self).__init__()vgg_pretrained_features = models.vgg16(weights='IMAGENET1K_V1').featuresself.stage1 = torch.nn.Sequential()self.stage2 = torch.nn.Sequential()self.stage3 = torch.nn.Sequential()self.stage4 = torch.nn.Sequential()self.stage5 = torch.nn.Sequential()for x in range(0, 4):self.stage1.add_module(str(x), vgg_pretrained_features[x])self.stage2.add_module(str(4), L2pooling(channels=64))for x in range(5, 9):self.stage2.add_module(str(x), vgg_pretrained_features[x])self.stage3.add_module(str(9), L2pooling(channels=128))for x in range(10, 16):self.stage3.add_module(str(x), vgg_pretrained_features[x])self.stage4.add_module(str(16), L2pooling(channels=256))for x in range(17, 23):self.stage4.add_module(str(x), vgg_pretrained_features[x])self.stage5.add_module(str(23), L2pooling(channels=512))for x in range(24, 30):self.stage5.add_module(str(x), vgg_pretrained_features[x])for param in self.parameters():param.requires_grad = Falseself.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1))self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1))self.chns = [3, 64, 128, 256, 512, 512]self.register_parameter('alpha', nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))self.register_parameter('beta', nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))self.alpha.data.normal_(0.1, 0.01)self.beta.data.normal_(0.1, 0.01)if pretrained_model_path is not None:load_pretrained_network(self, pretrained_model_path, False)elif pretrained:load_pretrained_network(self, default_model_urls['url'], False)def forward_once(self, x):h = (x - self.mean) / self.stdh = self.stage1(h)h_relu1_2 = hh = self.stage2(h)h_relu2_2 = hh = self.stage3(h)h_relu3_3 = hh = self.stage4(h)h_relu4_3 = hh = self.stage5(h)h_relu5_3 = hreturn [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]def forward(self, x, y):r"""Compute IQA using DISTS model.Args:- x: An input tensor with (N, C, H, W) shape. RGB channel order for colour images.- y: An reference tensor with (N, C, H, W) shape. RGB channel order for colour images.Returns:Value of DISTS model."""feats0 = self.forward_once(x)feats1 = self.forward_once(y)dist1 = 0dist2 = 0c1 = 1e-6c2 = 1e-6w_sum = self.alpha.sum() + self.beta.sum()alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)beta = torch.split(self.beta / w_sum, self.chns, dim=1)for k in range(len(self.chns)):x_mean = feats0[k].mean([2, 3], keepdim=True)y_mean = feats1[k].mean([2, 3], keepdim=True)S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1)dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True)x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)xy_cov = (feats0[k] * feats1[k]).mean([2, 3], keepdim=True) - x_mean * y_meanS2 = (2 * xy_cov + c2) / (x_var + y_var + c2)dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True)score = 1 - (dist1 + dist2)return score.squeeze(-1).squeeze(-1)

3个重点如下:

  1. 初始化中首先会插入前面讲到的L2_Pooling,来替换原始的max-pool,其他的就是初始化必要的标准化变量和用于各层结构和纹理的加权系数α\alphaαβ\betaβ,最后导入预训练的网络即可。
  2. 前向中调用的forward_once,可以看到总共有6个输出,第一个输出是输入x,即我们讲解中提到的identity的变换,其他5层是事先定义好的输出位置。
  3. dists的计算:首先根据权重的大小对alpha和beta进行归一化,随后分层计算我们前面定义好的纹理特征和结构特征的相关性公式,针对于纹理的部分代码中是S1,可以看到S1是利用了特征的在空间上的均值计算的参考图像和待评估图像的相关系数,然后利用alpha对计算好的S1进行加权,得到纹理上相似度dist1;针对于结构的部分代码中是S2,S2是利用了参考图像和待评估图像两个特征的协方差和方差,由于是全局的窗口所以在计算后会求取空间上的一个均值,这样得到了结构上的相似度dist2。最后结合dist1和dist2得到最终的score。dists计算的公式如下,可以对照着公式来查看:
    l(x~j(i),y~j(i))=2μx~j(i)μy~j(i)+c1(μx~j(i))2+(μy~j(i))2+c1l(\tilde{x}_j^{(i)}, \tilde{y}_j^{(i)}) = \frac{2\mu_{\tilde{x}_j}^{(i)}\mu_{\tilde{y}_j}^{(i)} + c_1}{(\mu_{\tilde{x}_j}^{(i)})^2 + (\mu_{\tilde{y}_j}^{(i)})^2 + c_1}l(x~j(i),y~j(i))=(μx~j(i))2+(μy~j(i))2+c12μx~j(i)μy~j(i)+c1 s(x~j(i),y~j(i))=2σx~jy~j(i)+c2(σx~j(i))2+(σy~j(i))2+c2,s(\tilde{x}_j^{(i)}, \tilde{y}_j^{(i)}) = \frac{2\sigma_{\tilde{x}_j\tilde{y}_j}^{(i)} + c_2}{(\sigma_{\tilde{x}_j}^{(i)})^2 + (\sigma_{\tilde{y}_j}^{(i)})^2 + c_2},s(x~j(i),y~j(i))=(σx~j(i))2+(σy~j(i))2+c22σx~jy~j(i)+c2, D(x,y;α,β)=1−∑i=0m∑j=1ni(αijl(x~j(i),y~j(i))+βijs(x~j(i),y~j(i)))D(x, y; \alpha, \beta) = 1 - \sum_{i = 0}^{m} \sum_{j = 1}^{n_i} \left( \alpha_{ij} l(\tilde{x}_j^{(i)}, \tilde{y}_j^{(i)}) + \beta_{ij} s(\tilde{x}_j^{(i)}, \tilde{y}_j^{(i)}) \right)D(x,y;α,β)=1i=0mj=1ni(αijl(x~j(i),y~j(i))+βijs(x~j(i),y~j(i)))其中,lllsss分别代表纹理和结构。

3、总结

代码实现核心的部分讲解完毕,DISTS作为一个可以同时捕获结构和纹理相似度的全参考IQA指标,在很多比赛和论文的引用中都可以见到它的身影,实用性是毋庸置疑的。
大家有涉及到数据集筛选、纹理分类、纹理搜索类的任务可以尝试使用DISTS指标,或者是在算法评估中利用它来做一个方面的对比评估。


感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/91405.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/91405.shtml
英文地址,请注明出处:http://en.pswp.cn/web/91405.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024年SEVC SCI2区,一致性虚拟领航者跟踪群集算法GDRRT*-PSO+多无人机路径规划,深度解析+性能实测

目录1.摘要2.算法背景3.GDRRT*-PSO与虚拟领航者跟踪算法4.结果展示5.参考文献6.算法辅导应用定制读者交流1.摘要 随着无人机技术的快速发展及其卓越的运动和机动性能,无人机在社会和军事等诸多领域得到了广泛应用。多无人机协同作业,能够显著提升任务执…

链特异性文库是什么?为什么它在转录组测序中越来越重要?

链特异性文库是什么?为什么它在转录组测序中越来越重要? 在现代分子生物学研究中,RNA测序(RNA-seq) 是一种广泛应用的技术,用于分析基因在不同条件下的表达情况。而在RNA-seq的众多技术细节中,有…

ClickHouse vs PostgreSQL:数据分析领域的王者之争,谁更胜一筹?

文章概要 作为一名数据架构师,我经常被问到一个问题:在众多数据库选择中,ClickHouse和PostgreSQL哪一个更适合我的项目?本文将深入探讨这两种数据库系统的核心差异、性能对比、适用场景以及各自的优缺点,帮助您在技术选…

面向对象系统的单元测试层次

面向对象系统的单元测试层次面向对象(Object-Oriented, OO)编程范式引入了封装、继承和多态等核心概念,这使得传统的、基于函数的单元测试方法不再充分。面向对象系统的单元测试必须适应其独特的结构和行为特性,从单一方法扩展到类…

如何用USRP捕获手机信号波形(上)系统及知识准备

目录: 如何用USRP捕获手机信号波形(上)系统及知识准备 如何用USRP捕获手机信号波形(中)手机/基站通信 如何用USRP捕获手机信号波形(下)协议分析 一、手机通信参数获取 首先用Cellular-z网络…

C语言-数组:数组(定义、初始化、元素的访问、遍历)内存和内存地址、数组的查找算法和排序算法;

本章概述思维导图:C语言数组在C语言中,数组是一种固定大小的、相同类型元素的有序集合,通过索引(下标)访问。数组数组:是一种容器,可以用来存储同种数据类型的多个值;数组特点&#…

河南萌新联赛2025第(二)场:河南农业大学(补题)

文章目录前言A.约数个数和整除分块(相当于约数求和)相关例题:取模B.异或期望的秘密二进制的规律相关例题累加器小蓝的二进制询问乘法逆元1. 概念2.基本定义3.费马小定理1.定理内容2.重要推论D.开罗尔网络的备用连接方案E.咕咕嘎嘎!!!(easy)I.猜数游戏(easy)K.打瓦M.…

常见中间件漏洞

一、TomcatTomcat put方法任意文件写入漏洞环境搭建,启动时端口被占用就改yml配置文件,改成8081端口。(我这里是8080)cd vulhub-master/tomcat/CVE-2017-12615 docker-compose up -d 去抓包,改成put提交。下面的内容是用哥斯拉生成的木马文件…

27.(vue3.x+vite)以pinia为中心的开发模板(监听watch)

效果截图 代码实现: HelloWorld.vue <template><div style="padding: 20px">介绍:<br />1:使用统一的 watch 来监听store的值。<br

Jenkins 详解

Jenkins 是一个开源的持续集成和持续交付(CI/CD)工具&#xff0c;用于自动化软件开发过程中的构建、测试和部署阶段。以下是关于 Jenkins 的详细介绍&#xff1a; 1. Jenkins 核心概念 1.1 持续集成(CI) 开发人员频繁地将代码变更提交到共享仓库每次提交都会触发自动构建和测试…

动态配置实现过程

查看DCCValueBeanFactory类的完整实现&#xff0c;了解动态配置的实现过程 动态配置实现过程 1. 自定义注解 使用DCCValue注解标记需要动态配置的字段&#xff0c;格式为key:defaultValue&#xff1a; DCCValue("downgradeSwitch:0") private String downgradeSw…

【大模型理论篇】跨语言AdaCOT

参考&#xff1a;AdaCoT: Rethinking Cross-Lingual Factual Reasoning throughAdaptive Chain-of-ThoughtAdaCoT&#xff08;Adaptive Chain-of-Thought&#xff0c;自适应思维链&#xff09;是一项提升大型语言模型&#xff08;LLMs&#xff09;跨语言事实推理能力的新框架。…

vue3项目搭建

前一段时间招聘前端开发,发现好多开发连基本的创建项目都不会,这里总结一下 在Vue 3中,使用Webpack和Vite创建的项目文件结构及语言(JS/TS)的选择有以下主要区别: 1. 创建方式与文件结构差异 方式一、Webpack(Vue CLI) 创建命令: vue create project-name 典型文件结构…

企业签名的多种形式

企业签名有多种形式&#xff0c;可分为企业签名独立版、企业签名稳定版、企业签名共享版等。每一种形式的企业签名都有其独特的特点&#xff0c;其中&#xff1a;  企业签名独立版&#xff1a;其特性主要为稳定性较高&#xff0c;使用者可以通过控制APP的下载量来保证APP的稳…

解构远程智能系统的视频能力链:从RTSP|RTMP协议接入到Unity3D头显呈现全流程指南

在人工智能奔腾的2025年&#xff0c;WAIC&#xff08;世界人工智能大会&#xff09;释放出一个明确信号&#xff1a;视频能力已经成为通往“远程智能”的神经中枢。在无人机、四足机器人、远程施工、巡检等新兴场景中&#xff0c;一套可靠、低延迟、可嵌入头显设备的视频传输系…

Less Less基础

1.lessless是一种动态样式语言&#xff0c;属于CSS预处理器的范畴&#xff0c;它扩展了CSS语言&#xff0c;增加了变量&#xff0c;Mixin&#xff0c;函数等特性&#xff0c;使CSS更易维护和扩展。Less既可以在客户端上运行&#xff0c;也可以借助Node.js在服务端运行。2.Less中…

如何使用 Redis 实现 API 网关或单个服务的请求限流?

使用 Redis 高效实现 API 网关与服务的请求限流 在微服务架构中&#xff0c;对 API 网关或单个服务的请求进行速率限制至关重要&#xff0c;以防止恶意攻击、资源滥用并确保系统的稳定性和可用性。 Redis 凭借其高性能、原子操作和丰富的数据结构&#xff0c;成为实现请求限流的…

图片查重从设计到实现(7) :使用 Milvus 实现高效图片查重功能

使用 Milvus 实现高效图片查重功能本文将介绍如何利用 Milvus 向量数据库构建一个高效的图片查重系统&#xff0c;通过传入图片就能快速从已有数据中找出匹配度高的相似图片。一.什么是图片查重&#xff1f; 图片查重指的是通过算法识别出内容相同或高度相似的图片&#xff0c;…

诱导多能干细胞(iPSC)的自述

自十七年前诱导多能干细胞&#xff08;也称iPS细胞或iPSC&#xff09;技术出现以来&#xff0c;干细胞生物学和再生医学取得了巨大进展。人类iPSC已广泛用于疾病建模、药物发现和细胞疗法开发。新的病理机制已被阐明&#xff0c;源自iPSC筛选的新药正在研发中&#xff0c;并且首…

基于深度学习的医学图像分析:使用DeepLabv3+实现医学图像分割

前言 医学图像分析是计算机视觉领域中的一个重要应用&#xff0c;特别是在医学图像分割任务中&#xff0c;深度学习技术已经取得了显著的进展。医学图像分割是指从医学图像中识别和分割出特定的组织或器官&#xff0c;这对于疾病的诊断和治疗具有重要意义。近年来&#xff0c;D…