LightGBM、XGBoost和CatBoost自定义损失函数和评估指标

    • 函数(缩放误差)
    • 数学原理
      • 损失函数定义
      • 梯度计算
      • 评估指标
    • LightGBM实现
      • 自定义损失函数
      • 自定义评估指标
      • 使用方式
    • XGBoost实现
      • 自定义损失函数
      • 自定义评估指标
      • 使用方式
    • CatBoost实现
      • 自定义损失函数
      • 自定义评估指标
      • 使用方式
    • 框架对比
    • 实际应用
      • 适用场景
    • 常见问题
      • 1. 为什么要设置最小阈值?
      • 2. 梯度和Hessian计算错误怎么办?
      • 3. 不同框架的性能差异
      • 4. 超参数调优建议

函数(缩放误差)

传统的均方误差(MSE)和平均绝对误差(MAE)对所有预测值给予相同的权重,但在某些场景下,更关心相对误差而非绝对误差。缩放误差通过将误差除以真实值来实现这一目标:

缩放误差 = (真实值 - 预测值) / max(真实值, 阈值)

这样设计的优势:

  • 对于大数值和小数值的预测给予相对平等的权重
  • 避免大数值主导损失函数
  • 更适合预测范围变化很大的场景

数学原理

损失函数定义

设损失函数为:

L(y, ŷ) = ((y - ŷ) / max(y, threshold))²

其中:

  • y 是真实值
  • ŷ 是预测值
  • threshold 是防止除零的最小阈值

梯度计算

对于梯度提升算法,我们需要计算损失函数对预测值的一阶导数(梯度)和二阶导数(Hessian):

d = max(y, threshold)e = (y - ŷ) / d

  • 一阶导数(梯度)∂L/∂ŷ = -2e/d
  • 二阶导数(Hessian)∂²L/∂ŷ² = 2/d²

评估指标

配套的评估指标使用缩放平均绝对误差(Scaled MAE):

Scaled MAE = mean(|y - ŷ| / max(y, threshold))

LightGBM实现

自定义损失函数

def custom_loss_squared_lgb(y_pred, train_data):"""LightGBM自定义缩放均方误差损失函数参数:y_pred: 预测值数组train_data: LightGBM的Dataset对象返回:tuple: (梯度数组, Hessian数组)"""y_true = train_data.get_label()  # 获取真实标签# 计算分母,防止除零denominator = np.maximum(y_true, threshold)# 计算缩放误差error = (y_true - y_pred) / denominator# 计算梯度和Hessiangrad = -2 * error / denominatorhess = 2 / (denominator ** 2)return grad, hess

自定义评估指标

def mae_metric_lgb(preds, train_data):"""LightGBM自定义缩放MAE评估指标参数:preds: 预测值数组train_data: LightGBM的Dataset对象返回:tuple: (指标名称, 指标值, 是否越大越好)"""y_true = train_data.get_label()denominator = np.maximum(y_true, threshold)error = np.abs(preds - y_true) / denominatorreturn 'scaled_mae', np.mean(error), False

使用方式

import lightgbm as lgb
import numpy as np# 参数配置
params = {'objective': custom_loss_squared_lgb,  # 使用自定义损失函数'boosting_type': 'gbdt','num_leaves': 31,'learning_rate': 0.01,'verbosity': -1
}# 训练模型
model = lgb.train(params, train_set, valid_sets=[train_set, valid_set],feval=mae_metric_lgb,  # 使用自定义评估指标num_boost_round=1000,callbacks=[lgb.early_stopping(100)]
)

XGBoost实现

自定义损失函数

def custom_loss_squared_xgb(y_pred, train_data):"""XGBoost自定义缩放均方误差损失函数参数:y_pred: 预测值数组train_data: XGBoost的DMatrix对象返回:tuple: (梯度数组, Hessian数组)"""y_true = train_data.get_label()  # 获取真实标签# 计算分母,防止除零denominator = np.maximum(y_true, threshold)# 计算缩放误差error = (y_true - y_pred) / denominator# 计算梯度和Hessiangrad = -2 * error / denominatorhess = 2 / (denominator ** 2)return grad, hess

自定义评估指标

def mae_metric_xgb(y_pred, train_data):"""XGBoost自定义缩放MAE评估指标参数:y_pred: 预测值数组train_data: XGBoost的DMatrix对象返回:tuple: (指标名称, 指标值)"""y_true = train_data.get_label()denominator = np.maximum(y_true, threshold)error = np.abs(y_true - y_pred) / denominatorreturn 'custom_mae', np.mean(error)

使用方式

import xgboost as xgb
import numpy as np# 参数配置
params = {'booster': 'gbtree','learning_rate': 0.01,'max_depth': 6,'random_state': 42
}# 训练模型
model = xgb.train(params,train_matrix,num_boost_round=1000,evals=[(train_matrix, 'train'), (valid_matrix, 'valid')],obj=custom_loss_squared_xgb,  # 自定义损失函数feval=mae_metric_xgb,         # 自定义评估指标early_stopping_rounds=100,verbose_eval=50
)

CatBoost实现

CatBoost的自定义函数需要用类的形式实现。

自定义损失函数

class CustomCatBoostObjective(object):"""CatBoost自定义缩放均方误差损失函数"""def calc_ders_range(self, approxes, targets, weights):"""计算梯度和Hessian参数:approxes: 当前预测值列表targets: 真实标签列表weights: 样本权重列表(可选)返回:list: [(梯度, Hessian), ...] 的列表"""assert len(approxes) == len(targets)if weights is not None:assert len(weights) == len(approxes)result = []for index in range(len(targets)):y_true = targets[index]y_pred = approxes[index]# 计算分母,防止除零denominator = max(y_true, threshold)# 计算缩放误差error = (y_true - y_pred) / denominator# 计算梯度和Hessiangrad = -2 * error / denominatorhess = 2 / (denominator ** 2)# 应用样本权重if weights is not None:grad *= weights[index]hess *= weights[index]result.append((grad, hess))return result

自定义评估指标

class CustomCatBoostEval(object):"""CatBoost自定义缩放MAE评估指标"""def is_max_optimal(self):"""指标是否越大越好"""return Falsedef evaluate(self, approxes, targets, weights):"""计算评估指标参数:approxes: 预测值列表的列表 [[pred1, pred2, ...]]targets: 真实标签列表weights: 样本权重列表(可选)返回:tuple: (误差总和, 权重总和)"""assert len(approxes) == 1assert len(targets) == len(approxes[0])error_sum = 0.0weight_sum = 0.0for i in range(len(targets)):y_true = targets[i]y_pred = approxes[0][i]# 计算缩放误差denominator = max(y_true, threshold)error = abs(y_true - y_pred) / denominator# 应用样本权重if weights is not None:error *= weights[i]weight_sum += weights[i]else:weight_sum += 1.0error_sum += errorreturn error_sum, weight_sumdef get_final_error(self, error, weight):"""计算最终的评估指标值"""return error / (weight + 1e-38)

使用方式

from catboost import CatBoostRegressor, Pool
import numpy as np# 创建数据池
train_pool = Pool(X_train, y_train)
valid_pool = Pool(X_valid, y_valid)# 参数配置
params = {'objective': CustomCatBoostObjective(),'eval_metric': CustomCatBoostEval(),'iterations': 1000,'learning_rate': 0.01,'depth': 6,'random_state': 42,'verbose': False
}# 训练模型
model = CatBoostRegressor(**params)
model.fit(train_pool,eval_set=valid_pool,early_stopping_rounds=100,verbose_eval=50,use_best_model=True
)

框架对比

特性LightGBMXGBoostCatBoost
损失函数形式函数函数类方法
参数名称objectiveobjobjective
数据获取train_data.get_label()dtrain.get_label()直接传入 targets
评估指标形式函数函数类方法
评估返回格式(name, value, is_higher_better)(name, value)error_sum, weight_sum
权重支持自动处理自动处理需手动处理
实现复杂度简单简单中等

实际应用

适用场景

  1. 新能源功率预测:风电、光伏功率预测范围从0到满功率
  2. 金融风险评估:不同规模公司的风险评估
  3. 销售预测:不同产品类别的销售额预测
  4. 网络流量预测:不同时段流量变化很大

常见问题

1. 为什么要设置最小阈值?

问题:直接用真实值作为分母会遇到什么问题?

答案

  • 当真实值为0或接近0时,会导致除零错误或梯度爆炸
  • 设置最小阈值可以保证数值稳定性
  • 阈值的选择应根据数据的实际分布来确定

2. 梯度和Hessian计算错误怎么办?

问题:如何验证梯度计算的正确性?

答案:可以用数值微分验证:

def verify_gradients(y_true, y_pred, eps=1e-6):"""验证梯度计算的正确性"""# 解析梯度denominator = np.maximum(y_true, threshold)error = (y_true - y_pred) / denominatorgrad_analytical = -2 * error / denominator# 数值梯度loss_plus = ((y_true - (y_pred + eps)) / denominator) ** 2loss_minus = ((y_true - (y_pred - eps)) / denominator) ** 2grad_numerical = (loss_plus - loss_minus) / (2 * eps)# 比较diff = np.abs(grad_analytical - grad_numerical)print(f"最大梯度差异: {np.max(diff)}")return np.allclose(grad_analytical, grad_numerical, atol=1e-5)

3. 不同框架的性能差异

问题:三个框架在使用自定义损失函数时的性能如何?

答案

  • LightGBM:通常最快,内存效率高
  • XGBoost:稳定性好,文档完善
  • CatBoost:对类别特征处理好,但自定义函数实现相对复杂

4. 超参数调优建议

# LightGBM调优示例
from optuna import create_studydef objective(trial):params = {'objective': custom_loss_squared_lgb,'boosting_type': 'gbdt','num_leaves': trial.suggest_int('num_leaves', 10, 100),'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.1),'feature_fraction': trial.suggest_float('feature_fraction', 0.4, 1.0),'bagging_fraction': trial.suggest_float('bagging_fraction', 0.4, 1.0),'verbosity': -1}model = lgb.train(params,train_data,valid_sets=[valid_data],feval=mae_metric_lgb,num_boost_round=1000,callbacks=[lgb.early_stopping(100)],verbose_eval=False)y_pred = model.predict(X_valid)scaled_mae = np.mean(np.abs(y_valid - y_pred) / np.maximum(y_valid, threshold))return scaled_maestudy = create_study(direction='minimize')
study.optimize(objective, n_trials=100)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/922207.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/922207.shtml
英文地址,请注明出处:http://en.pswp.cn/news/922207.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025-09-08升级问题记录: 升级SDK从Android11到Android12

将 Android 工程的 targetSdkVersion 从 30 (Android 11)升级到 31(Android 12)需要关注一些重要的行为变更和适配点。 主要适配要点: 适配类别关键变更点适配紧迫性简要说明组件导出属性声明了 Intent Filter 的组件…

利用OpenCV实现模板与多个对象匹配

代码实现:import cv2 import numpy as npimg_rgb cv2.imread(mobanpipei.jpg) img_gray cv2.cvtColor(img_rgb, cv2.COLOR_BGR2GRAY) template cv2.imread(jianto.jpg, flags0) h, w template.shape[:2]# 读取图像# # 顺时针旋转 90 度(k1&#xff0…

OS28.【Linux】自制简单的Shell的修bug记录

目录 1.问题代码 2.排查 前期检查 查找是谁修改了environ[0] 使用gdb下断点 查看后续的影响 分析出问题的split_commandline函数 3.反思 4.正确代码 5.结论 6.除此之外...... ★提示: 此bug非常隐蔽,不仔细分析很难查出问题,非常锻炼调试能力! 1.问题代码 #includ…

Debian 系统上安装与配置 MediaMTX

🎯 在 Debian 系统上安装与配置 MediaMTX(原 rtsp-simple-server):打造轻量级流媒体服务器 作者:远在太平洋 环境:Debian 10/11/12 | Ubuntu 可参考 关键词:MediaMTX、rtsp-simple-server、RTSP…

分布式专题——10.4 ShardingSphere-Proxy服务端分库分表

1 为什么要有服务端分库分表? ShardingSphere-Proxy 是 ShardingSphere 提供的服务端分库分表工具,定位是“透明化的数据库代理”。 它模拟 MySQL 或 PostgreSQL 的数据库服务,应用程序(Application)只需像访问单个数据…

Mysql相关的面试题1

什么是聚集索引(聚簇索引)?什么是二级索引(非聚簇索引)? 聚集索引就是叶子节点关联行数据的索引,二级索引就是叶子节点关联主键的索引,聚集索引必须有且仅有一个,二级索引…

电涌保护器:为现代生活筑起一道隐形防雷网

何为电涌保护器?电涌保护器(Surge Protective Device,简称SPD)主要用于控制信号系统,保护电气电子设备信号线路免受雷电电磁脉冲、感应过电压、操作过电压的影响,广泛应用于工控、消防、安防监控、交通、电…

【uniapp微信小程序】扫普通链接二维码打开小程序

需求:用户A保存自己的邀请码海报,用户B扫描该普通连接二维码,打开微信小程序,并且携带用户A的邀请码信息,用户B登录时,跟用户A关联,成为用户A的下级。 tips:保存海报到手机相册可以参…

LeetCode 378 - 有序矩阵中第 K 小的元素

文章目录摘要描述题解答案题解代码分析代码解析示例测试及结果输出结果时间复杂度空间复杂度总结摘要 在开发中,我们经常遇到需要处理大规模有序数据的场景,比如数据库分页、排行榜查询、或者处理排序过的矩阵。LeetCode 第 378 题“有序矩阵中第 K 小的…

【Lua】Windows 下编写 C 扩展模块:VS 编译与 Lua 调用全流程

▒ 目录 ▒🛫 导读需求环境1️⃣ 核心原理:Windows下Lua与C的交互逻辑2️⃣ Windows下编写步骤:以mymath模块为例2.1 步骤1:准备Windows开发环境方式1:官网下载Lua源码并编译(可控性高)方式2&am…

Python快速入门专业版(二十九):函数返回值:多返回值、None与函数嵌套调用

目录引一、多返回值:一次返回多个结果的优雅方式1. 多返回值的本质:隐式封装为元组示例1:返回多个值的函数及接收方式2. 多返回值的接收技巧技巧1:用下划线_忽略不需要的返回值技巧2:用*接收剩余值(Python …

python使用pip安装的包与卸载

1:基本卸载命令 # 卸载单个包 pip uninstall package_name# 示例:卸载requests包 pip uninstall requests2:卸载多个包 # 一次性卸载多个包 pip uninstall package1 package2 package3# 示例 pip uninstall requests numpy pandas3&#xff1…

超级流水线和标量流水线的原理

一、什么是流水线?要理解这两个概念,首先要明白流水线(Pipelining) 的基本思想。想象一个汽车装配工厂:* 没有流水线:一个工人负责组装一整辆汽车,装完一辆再装下一辆。效率很低。* 有了流水线&…

【Ansible】管理复杂的Play和Playbook知识点

1.什么是主机模式?答:主机模式是Ansible中用于从Inventory中筛选目标主机的规则,通过灵活的模式定义可精准定位需要执行任务的主机。2.主机模式的作用答:筛选目标:从主机清单中选择一个或多个主机/组,作为P…

FastGPT源码解析 Agent 智能体应用创建流程和代码分析

FastGPT对话智能体创建流程和代码分析 平台作为agent平台,平台所有功能都是围绕Agent创建和使用为核心的。平台整合各种基础能力,如大模型、知识库、工作流、插件等模块,通过可视化,在界面上创建智能体,使用全部基础能…

缺失数据处理全指南:方法、案例与最佳实践

如何处理缺失数据:方法、案例与最佳实践 1. 引言 在数据分析和机器学习中,缺失数据是一个普遍存在的问题。如何处理缺失值,往往直接影响到后续分析和建模的效果。处理不当,不仅会浪费数据,还可能导致模型预测结果的不准…

为什么Cesium不使用vue或者react,而是 保留 Knockout

1. Knockout-ES5 插件的语法简化优势 自动深度监听:Cesium 通过集成 Knockout-ES5 插件,允许开发者直接使用普通变量语法(如 viewModel.property newValue)替代繁琐的 observable() 包装,无需手动声明每个可观察属性。…

Word怎么设置页码总页数不包含封面和目录页

有时候使用页码格式是[第x页/共x页]或[x/x]时会遇到word总页数和实际想要的页数不一致,导致显示不统一,这里介绍一个简单的办法,适用于比较简单的情况。 一、wps版本 文章分节 首先将目录页与正文页进行分节:在目录页后面选择插入…

突破机器人通讯架构瓶颈,CAN/FD、高速485、EtherCAT,哪种总线才是最优解?

引言: 从协作机械臂到人形机器人,一文拆解主流总线技术选型困局 在机器人技术飞速发展的今天,从工厂流水线上的协作机械臂到科技展会上的人形机器人,它们的“神经系统”——通讯总线,正面临着前所未有的挑战。特斯拉O…

Java核心概念详解:JVM、JRE、JDK、Java SE、Java EE (Jakarta EE)

1. Java是什么? Java首先是一种编程语言。它拥有特定的语法、关键字和结构,开发者可以用它来编写指令,让计算机执行任务。核心特点: Java最著名的特点是“一次编写,到处运行”(Write Once, Run Anywhere - …