一、核心思想:通过概率分布惩罚错误


交叉熵损失的本质是:
比较模型预测的概率分布 vs 真实标签的概率分布,惩罚两者之间的差异。

例如:

  • 真实标签:图像 0 → 文本 0(独热编码 [1, 0, 0, ...])
  • 模型预测:[0.1, 0.2, 0.3, 0.4, ...](预测文本 0 的概率仅 0.1)

此时损失会很大,因为预测分布与真实分布差异大。

二、分步解析交叉熵惩罚机制


1. 相似度矩阵 → 概率分布


假设 sim_i2t 是一个 [3, 6] 的矩阵(3 个图像 × 6 个文本):

# 示例相似度矩阵(简化版,仅展示对角线高相似度)
sim_i2t = torch.tensor([[5.0, 1.0, 1.0, 1.0, 1.0, 1.0],  # 图像0 → 文本0是正样本[1.0, 5.0, 1.0, 1.0, 1.0, 1.0],  # 图像1 → 文本1是正样本[1.0, 1.0, 5.0, 1.0, 1.0, 1.0]   # 图像2 → 文本2是正样本
])

通过 softmax 将相似度转换为概率分布:

probs = F.softmax(sim_i2t, dim=1)  # 对每行做softmax
print(probs)

输出结果:

tensor([[0.94, 0.02, 0.02, 0.02, 0.02, 0.02],  # 预测文本0概率最高(正确)[0.02, 0.94, 0.02, 0.02, 0.02, 0.02],  # 预测文本1概率最高(正确)[0.02, 0.02, 0.94, 0.02, 0.02, 0.02]   # 预测文本2概率最高(正确)
])

2. 真实标签的概率分布


假设 targets = [0, 1, 2],转换为独热编码:

# 独热编码(简化版,仅展示核心逻辑)
one_hot = torch.zeros_like(probs)
for i, t in enumerate(targets):one_hot[i, t] = 1.0print(one_hot)

输出结果

tensor([[1.0, 0.0, 0.0, 0.0, 0.0, 0.0],  # 图像0的正样本是文本0[0.0, 1.0, 0.0, 0.0, 0.0, 0.0],  # 图像1的正样本是文本1[0.0, 0.0, 1.0, 0.0, 0.0, 0.0]   # 图像2的正样本是文本2
])

3. 计算交叉熵损失

交叉熵损失公式:

对于上述例子:

  • 图像 0 的损失:-log(0.94) ≈ 0.06
  • 图像 1 的损失:-log(0.94) ≈ 0.06
  • 图像 2 的损失:-log(0.94) ≈ 0.06

平均损失:(0.06 + 0.06 + 0.06) / 3 ≈ 0.06

实际函数内部:

# 1. 对预测值应用softmax,转换为概率分布
probs = F.softmax(sim_i2t, dim=1)# 2. 对每个样本,取出目标类别对应的概率
# 例如:
# - 第0个样本的目标类别是0,取出probs[0, 0]
# - 第1个样本的目标类别是1,取出probs[1, 1]
# - 第2个样本的目标类别是2,取出probs[2, 2]
target_probs = probs[torch.arange(len(targets)), targets]# 3. 计算负对数似然
nll = -torch.log(target_probs)# 4. 求平均值得到最终损失
loss = nll.mean()

三、标签平滑如何调整惩罚


标签平滑(label_smoothing=0.1)会将:

  • 正样本的概率从 1.0 调整为 0.9
  • 负样本的概率从 0.0 调整为 0.1 / (类别数-1)

例如,对于图像 0(正样本是文本 0):

  • 原始标签:[1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
  • 平滑后标签:[0.9, 0.02, 0.02, 0.02, 0.02, 0.02]

此时损失计算变为:

实际函数内部:当使用label_smoothing=0.1时,函数内部会将目标概率分布从严格的独热编码调整为平滑分布:

def cross_entropy_with_label_smoothing(logits, targets, smoothing=0.1):num_classes = logits.size(1)# 计算平滑后的目标分布# - 正样本概率: 1.0 - smoothing + (smoothing / num_classes)# - 负样本概率: smoothing / num_classessmooth_targets = torch.full_like(logits, smoothing / (num_classes - 1))smooth_targets[torch.arange(len(targets)), targets] = 1.0 - smoothing + (smoothing / num_classes)# 对预测值应用log_softmaxlog_probs = F.log_softmax(logits, dim=1)# 计算交叉熵(等价于F.kl_div(log_probs, smooth_targets))loss = (-smooth_targets * log_probs).sum(dim=1).mean()return loss

四、惩罚机制可视化


假设模型预测错误(图像 0 预测文本 1 的概率最高):

# 错误预测的情况
bad_probs = torch.tensor([[0.1, 0.8, 0.05, 0.05, 0.0, 0.0],  # 错误:预测文本1概率最高[0.02, 0.94, 0.02, 0.02, 0.02, 0.0],  # 正确[0.02, 0.02, 0.94, 0.02, 0.02, 0.0]   # 正确
])# 计算交叉熵损失(无标签平滑)
loss = -torch.log(bad_probs[0, 0])  # 图像0的损失:-log(0.1) ≈ 2.3
print(f"错误预测的损失: {loss.item():.4f}")  # 损失远大于正确预测的0.06

输出结果:

错误预测的损失: 2.3026

五、总结


交叉熵损失的惩罚机制是:

  • 对正样本:预测概率越低,惩罚越大(损失呈对数增长)
  • 对负样本:预测概率越高,惩罚越大
  • 标签平滑:减轻对极端预测的惩罚,防止过拟合

通过这种方式,模型被强制学习到:

  • 正样本对的相似度要尽可能高
  • 负样本对的相似度要尽可能低

这就是对比学习中 “拉近正样本、推远负样本” 的核心实现方式!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/89511.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/89511.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/89511.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

测试学习之——Pytest Day3

引言Pytest 作为 Python 中最受欢迎的测试框架之一,以其简洁的语法、强大的功能和丰富的插件生态系统,极大地提升了自动化测试的效率和可维护性。在本文中,我们将深入探讨 Pytest 的两大核心特性:Fixture 和插件管理,帮…

控制Vue对话框显示隐藏

正确做法 — 使用 Vue 数据驱动控制显隐你不需要手动设置 display: block&#xff0c;因为 Element Plus 的 <el-dialog> 是基于 v-model 或 :visible.sync 控制的。&#x1f527; 修改模板部分&#xff1a;将原来的&#xff1a;<el-dialog title"报文详情"…

直播带货与开源AI智能名片链动2+1模式S2B2C商城小程序:重塑电商营销新格局

摘要&#xff1a;本文聚焦于直播带货对互联网供需关系的深刻影响&#xff0c;分析其如何改变传统电商营销模式&#xff0c;实现从“人找货”到“货找人”的转变。同时&#xff0c;引入开源AI智能名片链动21模式S2B2C商城小程序这一创新概念&#xff0c;探讨其在直播带货背景下的…

Jmeter 性能测试响应时间过长怎么办?

当 JMeter 性能测试中出现 响应时间过长 的问题时&#xff0c;需要从 测试脚本、服务器、网络、JMeter配置 等多方面排查和优化。以下是详细的解决步骤和思路&#xff1a; B站最新性能进阶&#xff0c;学会这些jmeter性能测试技能&#xff0c;更助于正确设计、执行和分析性能测…

COZE官方文档基础知识解读第三期 —— prompt(提示词)

COZE官方文档基础知识解读第三期 —— prompt&#xff08;提示词&#xff09; 对于初步接触PE&#xff08;prompt engineering&#xff09; 的小伙伴们&#xff0c;你们可以去火山方舟提供的prompt工具&#xff0c;用工具&#xff08;其余的prompt网站https://www.promptinggu…

代码随想录算法训练营第三十二天|动态规划理论基础、LeetCode 509. 斐波那契数、70. 爬楼梯、746. 使用最小花费爬楼梯

目录 LeetCode 509. 斐波那契数 70. 爬楼梯 746. 使用最小花费爬楼梯 感想 文档讲解&#xff1a;代码随想录 动态规划&#xff0c;英文&#xff1a;Dynamic Programming&#xff0c;简称DP&#xff0c;如果某一问题有很多重叠子问题&#xff0c;使用动态规划是最有效的。 …

SpringMVC3

一、JSON 与参数传递1.1JSON 是什么- JSON 是字符串&#xff1a;比如 {"name":"zhangsan","password":"123456","age":15} 就是一个 JSON 字符串&#xff0c;它用来在前后端、服务间传递数据。- JSON 库&#xff1a;Fastj…

查看.bin二进制文件的方式(HxD十六进制编辑器的安装)

文章目录Windows 系统上安装 HxD 十六进制编辑器的步骤。**HxD 是一款免费、轻量级的工具&#xff0c;适合查看和编辑 .bin 等二进制文件。****PS:实际安装过程中会发现找不到Windows11的版本&#xff0c;安装windows10的即可&#xff0c;并且没有区别setup版和portable版**安装…

Linux系统性能优化与监控

系统性能优化与监控是保障 Linux 服务器稳定运行的核心技术&#xff0c;涉及 ​​CPU、内存、磁盘 I/O、网络、进程​​ 等多维度的指标分析、问题定位与优化策略。以下从​​监控工具与指标​​、​​常见问题诊断​​、​​优化方法​​三个层面详细讲解&#xff0c;并结合​…

如何在 React + TypeScript 中实现 JSON 格式化功能

如何在 React TypeScript 中实现 JSON 格式化功能 作为前端开发者&#xff0c;我们经常需要处理 JSON 数据。无论是 API 调试、配置文件编辑还是数据转换&#xff0c;能够格式化 JSON 是一项基本但非常有用的技能。本文将详细介绍如何在 React 和 TypeScript 环境中实现 JSON…

Mac连接服务器Docker容器全攻略

苹果电脑( macOS 系统 )连接服务器、配置容器,整体思路和 Linux 终端操作更贴近,以下结合 macOS 特点,详细分步说明,以 Docker 容器 + 常见 Linux 服务器( 如 CentOS、Ubuntu )为例: 一、连接服务器(SSH 方式, macOS 终端原生支持 ) 1. 准备信息 找运维或云平台…

【字节跳动】数据挖掘面试题0019:带货直播间推荐:现在有一个带货的直播间,怎么把它精准地推送给有需要的用户

文章大纲 带货直播间推荐系统:原理、算法与实践 一、推荐系统在带货直播中的重要性 二、数据收集与处理 1. 用户数据 2. 直播间数据 3. 用户行为数据 4. 数据处理与特征工程 三、推荐算法实现 1. 基于内容的推荐 2. 基于协同过滤的推荐 3. 基于知识图谱的推荐 4. 混合推荐算法…

Windows10笔记本电脑开启BIOS

文章目录什么是BIOS一、方案一&#xff1a;快捷键进入二、方案二&#xff08;推荐&#xff09;各品牌快捷键大全什么是BIOS BIOS 全拼为 BasicInputOutputSystem, 即基本输入/输出系统,是计算机中非常基础而且重要的程序。把这一段程序存放在一个不需要电源的记忆体(芯片)中,就…

NFS、iSCSI 和lnmp部署操作

目录 &#xff08;一&#xff09;基础配置 1.NFS服务安装 2.修改配置文件 3.重载配置文件 4.查看共享目录 5.客户端挂载 6.更换共享目录 7.基础实验 &#xff08;二&#xff09;布置lnmp平台 1.php 安装软件 检测 2.连接MySQL 测试 3.软件实施 软件安装配置 &…

Redis深度解析:从缓存原理到高并发实战

第一部分&#xff1a;Redis核心概念与架构设计1.1 Redis本质解析Redis&#xff08;Remote Dictionary Server&#xff09;作为开源的内存数据结构存储系统&#xff0c;其核心价值在于&#xff1a;内存优先架构&#xff1a;数据主要存储在内存中&#xff0c;读写性能达到10万 QP…

【NLP舆情分析】基于python微博舆情分析可视化系统(flask+pandas+echarts) 视频教程 - 微博类别信息爬取

大家好&#xff0c;我是java1234_小锋老师&#xff0c;最近写了一套【NLP舆情分析】基于python微博舆情分析可视化系统(flaskpandasecharts)视频教程&#xff0c;持续更新中&#xff0c;计划月底更新完&#xff0c;感谢支持。今天讲解架构搭建 视频在线地址&#xff1a; 2026…

GD32/STM32嵌入CMSIS-DSP的库(基于Keil)

当你要用到三角函数、开方、矩阵运算等复杂的数学运算时&#xff0c;可以选择用C库的math.h里面的函数&#xff0c;如果要求速度快的话就得用CMSIS-DSP库里面的函数了&#xff0c;因为CMSIS-DSP库充分运用了CM4内核的浮点运算单元&#xff08;若有&#xff09;和DSP相关的指令&…

页面登录阻止浏览器提醒是否保存密码

一、原因 使用input的type"password"类型&#xff0c;浏览器会提醒是否记住密码。 二、解决 取消type"password" 三、实现输入密码*代替 通过input输入框&#xff0c;监听输入值&#xff0c;进行替换成*符号&#xff0c;避免使用input的type"password…

【iOS】dyld加载流程——应用程序的加载

目录 前言 编译过程与动静态库 编译过程 动静态库 dyld &#x1f4cc; 什么是 dyld&#xff1f; dyld_shared_cache: dyld加载流程 _dyld_start dyldbootstrap::start dyld::main() 配置环境变量 共享缓存 主程序的初始化 插入动态库 link主程序 link动态库 弱…

从零开始,手把手教你本地部署Stable Diffusion AI绘画(Win最新版)

本号之前有发过一篇win平台的教程&#xff0c;由于是去年10月发布的&#xff0c;而Al绘画技术发展很快&#xff0c;那篇教程已经有些不适用了&#xff0c;有些同学执行到第二步就出错了。 应广大同学的期望&#xff0c;我更新一版新版详细教程。 一、前言 1.为什么要本地部署…