损失函数的调用

import torch
from torch import nn
from torch.nn import L1Lossinputs = torch.tensor([1.0,2.0,3.0])
target = torch.tensor([1.0,2.0,5.0])inputs = torch.reshape(inputs, (1, 1, 1, 3))
target = torch.reshape(target, (1, 1, 1, 3))
#损失函数
loss = L1Loss(reduction='sum')
#MSELoss均值方差
loss_mse = nn.MSELoss()
result1 = loss(inputs, target)
result2 = loss_mse(inputs, target)
print(result1, result2)

 实际应用

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2ddataset = torchvision.datasets.CIFAR10(root='./data_CIF', train=False, download=True, transform=torchvision.transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)class Tudui(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)self.maxpool1 = nn.MaxPool2d(kernel_size=2)self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)self.maxpool2 = nn.MaxPool2d(kernel_size=2)self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)self.maxpool3 = nn.MaxPool2d(kernel_size=2)self.flatten = nn.Flatten()self.linear1 = nn.Linear(in_features=1024, out_features=64)self.linear2 = nn.Linear(in_features=64, out_features=10)self.model1 = nn.Sequential(Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(in_features=1024, out_features=64),nn.Linear(in_features=64, out_features=10))def forward(self, x):x = self.model1(x)return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
for data in dataloader:imgs,targets = dataoutputs = tudui(imgs)result1 = loss(outputs, targets)print(result1)#反向传播result1.backward()#梯度grad会改变,从而通过grad来降低loss

torch.nn.CrossEntropyLoss 

🧩 CrossEntropyLoss 是什么?

本质上是:

Softmax + NLLLoss(负对数似然) 的组合。

公式:

\text{Loss} = - \sum_{i} y_i \log(\hat{p}_i)

  • \hat{p}_i​:模型预测的概率(通过 softmax 得到)

  • y_i​:真实类别的 one-hot 标签

PyTorch 不需要你手动做 softmax,它会直接从 logits(未经过 softmax 的原始输出)算起,防止数值不稳定。


🏷️ 常用参数

torch.nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean')

参数含义
weight给不同类别加权(处理类别不均衡)
ignore_index忽略某个类别(常见于 NLP 的 padding)
reductionmean(默认平均)、sum(求和)、none(逐个样本返回 loss)


🎨 最小使用例子

import torch
import torch.nn as nncriterion = nn.CrossEntropyLoss()# 假设 batch_size=3, num_classes=5
outputs = torch.tensor([[1.0, 2.0, 0.5, -1.0, 0.0],[0.1, -0.2, 2.3, 0.7, 1.8],[2.0, 0.1, 0.0, 1.0, 0.5]])  # logits
labels = torch.tensor([1, 2, 0])  # 真实类别索引loss = criterion(outputs, labels)
print(loss.item())
  • outputs:模型输出 logits,不需要 softmax;

  • labels:真实类别(索引型),如 0, 1, 2,...

  • loss.item():输出标量值。


💡 你需要注意:

⚠️ 重点📌 说明
logits 直接输入不要提前做 softmax
label 是类别索引不是 one-hot,而是整数(如 [1, 3, 0]
自动求 batch 平均默认 reduction='mean'
多分类用它最合适二分类也能用,但 BCEWithLogitsLoss 更常见


🎁 总结

优点缺点
✅ 简单强大,适合分类❌ 不适合回归任务
✅ 内置 softmax + log❌ label 不能是 one-hot
✅ 数值稳定性强❌ 类别极度不均衡需额外加 weight


🎯 一句话总结

CrossEntropyLoss 是深度学习中分类问题的“首选痛点衡量尺”,帮你用“正确标签”去教训“错误预测”,模型越聪明 loss 越小。

 公式:

 

1️⃣ 第一部分:

- \log \left( \frac{\exp(x[\text{class}])}{\sum_j \exp(x[j])} \right)

这是经典 负对数似然(Negative Log-Likelihood):

  • 分子:你模型对正确类别 class 输出的得分(logits),取 exp;

  • 分母:所有类别的 logits 做 softmax 归一化;

  • 再取负 log —— 意思是“你对正确答案预测得越自信,loss 越小”。


2️⃣ 推导为:

= - x[\text{class}] + \log \left( \sum_j \exp(x[j]) \right)

log(a/b) = log(a) - log(b) 的变形:

  • - x[\text{class}]:你对正确类输出的分值直接扣掉;

  • +\log(\sum_j \exp(x[j])):对所有类别的总分值做归一化。

这是交叉熵公式最常用的“log-sum-exp”形式。


📌 为什么这么写?

  • 避免直接用 softmax(softmax+log 合并后可以避免数值不稳定 🚀)

  • 计算量更高效(框架底层可以优化)


🌟 直观理解:

场景解释
正确类分数高x[\text{class}]越大,loss 越小
错误类分数高\sum \exp(x[j])越大,loss 越大
目标压低 log-sum-exp,拉高正确类别 logits

🎯 一句话总结:

交叉熵 = “扣掉正确答案得分” + “对所有类别归一化”,越接近正确答案,loss 越小。
这就是你训练神经网络时 模型越来越聪明的数学依据 😎

举例:

logits = torch.tensor([1.0, 2.0, 0.1])  # 模型输出 (C=3)
label = torch.tensor([1])  # 真实类别索引 = 1

其中:

  • N=1(batch size)

  • C=3(类别数)

  • 正确类别是索引1,对应第二个值:2.0

🎁 完整公式回顾

\text{loss}(x, y) = -x[y] + \log \sum_{j} \exp(x[j])


🟣 第一步:Softmax + log 逻辑

softmax 本质上是:

p = \frac{\exp(x[\text{class}])}{\sum_j \exp(x[j])}

但是 PyTorch 的 CrossEntropyLoss 内部直接用:

\text{loss} = - \log p


🧮 你这个例子手动算:

logits = [1.0, 2.0, 0.1],class = 1,对应 logit = 2.0

第一部分:

- x[\text{class}] = -2.0

第二部分:

\log \sum_{j=1}^{3} \exp(x[j]) = \log (\exp(1.0) + \exp(2.0) + \exp(0.1))

先算:

  • exp(1.0)≈2.718

  • exp(2.0)≈7.389

  • exp(0.1)≈1.105

加起来:

∑=2.718+7.389+1.105=11.212

取对数:

log⁡(11.212)≈2.418

最终 loss:

loss=−2.0+2.418=0.418

🌟 你可以这样理解

部分含义
−x[class]- x[\text{class}]−x[class]惩罚正确答案打分太低
log⁡∑exp⁡(x)\log \sum \exp(x)log∑exp(x)考虑所有类别的对比,如果错误类别打分高也被惩罚
最终目标“提升正确答案打分、降低错误答案打分”

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/914873.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/914873.shtml
英文地址,请注明出处:http://en.pswp.cn/news/914873.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用 Ray 跨节点调用 GPU 部署 DeepSeek 大模型,实现分布式高效推理

在大模型时代,单节点 GPU 资源往往难以满足大模型(如 7B/13B 参数模型)的部署需求。借助 Ray 分布式框架,我们可以轻松实现跨节点 GPU 资源调度,让大模型在多节点间高效运行。本文将以 DeepSeek-llm-7B-Chat 模型为例&…

快速了解 HTTPS

1. 引入 在 HTTP 协议 章节的 reference 段,曾提到过 HTTPS。这里对HTTPS进行详细介绍。 HTTPS 是在 HTTP 的基础上,引入了一个加密层 (SSL)。HTTP 是明文传输的 (不安全)。当下所见到的大部分网站都是 HTTPS 的。 起初是拜运营商劫持所赐(…

mysql备份与视图

要求:1.将mydb9_stusys数据库下的student、sc 和course表,备份到本地主机保存为st_msg_bak.sql文件,然后将数据表恢复到自建的db_test数据库中;2.在db_test数据库创建一视图 stu_info,查询全体学生的姓名,性别,课程名&…

【数据结构】 链表 + 手动实现单链表和双链表的接口(图文并茂附完整源码)

文章目录 一、 链表的概念及结构 二、链表的分类 ​编辑 三、手动实现单链表 1、定义单链表的一个节点 2、打印单链表 3、创建新节点 4、单链表的尾插 5、单链表的头插 6、单链表的尾删 7、单链表的头删 8、单链表的查找 9、在指定位置之前插入一个新节点 10、在指…

Go语言时间控制:定时器技术详细指南

1. 定时器基础:从 time.Sleep 到 time.Timer 的进化为什么 time.Sleep 不够好?在 Go 编程中,很多人初学时会用 time.Sleep 来实现时间控制。比如,想让程序暂停 2 秒,代码可能是这样:package mainimport (&q…

C# 转换(显式转换和强制转换)

显式转换和强制转换 如果要把短类型转换为长类型,让长类型保存短类型的所有位很简单。然而,在其他情况下, 目标类型也许无法在不损失数据的情况下容纳源值。 例如,假设我们希望把ushort值转化为byte。 ushort可以保存任何0~65535的…

浅谈自动化设计最常用的三款软件catia,eplan,autocad

笔者从上半年开始接触这三款软件,掌握了基础用法,但是过了一段时间不用,发现再次用,遇到的问题短时间解决不了,忘记的有点多,这里记录一下,防止下次忘记Elpan:问题1QF01是柜安装板上的一个部件&…

网络编程7.17

练习&#xff1a;服务器&#xff1a;#include <stdio.h> #include <string.h> #include <unistd.h> #include <stdlib.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> #include <pthread.h> #include &…

c++ 模板元编程

听说模板元编程能在编译时计算出常量&#xff0c;简单测试下看看&#xff1a;template<int N> struct Summation {static constexpr int value N Summation<N - 1>::value; // 计算 1 2 ... N 的值 };template<> struct Summation<1> { // 递归终…

【深度学习】神经网络过拟合与欠拟合-part5

八、过拟合与欠拟合训练深层神经网络时&#xff0c;由于模型参数较多&#xff0c;数据不足的时候容易过拟合&#xff0c;正则化技术就是防止过拟合&#xff0c;提升模型的泛化能力和鲁棒性 &#xff08;对新数据表现良好 对异常数据表现良好&#xff09;1、概念1.1过拟合在训练…

JavaScript的“硬件窥探术”:浏览器如何读取你的设备信息?

JavaScript的“硬件窥探术”&#xff1a;浏览器如何读取你的设备信息&#xff1f; 在Web开发的世界里&#xff0c;JavaScript一直扮演着“幕后魔术师”的角色。从简单的页面跳转到复杂的实时数据处理&#xff0c;它似乎总能用最轻巧的方式解决最棘手的问题。但你是否想过&#…

论安全架构设计(层次)

安全架构设计&#xff08;层次&#xff09; 摘要 2021年4月&#xff0c;我有幸参与了某保险公司的“优车险”项目的建设开发工作&#xff0c;该系统以车险报价、车险投保和报案理赔为核心功能&#xff0c;同时实现了年检代办、道路救援、一键挪车等增值服务功能。在本项目中&a…

滚珠导轨常见的故障有哪些?

在自动化生产设备、精密机床等领域&#xff0c;滚珠导轨就像是设备平稳运行的 “轨道”&#xff0c;为机械部件的直线运动提供稳准导向。但导轨使用时间长了&#xff0c;难免会出现这样那样的故障。滚珠脱落&#xff1a;可能由安装不当、导轨损坏、超负荷运行、维护不当或恶劣环…

机器视觉的包装盒丝印应用

在包装盒丝网印刷领域&#xff0c;随着消费市场对产品外观精细化要求的持续提升&#xff0c;传统印刷工艺面临多重挑战&#xff1a;多色套印偏差、曲面基材定位困难、异形结构印刷失真等问题。双翌光电科技研发的WiseAlign视觉系统&#xff0c;通过高精度视觉对位技术与智能化操…

Redis学习-03重要文件及作用、Redis 命令行客户端

Redis 重要文件及作用 启动/停止命令或脚本 /usr/bin/redis-check-aof -> /usr/bin/redis-server /usr/bin/redis-check-rdb -> /usr/bin/redis-server /usr/bin/redis-cli /usr/bin/redis-sentinel -> /usr/bin/redis-server /usr/bin/redis-server /usr/libexec/red…

SVN客户端(TortoiseSVN)和SVN-VS2022插件(visualsvn)官网下载

SVN服务端官网下载地址&#xff1a;https://sourceforge.net/projects/win32svn/ SVN客户端工具(TortoiseSVN):https://plan.io/tortoise-svn/ SVN-VS2022插件(visualsvn)官网下载地址&#xff1a;https://www.visualsvn.com/downloads/

990. 等式方程的可满足性

题目&#xff1a;第一次思考&#xff1a; 经典并查集 实现&#xff1a;class UnionSet{public:vector<int> parent;public:UnionSet(int n) {parent.resize(n);}void init(int n) {for (int i 0; i < n; i) {parent[i] i;}}int find(int x) {if (parent[x] ! x) {pa…

HTML--教程

<!DOCTYPE html> <html> <head> <meta charset"utf-8"> <title>菜鸟教程(runoob.com)</title> </head> <body><h1>我的第一个标题</h1><p>我的第一个段落。</p> </body> </html&g…

Leetcode刷题营第二十七题:二叉树的最大深度

104. 二叉树的最大深度 给定一个二叉树 root &#xff0c;返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1&#xff1a; 输入&#xff1a;root [3,9,20,null,null,15,7] 输出&#xff1a;3示例 2&#xff1a; 输入&#xff…