点击 “AladdinEdu,同学们用得起的【H卡】算力平台”,H卡级别算力,按量计费,灵活弹性,顶级配置,学生专属优惠。

决策树/SVM/KNN算法对比 × 模型评估指标解析
读者收获:掌握经典机器学习全流程

当80%的机器学习问题可用Scikit-learn解决,掌握其核心流程将成为你的核心竞争力。本文通过对比实验揭示算法本质,带你一站式打通机器学习任督二脉。

一、Scikit-learn全景图:3大核心模块解析

在这里插入图片描述

1.1 算法选择矩阵

在这里插入图片描述

1.2 环境极速配置

# 创建专用环境  
conda create -n sklearn_env python=3.10  
conda activate sklearn_env  # 安装核心库  
pip install numpy pandas matplotlib seaborn scikit-learn  # 验证安装  
import sklearn  
print(f"Scikit-learn version: {sklearn.__version__}")  

二、分类实战:鸢尾花识别

2.1 数据探索与预处理

from sklearn.datasets import load_iris  
import pandas as pd  # 加载数据集  
iris = load_iris()  
df = pd.DataFrame(iris.data, columns=iris.feature_names)  
df['target'] = iris.target  # 数据概览  
print(f"样本数: {df.shape[0]}")  
print(f"特征数: {df.shape[1]-1}")  
print(f"类别分布:\n{df['target'].value_counts()}")  # 可视化分析  
import seaborn as sns  
sns.pairplot(df, hue='target', palette='viridis')  

2.2 三大分类器对比实验

from sklearn.model_selection import train_test_split  
from sklearn.tree import DecisionTreeClassifier  
from sklearn.svm import SVC  
from sklearn.neighbors import KNeighborsClassifier  # 划分数据集  
X = df.drop(columns='target')  
y = df['target']  
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)  # 初始化分类器  
models = {  "决策树": DecisionTreeClassifier(max_depth=3),  "SVM": SVC(kernel='rbf', probability=True),  "KNN": KNeighborsClassifier(n_neighbors=5)  
}  # 训练与评估  
results = {}  
for name, model in models.items():  model.fit(X_train, y_train)  y_pred = model.predict(X_test)  results[name] = y_pred  

2.3 分类结果可视化

import matplotlib.pyplot as plt  
from sklearn.metrics import confusion_matrix  # 绘制混淆矩阵  
fig, axes = plt.subplots(1, 3, figsize=(18, 5))  
for i, (name, y_pred) in enumerate(results.items()):  cm = confusion_matrix(y_test, y_pred)  sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[i])  axes[i].set_title(f"{name} 混淆矩阵")  
plt.show()  

三、回归实战:波士顿房价预测

3.1 数据解析与特征工程

from sklearn.datasets import fetch_openml  # 加载数据集  
boston = fetch_openml(name='boston', version=1)  
df = pd.DataFrame(boston.data, columns=boston.feature_names)  
df['PRICE'] = boston.target  # 关键特征分析  
corr = df.corr()['PRICE'].sort_values(ascending=False)  
print(f"与房价相关性最高的特征:\n{corr.head(5)}")  # 特征工程  
df['RM_LSTAT'] = df['RM'] / df['LSTAT']  # 创造新特征  

3.2 回归模型对比

from sklearn.linear_model import LinearRegression  
from sklearn.tree import DecisionTreeRegressor  
from sklearn.svm import SVR  # 划分数据集  
X = df.drop(columns='PRICE')  
y = df['PRICE']  
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)  # 初始化回归器  
regressors = {  "线性回归": LinearRegression(),  "决策树回归": DecisionTreeRegressor(max_depth=5),  "支持向量回归": SVR(kernel='rbf')  
}  # 训练与预测  
predictions = {}  
for name, reg in regressors.items():  reg.fit(X_train, y_train)  pred = reg.predict(X_test)  predictions[name] = pred  

3.3 回归结果可视化

# 绘制预测值与真实值对比  
plt.figure(figsize=(15, 10))  
for i, (name, pred) in enumerate(predictions.items(), 1):  plt.subplot(3, 1, i)  plt.scatter(y_test, pred, alpha=0.7)  plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--')  plt.xlabel('真实价格')  plt.ylabel('预测价格')  plt.title(f'{name} 预测效果')  
plt.tight_layout()  

四、模型评估指标深度解析

4.1 分类指标四维分析

在这里插入图片描述
鸢尾花分类评估实例:

from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score  metrics = []  
for name, y_pred in results.items():  metrics.append({  "模型": name,  "准确率": accuracy_score(y_test, y_pred),  "精确率": precision_score(y_test, y_pred, average='macro'),  "召回率": recall_score(y_test, y_pred, average='macro'),  "F1": f1_score(y_test, y_pred, average='macro')  })  metrics_df = pd.DataFrame(metrics)  
print(metrics_df)  

在这里插入图片描述

4.2 回归指标三维对比

波士顿房价评估实例:

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score  reg_metrics = []  
for name, pred in predictions.items():  reg_metrics.append({  "模型": name,  "MSE": mean_squared_error(y_test, pred),  "MAE": mean_absolute_error(y_test, pred),  "R²": r2_score(y_test, pred)  })  reg_metrics_df = pd.DataFrame(reg_metrics)  
print(reg_metrics_df)  

在这里插入图片描述

五、算法原理对比揭秘

5.1 决策树:可解释性之王

核心参数调优指南:

params = {  'max_depth': [3, 5, 7, None],  'min_samples_split': [2, 5, 10],  'criterion': ['gini', 'entropy']  
}  best_tree = GridSearchCV(  DecisionTreeClassifier(),  param_grid=params,  cv=5,  scoring='f1_macro'  
)  
best_tree.fit(X_train, y_train)  

5.2 SVM:高维空间的分割大师

核函数选择策略:
在这里插入图片描述

5.3 KNN:简单高效的惰性学习

距离度量对比:

distance_metrics = [  ('euclidean', '欧氏距离'),  ('manhattan', '曼哈顿距离'),  ('cosine', '余弦相似度')  
]  for metric, name in distance_metrics:  knn = KNeighborsClassifier(n_neighbors=5, metric=metric)  knn.fit(X_train, y_train)  score = knn.score(X_test, y_test)  print(f"{name} 准确率: {score:.4f}")  

六、模型优化实战技巧

6.1 特征工程:性能提升关键

波士顿房价特征优化:

from sklearn.preprocessing import PolynomialFeatures  # 创建多项式特征  
poly = PolynomialFeatures(degree=2, include_bias=False)  
X_poly = poly.fit_transform(X)  # 新特征训练  
lr_poly = LinearRegression()  
lr_poly.fit(X_train_poly, y_train)  
r2 = lr_poly.score(X_test_poly, y_test)  
print(f"R²提升: {reg_metrics_df.loc[0,'R²']:.2f}{r2:.2f}")  

6.2 交叉验证:防止过拟合

from sklearn.model_selection import cross_val_score  # 5折交叉验证  
scores = cross_val_score(  SVC(),  X, y,  cv=5,  scoring='accuracy'  
)  
print(f"平均准确率: {scores.mean():.4f}{scores.std():.4f})")  

6.3 网格搜索:自动化调参

from sklearn.model_selection import GridSearchCV  # 定义参数网格  
param_grid = {  'C': [0.1, 1, 10, 100],  'gamma': [1, 0.1, 0.01, 0.001],  'kernel': ['rbf', 'linear']  
}  # 执行搜索  
grid = GridSearchCV(SVC(), param_grid, refit=True, verbose=3)  
grid.fit(X_train, y_train)  
print(f"最优参数: {grid.best_params_}")  

七、工业级部署方案

7.1 模型持久化

import joblib  # 保存模型  
joblib.dump(best_model, 'iris_classifier.pkl')  # 加载模型  
clf = joblib.load('iris_classifier.pkl')  # 在线预测  
new_data = [[5.1, 3.5, 1.4, 0.2]]  
prediction = clf.predict(new_data)  
print(f"预测类别: {iris.target_names[prediction[0]]}")  

7.2 构建预测API

from flask import Flask, request, jsonify  app = Flask(__name__)  
model = joblib.load('iris_classifier.pkl')  @app.route('/predict', methods=['POST'])  
def predict():  data = request.get_json()  features = [data['sepal_length'], data['sepal_width'],  data['petal_length'], data['petal_width']]  prediction = model.predict([features])  return jsonify({'class': iris.target_names[prediction[0]]})  if __name__ == '__main__':  app.run(host='0.0.0.0', port=5000)  

7.3 性能监控仪表盘

from sklearn.metrics import plot_roc_curve, plot_precision_recall_curve  # 分类性能可视化  
fig, ax = plt.subplots(1, 2, figsize=(15, 6))  
plot_roc_curve(model, X_test, y_test, ax=ax[0])  
plot_precision_recall_curve(model, X_test, y_test, ax=ax[1])  

八、避坑指南:常见错误解决方案

8.1 数据预处理陷阱

问题:测试集出现未知类别
解决方案

from sklearn.preprocessing import OneHotEncoder  # 训练阶段  
encoder = OneHotEncoder(handle_unknown='ignore')  
encoder.fit(X_train_categorical)  # 测试阶段自动忽略未知类别  
X_test_encoded = encoder.transform(X_test_categorical)  

8.2 特征尺度问题

症状:SVM/KNN性能异常
处方

from sklearn.preprocessing import StandardScaler  scaler = StandardScaler()  
X_train_scaled = scaler.fit_transform(X_train)  
X_test_scaled = scaler.transform(X_test)  # 注意:只变换不拟合  

8.3 样本不均衡处理

解决方案对比
在这里插入图片描述

结语:机器学习工程师的成长之路

当你在Scikit-learn中完整实现从数据加载到模型部署的全流程,已超越70%的入门者。但真正的进阶之路刚刚开始。

下一步行动指南:

# 1. 复现经典论文算法  
from sklearn.linear_model import LogisticRegression  
model = LogisticRegression(penalty='l1', solver='liblinear')  # 2. 参加Kaggle竞赛  
from kaggle import api  
api.competitions_list(search='getting started')  # 3. 构建个人项目组合  
projects = [  {"name": "鸢尾花分类器", "type": "分类", "accuracy": 0.97},  {"name": "房价预测", "type": "回归", "R2": 0.85}  
]  

记住:在机器学习领域,理论认知的深度=代码实践的厚度。现在运行你的第一个完整流程,让Scikit-learn成为你AI旅程中最可靠的伙伴。

附录:Scikit-learn速查表

任务类型导入路径核心参数
分类from sklearn.ensemble import RandomForestClassifiern_estimators, max_depth
回归from sklearn.linear_model import LinearRegressionfit_intercept, normalize
聚类from sklearn.cluster import KMeansn_clusters, init
降维from sklearn.decomposition import PCAn_components
模型选择from sklearn.model_selection import GridSearchCVparam_grid, cv
数据预处理from sklearn.preprocessing import StandardScalerwith_mean, with_std

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/93835.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/93835.shtml
英文地址,请注明出处:http://en.pswp.cn/web/93835.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

rsync + inotify 数据实时同步

rsync inotify 数据实时同步 一、rsync简介 rsync是linux系统下的数据镜像备份工具。使用快速增量备份工具Remote Sync可以远程同步, 支持本地复制,或者与其他SSH、rsync主机同步 二、rsync三种命令 Rsync的命令格式常用的有以下三种:&#…

Linux基础介绍-3——第一阶段

文章目录一、进程管理1.1 进程的基本概念1.2 常见管理命令1.3 进程优先级调整:nice 与 renice二、软件包管理三、防火墙管理四、shell脚本五、xshell链接kali一、进程管理 1.1 进程的基本概念 进程是程序的动态执行实例,每个进程都有唯一的 PID&#x…

python 可迭代对象相关知识点

1. 什么是可迭代对象 (Iterable) 在 Python 里,可迭代对象指的是: 👉 能够一次返回一个元素的对象,可以被 for 循环遍历。 常见的可迭代对象有: 序列类型:list、tuple、str集合类型:set、dict&a…

ijkplayer Android 编译

一、下载编译库文件1.1 编译库文件环境:ubuntu 20.04 版本liangtao:ffmpeg$lsb_release -a No LSB modules are available. Distributor ID: Ubuntu Description: Ubuntu 20.04.6 LTS Release: 20.04 Codename: focal1.2 项目源码下载使用 git 下载 ijkplayer&#…

snn前向推理时间计算(处理器实现)

公式 Tinf(1−sparsity)number of synapsesnumber of sub-processorsSIMD ways T_{\text{inf}} \frac{(1-\text{sparsity})\times \text{number of synapses}} {\text{number of sub-processors}\times \text{SIMD ways}} Tinf​number of sub-processorsSIMD ways(1−sparsity…

Linux------《操作系统全景速览:Windows·macOS·Linux·Unix 对比及 Linux 发行版实战指南》

(一)常见操作系统(system)电脑:Windows,Macos,Linux,UnixWindows:微软公司开发的一款桌面操作系统(闭源系统)。版本有dos,win98,win NT,win XP , …

Three.js 初级教程大全

本文档旨在为初学者提供一个全面的 Three.js 入门指南。我们将从 Three.js 的基本概念开始,逐步介绍如何创建场景、添加物体、设置材质、使用光照和相机,以及如何实现简单的动画和交互。通过本教程,你将能够掌握 Three.js 的核心知识&#xf…

遥感领域解决方案丨高光谱、无人机多光谱、空天地数据识别与计算

一:AI智慧高光谱遥感实战:手撕99个案例项目、全覆盖技术链与应用场景一站式提升方案在遥感技术飞速发展的今天,高光谱数据以其独特的光谱分辨率成为环境监测、精准农业、地质勘探等领域的核心数据源。然而,海量的波段数据、复杂的…

(LeetCode 面试经典 150 题) 114. 二叉树展开为链表 (深度优先搜索dfs+链表)

题目:114. 二叉树展开为链表 思路:深度优先搜索dfs链表,时间复杂度0(n)。 C版本: /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode() : …

《线程状态转换深度解析:从阻塞到就绪的底层原理》

目录 一、线程的五种基本状态 二、线程从 RUNNABLE 进入阻塞 / 等待状态的三种典型场景 1. 调用sleep(long millis):进入 TIMED_WAITING 状态 2. 调用wait():进入 WAITING/TIMED_WAITING 状态 3. 等待 I/O 资源或获取锁失败:进入 BLOCKE…

面经整理-猿辅导-内容服务后端-java实习

部门管理系统设计 题目要求 设计部门 MySQL 数据表实现接口:根据中间部门 ID 获取其下属叶子部门 ID设计包含子节点列表的 Java 数据对象,并实现批量获取功能 一、MySQL 部门表设计 表结构 CREATE TABLE department (id BIGINT PRIMARY KEY AUTO_INCREME…

Openharmony之window_manager子系统源码、需求定制详解

1. 模块概述 Window Manager 模块是 OpenHarmony 操作系统的核心窗口管理系统,负责窗口的创建、销毁、布局、焦点管理、动画效果以及与硬件显示的交互。该模块采用客户端-服务端架构,提供完整的窗口生命周期管理和用户界面交互支持。 1.1架构总览 Window Manager Client 应…

《CDN加速的安全隐患与解决办法:如何构建更安全的网络加速体系》

CDN(内容分发网络)作为提升网站访问速度的关键技术,被广泛应用于各类互联网服务中。然而,在享受加速优势的同时,CDN也面临诸多安全隐患。本文将解析常见的CDN安全问题,并提供实用的解决办法,帮助…

【Linux指南】GCC/G++编译器:庖丁解牛——从源码到可执行文件的奇幻之旅

不只是简单的 gcc hello.c 每一位Linux C/C++开发者敲下的第一行编译命令,几乎都是 gcc hello.c -o hello 或 g++ hello.cpp -o hello。这像一句神奇的咒语,将人类可读的源代码变成了机器可执行的二进制文件。但在这条简单的命令背后,隐藏着一个如同精密钟表般复杂的多步流…

地区电影市场分析:用Python爬虫抓取猫眼_灯塔专业版各地区票房

在当今高度数据驱动的影视行业,精准把握地区票房表现是制片方、宣发团队和影院经理做出关键决策的基础。一部电影在北上广深的表现与二三线城市有何差异?哪种类型的电影在特定区域更受欢迎?回答这些问题,不能再依赖“拍脑袋”和经…

Spark03-RDD02-常用的Action算子

一、常用的Action算子 1-1、countByKey算子 作用:统计key出现的次数,一般适用于K-V型的RDD。 【注意】: 1、collect()是RDD的算子,此时的Action算子,没有生成新的RDD,所以,没有collect()&…

[Android] 显示的内容被导航栏这挡住

上图中弹出的对话框的按钮“Cancel/Save”被导航栏遮挡了部分显示&#xff0c;影响了使用。Root cause: Android 应用的主题是 Theme.AppCompat.Light1. 修改 AndroidManifest.xml 将 application 标签的 android:theme 属性指向新的自定义主题&#xff1a;<applicationandr…

分贝单位全指南:从 dB 到 dBm、dBc

引言在射频、音频和通信工程中&#xff0c;我们经常会在示波器、频谱仪或测试报告里看到各种各样的dB单位&#xff0c;比如 dBm、dBc、dBV、dBFS 等。它们看起来都带个 dB&#xff0c;实则各有不同的定义和参考基准&#xff1a;有的表示相对功率&#xff0c;有的表示电压电平&a…

怎么确定mysql 链接成功了呢?

asyncio.run(test_connection()) ✗ Connection failed: cryptography package is required for sha256_password or caching_sha2_password auth methods 根据你提供的错误信息,问题出现在 MySQL 的认证插件和加密连接配置上。以下是几种解决方法: 1. 安装 cryptography 包…

(5)软件包管理器 yum | Vim 编辑器 | Vim 文本批量化操作 | 配置 Vim

Ⅰ . Linux 软件包管理器 yum01 安装软件在 Linux 下安装软件并不像 Windows 下那么方便&#xff0c;最通常的方式是去下载程序的源代码并进行编译&#xff0c;从而得到可执行程序。正是因为太麻烦&#xff0c;所以有些人就把一些常用的软件提前编译好并做成软件包&#xff0c;…