决策树分两种分类和回归,这篇博客我将对两种方法进行实战讲解

一、分类决策树

代码的核心任务是预测 “电信客户流失状态”,这是一个典型的分类任务

数据集附在该博客上,可以直接下载

代码整体结构整理

代码主要分为以下几个部分:

  1. 导入必要的库
  2. 数据读取与预处理
  3. 数据平衡处理
  4. 决策树模型参数调优
  5. 最佳模型训练与评估
  6. 阈值调整以优化召回率

详细代码讲解

1. 导入必要的库
from imblearn.over_sampling import RandomOverSampler
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
from sklearn.model_selection import cross_val_score
import numpy as np

这部分代码导入了后续需要用到的各类库:

  • 数据处理:pandas、numpy
  • 模型相关:DecisionTreeClassifier(决策树分类器)
  • 数据划分与评估:train_test_split、cross_val_score、metrics
  • 可视化:matplotlib.pyplot、plot_tree
  • 不平衡数据处理:RandomOverSampler
2. 数据读取与初步处理
# 读取数据
data = pd.read_excel('电信客户流失数据.xlsx')
x = data.iloc[:, :-1]  # 提取特征变量(除最后一列外的所有列)
y = data['流失状态']   # 提取目标变量(流失状态列)# 将数据分为训练集和测试集,测试集占20%
xtr, xte, ytr, yte = train_test_split(x, y, test_size=0.2, random_state=42)
  • 读取 Excel 格式的电信客户流失数据
  • 分割特征变量 (x) 和目标变量 (y),其中目标变量是 "流失状态"
  • 使用 train_test_split 将数据分为训练集 (80%) 和测试集 (20%)
  • random_state=42 确保结果可重现
3. 数据平衡处理
# 手动实现过采样以平衡数据集
new_date = xtr.copy()
new_date['流失状态'] = ytr  # 将训练集特征和标签合并# 分离正例和负例(假设0表示未流失,1表示流失)
positive_eg = new_date[new_date['流失状态'] == 0]
negative_eg = new_date[new_date['流失状态'] == 1]# 对多数类进行下采样,使其数量与少数类相同
positive_eg = positive_eg.sample(len(negative_eg))# 合并处理后的正负例,得到平衡的训练集
date_c = pd.concat([positive_eg, negative_eg])
xtr = date_c.iloc[:, :-1]  # 平衡后的特征
ytr = date_c['流失状态']   # 平衡后的标签

这部分通过下采样方法处理数据不平衡问题:

  • 将多数类 (未流失客户) 的样本数量减少到与少数类 (流失客户) 相同
  • 这样处理是因为在客户流失预测中,我们通常更关注少数类 (流失客户) 的识别
  • 平衡数据有助于模型更好地学习少数类的特征
4. 决策树模型参数调优
# 网格搜索寻找最佳参数组合
scores = []
# 遍历不同的参数组合
for d in range(3, 10):          # 树的最大深度for l in range(2, 8):       # 叶节点的最小样本数for s in range(2, 6):   # 分裂所需的最小样本数for n in range(2, 8):# 最大叶节点数# 创建决策树模型model = DecisionTreeClassifier(max_depth=d,min_samples_split=s,min_samples_leaf=l,max_leaf_nodes=n,random_state=42)# 5折交叉验证,评估指标为召回率score = cross_val_score(model, xtr, ytr, cv=5, scoring='recall')score_mean = sum(score) / len(score)  # 计算平均召回率scores.append([d, l, s, n, score_mean])# 找到召回率最高的参数组合
best_params = max(scores, key=lambda x: x[4])
d, l, s, n, best_recall = best_params
print(f"最佳参数: 最大深度={d}, 最小叶节点样本数={l}, 最小分裂样本数={s}, 最大叶节点数={n}, 最佳召回率={best_recall}")

这部分通过网格搜索进行参数调优:

  • 遍历决策树的四个关键参数的不同取值组合
  • 使用 5 折交叉验证评估每个组合的性能,重点关注召回率 (recall)
  • 召回率在客户流失预测中很重要,因为我们希望尽可能识别出所有可能流失的客户
  • 选择召回率最高的参数组合作为最佳参数
5. 最佳模型训练与评估
# 使用最佳参数创建并训练模型
dtr_best = DecisionTreeClassifier(max_depth=d,min_samples_split=s,min_samples_leaf=l,max_leaf_nodes=n,random_state=42
)
dtr_best.fit(xtr, ytr)# 在训练集和测试集上进行预测
test_predicted = dtr_best.predict(xte)
train_predicted = dtr_best.predict(xtr)# 输出分类报告
print("训练集分类报告:")
print(metrics.classification_report(ytr, train_predicted))
print("测试集分类报告:")
print(metrics.classification_report(yte, test_predicted))

这部分使用最佳参数训练模型并评估:

  • 用最佳参数组合构建决策树模型并在训练集上拟合
  • 在训练集和测试集上分别进行预测
  • 输出详细的分类报告,包括精确率、召回率、F1 分数等指标
  • 通过对比训练集和测试集的表现,可以初步判断模型是否过拟合
6. 阈值调整以优化召回率
# 尝试不同的阈值以优化召回率
thresholds = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5]
recalls = []for i in thresholds:# 获取测试集的预测概率y_predict_proba = dtr_best.predict_proba(xte)y_predict_proba = pd.DataFrame(y_predict_proba)# 根据当前阈值调整预测结果y_predict_proba[y_predict_proba[[1]] > i] = 1y_predict_proba[y_predict_proba[[1]] <= i] = 0# 计算并记录召回率recall = metrics.recall_score(yte, y_predict_proba[1])recalls.append(recall)print(f"阈值: {i}, 召回率: {recall}")# 找到最佳阈值
best_threshold = thresholds[np.argmax(recalls)]
print(f'最佳阈值: {best_threshold}')
print(f'调整阈值后的最高recall为: {max(recalls)}')# 使用最佳阈值生成最终预测结果
y_predict_proba = dtr_best.predict_proba(xte)
y_predict_proba = pd.DataFrame(y_predict_proba)
y_predict_proba[y_predict_proba[[1]] > best_threshold] = 1
y_predict_proba[y_predict_proba[[1]] <= best_threshold] = 0# 输出最佳阈值下的分类报告
print(f"\n最佳阈值 {best_threshold} 对应的分类报告:")
print(metrics.classification_report(yte, y_predict_proba[1]))

这部分通过调整分类阈值进一步优化模型:

  • 决策树默认使用 0.5 作为分类阈值
  • 降低阈值会增加预测为流失 (1) 的样本比例,通常会提高召回率
  • 尝试多个阈值,计算每个阈值对应的召回率
  • 选择召回率最高的阈值作为最佳阈值
  • 输出最佳阈值下的分类报告,此时模型在识别流失客户方面表现最优

二、分类决策树

下面会创建一个使用回归决策树解决问题的示例,我们将使用 scikit-learn 自带的波士顿房价数据集(或其替代数据集)来预测房价。(这个代码不用下载数据集,python的sklearn库自带这个数据集

1. 导入必要的库

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_california_housing  # 加州房价数据集
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import cross_val_score
from sklearn.tree import plot_tree

功能说明
这部分导入了后续分析所需的各类工具库:

  • 数值计算:numpy 用于数学运算,pandas 用于数据处理与分析
  • 可视化:matplotlib.pyplot 用于绘图,plot_tree 用于决策树可视化
  • 数据集:fetch_california_housing 获取 sklearn 自带的加州房价数据集
  • 模型相关:DecisionTreeRegressor 是回归决策树模型类
  • 数据划分与评估:train_test_split 分割训练集和测试集;cross_val_score 用于交叉验证;mean_squared_error等是回归任务的评估指标

回归 vs 分类:此处使用 DecisionTreeRegressor(回归决策树),而非分类任务的 DecisionTreeClassifier,因为我们要预测的是连续数值(房价)。

2. 加载并探索数据集

# 1. 加载数据集
housing = fetch_california_housing()
X = housing.data  # 特征数据
y = housing.target  # 目标变量(房价,单位:10万美元)# 将数据转换为DataFrame以便查看
df = pd.DataFrame(X, columns=housing.feature_names)
df['MedHouseVal'] = y  # 添加目标变量列(房价中位数)# 查看数据集基本信息
print("数据集形状:", df.shape)
print("\n数据集前5行:")
print(df.head())
print("\n数据集描述统计:")
print(df.describe())

功能说明

  • 加载加州房价数据集:该数据集包含加州各地区的房价数据,替代了已移除的波士顿房价数据集
  • 数据结构:X 是特征矩阵(8 个特征),y 是目标变量(房价中位数,单位为 10 万美元)
  • 特征说明:包括平均收入(MedInc)、房龄(HouseAge)、平均房间数(AveRooms)等 8 个特征
  • 数据探索:通过打印数据集形状、前 5 行和描述统计量,快速了解数据分布特征(均值、标准差、最值等)

关键输出

  • 数据集形状:查看样本数量和特征数量
  • 描述统计:了解各特征的数值范围和分布,为后续建模提供参考

3. 划分训练集和测试集

# 2. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42
)

功能说明

  • 使用 train_test_split 将数据分为训练集(80%)和测试集(20%)
    • X_train/y_train:用于模型训练的特征和目标变量
    • X_test/y_test:用于评估模型泛化能力的特征和目标变量
  • test_size=0.2 表示测试集占总数据的 20%
  • random_state=42 固定随机种子,确保每次运行结果可重现

为什么要划分数据集
训练集用于模型学习数据规律,测试集用于模拟真实场景,评估模型对新数据的预测能力,避免过拟合问题。

4. 模型参数调优(寻找最佳树深度

# 3. 训练回归决策树模型并进行参数调优
best_score = -np.inf
best_depth = 1# 尝试不同的树深度,寻找最佳参数
for depth in range(1, 21):# 创建回归决策树模型regressor = DecisionTreeRegressor(max_depth=depth,random_state=42)# 使用交叉验证评估模型scores = cross_val_score(regressor, X_train, y_train, cv=5, scoring='neg_mean_squared_error')# 计算平均MSE(注意交叉验证返回的是负的MSE)avg_mse = -np.mean(scores)# 计算R²分数r2_scores = cross_val_score(regressor, X_train, y_train, cv=5, scoring='r2')avg_r2 = np.mean(r2_scores)if avg_r2 > best_score:best_score = avg_r2best_depth = depthprint(f"树深度: {depth}, 平均MSE: {avg_mse:.4f}, 平均R²: {avg_r2:.4f}")print(f"\n最佳树深度: {best_depth}, 最佳R²分数: {best_score:.4f}")

功能说明
这是模型调优的核心步骤,通过遍历不同树深度寻找最佳参数:

  • 参数选择:重点优化 max_depth(树的最大深度),范围从 1 到 20
    • 树深度过小:模型简单,可能欠拟合(无法捕捉数据规律)
    • 树深度过大:模型复杂,可能过拟合(过度拟合训练数据噪声)
  • 交叉验证:使用 5 折交叉验证(cv=5)评估每个深度的模型性能
    • 将训练集分为 5 份,轮流用 4 份训练、1 份验证,最终取平均值
    • 避免单次划分的随机性影响,更稳定地评估模型性能
  • 评估指标
    • neg_mean_squared_error:负均方误差(sklearn 习惯返回负值,需取反)
    • r2:决定系数,衡量模型对目标变量变异的解释能力(越接近 1 越好)
  • 最佳参数选择:保留 R² 分数最高的树深度作为最佳参数

调优目的:找到泛化能力最强的模型参数,平衡欠拟合和过拟合。

5. 训练最佳模型并评估

# 4. 使用最佳参数训练模型
best_regressor = DecisionTreeRegressor(max_depth=best_depth,random_state=42
)
best_regressor.fit(X_train, y_train)# 5. 模型评估
# 在测试集上进行预测
y_pred = best_regressor.predict(X_test)# 计算评估指标
mse = mean_squared_error(y_test, y_pred)
rmse = np.sqrt(mse)  # 均方根误差
mae = mean_absolute_error(y_test, y_pred)  # 平均绝对误差
r2 = r2_score(y_test, y_pred)  # R²分数print("\n测试集评估指标:")
print(f"均方误差 (MSE): {mse:.4f}")
print(f"均方根误差 (RMSE): {rmse:.4f}")  # 单位与目标变量相同(10万美元)
print(f"平均绝对误差 (MAE): {mae:.4f}")
print(f"R²分数: {r2:.4f}")  # 越接近1表示模型拟合越好

功能说明

  • 用最佳树深度创建模型并在训练集上拟合(fit方法)
  • 在测试集上生成预测结果(predict方法)
  • 计算核心评估指标:
    • MSE(均方误差):预测值与实际值差的平方的平均值,衡量误差大小
    • RMSE(均方根误差):MSE 的平方根,单位与目标变量一致(此处为 10 万美元),更易解释
    • MAE(平均绝对误差):预测值与实际值差的绝对值的平均值,对异常值更稳健
    • R² 分数:表示模型解释的房价变异比例,R²=0.6 意味着模型解释了 60% 的房价变异

评估意义:通过测试集指标判断模型的实际预测能力,若测试集与训练集指标差距过大,可能存在过拟合。

6. 决策树可视化(可选

# 6. 可视化部分决策树(如果树不太大)
if best_depth <= 5:  # 只可视化较浅的树,太深的树可视化效果不好plt.figure(figsize=(15, 10))plot_tree(best_regressor,feature_names=housing.feature_names,filled=True,rounded=True,precision=2)plt.title(f"回归决策树 (深度: {best_depth})")plt.show()
else:print(f"\n由于最佳树深度({best_depth})较大,未进行可视化")

功能说明

  • 当树深度较小时(≤5),使用 plot_tree 可视化决策树结构
  • 可视化参数:
    • feature_names:显示特征名称,增强可读性
    • filled=True:根据节点的目标值范围填充颜色
    • rounded=True:圆角矩形显示节点
  • 深层树不可视化的原因:深度过大的树结构复杂,节点繁多,可视化后难以解读

可视化价值:直观展示决策树的分裂规则,帮助理解模型如何通过特征判断房价。

7. 特征重要性分析

# 7. 分析特征重要性
feature_importance = pd.DataFrame({'特征': housing.feature_names,'重要性': best_regressor.feature_importances_
})
feature_importance = feature_importance.sort_values('重要性', ascending=False)
print("\n特征重要性排序:")
print(feature_importance)

功能说明

  • 回归决策树通过 feature_importances_ 属性输出各特征的重要性
  • 重要性定义:特征在树中所有分裂点减少的误差总和(归一化到 0-1 之间,总和为 1)
  • 结果排序:按重要性降序排列,清晰展示哪些特征对房价预测影响最大

实际意义:在房价预测中,通常 "平均收入(MedInc)" 会是最重要的特征,这符合现实逻辑 —— 收入水平直接影响购房能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/91977.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/91977.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/91977.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL154 插入记录(一)

描述牛客后台会记录每个用户的试卷作答记录到exam_record表&#xff0c;现在有两个用户的作答记录详情如下&#xff1a;用户1001在2021年9月1日晚上10点11分12秒开始作答试卷9001&#xff0c;并在50分钟后提交&#xff0c;得了90分&#xff1b;用户1002在2021年9月4日上午7点1分…

BeanFactory 和 ApplicationContext 的区别?

口语化答案好的&#xff0c;面试官。BeanFactory和ApplicationContext都是用于管理Bean的容器接口。BeanFactory功能相对简单。提供了Bean的创建、获取和管理功能。默认采用延迟初始化&#xff0c;只有在第一次访问Bean时才会创建该Bean。因为功能较为基础&#xff0c;BeanFact…

VNC连接VirtualBox中的Ubuntu24.04 desktop图形化(GUI)界面

测试环境&#xff1a;VirtualBox 7,Ubuntu24.04 desktop,Ubuntu24.04 server(no desktop) 一、下载和安装dRealVNC viewer。 二、配置 VirtualBox 网络&#xff1a;NAT 模式 端口转发 1、打开 VirtualBox&#xff0c;选择您的 Ubuntu 虚拟机&#xff0c;点击 设置。 选择 网…

浮动路由和BFD配置

拓扑图 前期的拓扑图没有交换机配置步骤 1、配置IP地址 终端IP地址的配置 路由器IP地址的配置 配置router的对应接口的IP地址 <Huawei>sys [Huawei]sysname router [router]interface Ethernet 0/0/0 [router-Ethernet0/0/0]ip address 192.168.10.254 24 [router-Ethern…

Docker 实战 -- Nextcloud

文章目录前言1. 创建 docker-compose.yml2. 启动 Nextcloud3. 访问 Nextcloud4. 配置优化&#xff08;可选&#xff09;使用 PostgreSQL使用 redis添加 Cron 后台任务5. 常用命令6. 反向代理&#xff08;Nginx/Apache&#xff09;前言 当你迷茫的时候&#xff0c;请点击 Docke…

【计算机网络 | 第2篇】计算机网络概述(下)

文章目录七.因特网服务提供商&#x1f95d;八.接入网&#x1f95d;主流的家庭宽带接入方式介入网工作原理&#x1f9d0;DSL技术&#xff1a;铜线上的“三通道”通信DSL的速率标准呈现出显著的"不对称"特征&#x1f914;电缆互联网接入技术&#x1f34b;‍&#x1f7e…

SpringMVC 6+源码分析(四)DispatcherServlet实例化流程 3--(HandlerAdapter初始化)

一、概述 HandlerAdapter 是 Spring MVC 框架中的一个核心组件&#xff0c;它在 DispatcherServlet 和处理程序&#xff08;handler&#xff09;之间扮演适配器的角色。DispatcherServlet 接收到 HTTP 请求后&#xff0c;需要调用对应的 handler 来处理请求&#xff08;如控制器…

【lucene】FastVectorHighlighter案例

下面给出一套可直接拷贝运行的 Lucene 8.5.0 FastVectorHighlighter 完整示例&#xff08;JDK 8&#xff09;&#xff0c;演示从建索引、查询到高亮的全过程。 > 关键点&#xff1a;字段必须 1. 存储原始内容&#xff08;setStored(true)&#xff09; 2. 开启 TermVecto…

C++返回值优化(RVO):高效返回对象的艺术

在C开发中&#xff0c;按值返回对象的场景十分常见&#xff08;如运算符重载、工厂函数等&#xff09;&#xff0c;但开发者常因担忧“构造/析构的性能开销”而陷入纠结&#xff1a;该不该返回对象&#xff1f;如何避免额外成本&#xff1f;本文将剖析痛点、拆解错误思路&#…

用 PyTorch 实现一个简单的神经网络:从数据到预测

PyTorch 是目前最流行的深度学习框架之一&#xff0c;以其灵活性和易用性受到开发者的喜爱。本文将带你从零开始&#xff0c;用 PyTorch 实现一个简单的神经网络&#xff0c;用于解决经典的 MNIST 手写数字分类问题。我们将涵盖数据准备、模型构建、训练和预测的完整流程&#…

四级页表通俗讲解与实践(以 64 位 ARM Cortex-A 为例)

&#x1f4d6; &#x1f3a5; B 站博文精讲视频&#xff1a;点击链接&#xff0c;配合视频深度学习 四级页表通俗讲解与实践&#xff08;以 64 位 ARM Cortex-A 为例&#xff09; 本文面向希望彻底理解现代 64 位架构下四级页表的开发者&#xff0c;结合 ARM Cortex-A 系列处理…

AI模型整合包上线!一键部署ComfyUI,2.19TB模型全解析

最近体验了AIStarter平台上线的AI模型整合包&#xff0c;包含2.19TB ComfyUI大模型&#xff0c;整合市面主流模型&#xff0c;一键部署ComfyUI&#xff0c;省去重复下载烦恼&#xff01;以下是使用心得和部署步骤&#xff0c;适合AI开发者参考。工具亮点这款AI模型整合包由熊哥…

灰色优选模型及算法MATLAB代码

电子装备试验方案优选是一个典型的多属性决策问题&#xff0c;通常涉及指标复杂、信息不完整、数据量少且存在不确定性的特点。灰色系统理论&#xff08;Grey System Theory&#xff09;特别擅长处理“小样本、贫信息”的不确定性问题&#xff0c;因此非常适合用于此类方案的优…

AI框架工具FastRTC快速上手6——视频流案例之物体检测(下)

一 前言 上一篇,我们实现了用YOLO对图片上的物体进行检测,并在图片上框出具体的对象并打出标签。但只是应用在单张图片,且还没用上FastRTC。 本篇,我们希望结合FastRTC的能力,实现基于YOLO的实时视频流的物体检测。 本篇文字将不会太多。学习完本篇,对比前面的文章,你…

PHP常见中高面试题汇总

一、 PHP部分 1、PHP如何实现静态化 PHP的静态化分为&#xff1a;纯静态和伪静态。其中纯静态又分为&#xff1a;局部纯静态和全部纯静态。 PHP伪静态&#xff1a;利用Apache mod_rewrite实现URL重写的方法&#xff1b; PHP纯静态&#xff0c;就是生成HTML文件的方式&#xff0…

基于Java AI(人工智能)生成末日题材的实践

Java AI 生成《全球末日》文章的实例 使用Java结合AI技术生成《全球末日》题材的文章可以通过多种方式实现,包括调用预训练模型、使用自然语言处理库或结合生成式AI框架。以下是30个实例的生成方法和示例代码片段。 调用预训练模型(如GPT-3或GPT-4) 使用OpenAI API生成末日…

针对软件定义车载网络的动态服务导向机制

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 做到欲望极简,了解自己的真实欲望,不受外在潮流的影响,不盲从,不跟风。把自己的精力全部用在自己。一是去掉多余,凡事找规律,基础是诚信;二是…

Pytorch实现婴儿哭声检测和识别

Pytorch实现婴儿哭声检测和识别 目录 Pytorch实现婴儿哭声检测识别 1. 项目说明 2. 数据说明 &#xff08;1&#xff09;婴儿哭声语音数据集 &#xff08;2&#xff09;自定义数据集 3. 模型训练 &#xff08;1&#xff09;项目安装 &#xff08;2&#xff09;准备Tra…

海信IP810N/海信IP811N_海思MV320-安卓9.0主板-TTL烧录包-可救砖

海信IP810N&#xff0f;海信IP811N_海思MV320处理器-安卓9主板-TTL烧录包-可救砖准备工作&#xff1a;TTL线自备跑码工具【putty跑码中文版】路径&#xff1a;【工具大全】-【putty跑码中文版】测试跑码以后将跑码窗口关闭&#xff1b;然后到下方下载烧录工具并大致看下教程烧录…

Go 中的 interface{} 与 Java 中的 Object:相似之处与本质差异

在软件系统开发中&#xff0c;“通用类型”的处理是各语言设计中不可忽视的一部分。Java 使用 Object&#xff0c;Go 使用 interface{}&#xff0c;它们都可以容纳任意类型的值&#xff0c;是实现动态行为或通用容器的基础类型。然而&#xff0c;虽然两者在使用层面看似相似&am…