一、KNN算法介绍

K最近邻(K-Nearest Neighbor, KNN)算法是机器学习中最简单、最直观的分类算法之一。它既可以用于分类问题,也可以用于回归问题。KNN是一种基于实例的学习(instance-based learning)或懒惰学习(lazy learning)算法,因为它不会从训练数据中学习一个明确的模型,而是直接使用训练数据本身进行预测。

1.1 KNN算法原理

KNN算法的核心思想可以概括为:"物以类聚,人以群分"。对于一个待分类的样本,算法会在训练集中找到与之最相似的K个样本(即"最近邻"),然后根据这K个样本的类别来决定待分类样本的类别。

具体步骤如下:

  1. 计算待分类样本与训练集中每个样本的距离(通常使用欧氏距离)

  2. 选取距离最近的K个训练样本

  3. 统计这K个样本中每个类别出现的频率

  4. 将频率最高的类别作为待分类样本的预测类别

1.2 KNN算法的特点

优点:

  • 简单直观,易于理解和实现

  • 无需训练过程,新数据可以直接加入训练集

  • 对数据分布没有假设,适用于各种形状的数据分布

缺点:

  • 计算复杂度高,预测时需要计算与所有训练样本的距离

  • 对高维数据效果不佳(维度灾难)

  • 对不平衡数据敏感

  • 需要选择合适的K值和距离度量方式

二、Scikit-learn中的KNN实现

Scikit-learn提供了KNeighborsClassifier和KNeighborsRegressor两个类分别用于KNN分类和回归。下面我们重点介绍KNeighborsClassifier。

2.1 KNeighborsClassifier API详解

class sklearn.neighbors.KNeighborsClassifier(n_neighbors=5,          # K值,默认5weights='uniform',      # 权重函数algorithm='auto',       # 计算最近邻的算法leaf_size=30,           # KD树或球树的叶子节点大小p=2,                    # 距离度量参数(1:曼哈顿距离,2:欧氏距离)metric='minkowski',     # 距离度量类型metric_params=None,     # 距离度量的额外参数n_jobs=None             # 并行计算数
)
主要参数说明:
  1. n_neighbors (int, default=5)

    • K值,即考虑的最近邻的数量

    • 较小的K值会使模型对噪声更敏感,较大的K值会使决策边界更平滑

    • 通常通过交叉验证来选择最佳K值

  2. weights ({'uniform', 'distance'} or callable, default='uniform')

    • 'uniform': 所有邻居的权重相同

    • 'distance': 权重与距离成反比,距离越近的邻居影响越大

    • 也可以自定义权重函数

  3. algorithm ({'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto')

    • 计算最近邻的算法:

      • 'brute': 暴力搜索,计算所有样本的距离

      • 'kd_tree': KD树,适用于低维数据

      • 'ball_tree': 球树,适用于高维数据

      • 'auto': 自动选择最合适的算法

  4. leaf_size (int, default=30)

    • KD树或球树的叶子节点大小

    • 影响树的构建和查询速度

  5. p (int, default=2)

    • 距离度量的参数:

      • p=1: 曼哈顿距离

      • p=2: 欧氏距离

    • 仅当metric='minkowski'时有效

  6. metric (str or callable, default='minkowski')

    • 距离度量类型,可以是:

      • 'euclidean': 欧氏距离

      • 'manhattan': 曼哈顿距离

      • 'chebyshev': 切比雪夫距离

      • 'minkowski': 闵可夫斯基距离

      • 或自定义距离函数

  7. n_jobs (int, default=None)

    • 并行计算数

    • -1表示使用所有处理器

2.2 常用方法

  • fit(X, y): 拟合模型,只需要存储训练数据

  • predict(X): 预测X的类别

  • predict_proba(X): 返回X属于各类别的概率

  • kneighbors([X, n_neighbors, return_distance]): 查找点的K近邻

  • score(X, y): 返回给定测试数据和标签的平均准确率

三、KNN分类实战示例

3.1 基础示例:鸢尾花分类

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix# 加载数据集
iris = load_iris()
X = iris.data  # 特征 (150, 4)
y = iris.target  # 标签 (150,)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 创建KNN分类器
knn = KNeighborsClassifier(n_neighbors=5,      # 使用5个最近邻weights='uniform',   # 均匀权重algorithm='auto',    # 自动选择算法p=2                  # 欧氏距离
)# 训练模型(实际上只是存储数据)
knn.fit(X_train, y_train)# 预测测试集
y_pred = knn.predict(X_test)# 评估模型
print("分类报告:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))
print("\n混淆矩阵:")
print(confusion_matrix(y_test, y_pred))

3.2 进阶示例:手写数字识别 

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
import matplotlib.pyplot as plt
import numpy as np# 加载手写数字数据集
digits = load_digits()
X = digits.data  # (1797, 64)
y = digits.target  # (1797,)# 可视化一些样本
fig, axes = plt.subplots(2, 5, figsize=(10, 5))
for i, ax in enumerate(axes.flat):ax.imshow(X[i].reshape(8, 8), cmap='gray')ax.set_title(f"Label: {y[i]}")ax.axis('off')
plt.show()# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 创建管道:先标准化数据,再应用KNN
pipe = Pipeline([('scaler', StandardScaler()),  # 标准化特征('knn', KNeighborsClassifier())  # KNN分类器
])# 设置参数网格进行网格搜索
param_grid = {'knn__n_neighbors': [3, 5, 7, 9],  # 不同的K值'knn__weights': ['uniform', 'distance'],  # 两种权重方式'knn__p': [1, 2]  # 曼哈顿距离和欧氏距离
}# 创建网格搜索对象
grid = GridSearchCV(pipe, param_grid, cv=5,  # 5折交叉验证scoring='accuracy',  # 评估指标n_jobs=-1  # 使用所有CPU核心
)# 执行网格搜索
grid.fit(X_train, y_train)# 输出最佳参数和得分
print(f"最佳参数: {grid.best_params_}")
print(f"最佳交叉验证准确率: {grid.best_score_:.4f}")# 在测试集上评估最佳模型
best_model = grid.best_estimator_
test_score = best_model.score(X_test, y_test)
print(f"测试集准确率: {test_score:.4f}")# 可视化一些预测结果
sample_indices = np.random.choice(len(X_test), 10, replace=False)
sample_images = X_test[sample_indices]
sample_labels = y_test[sample_indices]
predicted_labels = best_model.predict(sample_images)plt.figure(figsize=(12, 3))
for i, (image, true_label, pred_label) in enumerate(zip(sample_images, sample_labels, predicted_labels)):plt.subplot(1, 10, i+1)plt.imshow(image.reshape(8, 8), cmap='gray')plt.title(f"True: {true_label}\nPred: {pred_label}", fontsize=8)plt.axis('off')
plt.tight_layout()
plt.show()

3.3 自定义距离度量示例 

from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import make_classification
import numpy as np# 自定义距离函数:余弦相似度
def cosine_distance(x, y):return 1 - np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))# 创建模拟数据
X, y = make_classification(n_samples=1000, n_features=20, n_classes=3, random_state=42
)# 使用自定义距离度量的KNN
knn_custom = KNeighborsClassifier(n_neighbors=5,metric=cosine_distance,  # 使用自定义距离algorithm='brute'         # 自定义距离需要暴力搜索
)# 使用标准KNN作为对比
knn_standard = KNeighborsClassifier(n_neighbors=5)# 划分训练集和测试集
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 训练并评估两个模型
knn_custom.fit(X_train, y_train)
knn_standard.fit(X_train, y_train)print(f"自定义距离KNN准确率: {knn_custom.score(X_test, y_test):.4f}")
print(f"标准KNN准确率: {knn_standard.score(X_test, y_test):.4f}")

四、KNN回归实战示例

总结:一句话区分核心差异

  • 分类:回答 “这是什么” 的问题,输出离散类别(如 “是垃圾邮件”);
  • 回归:回答 “这有多少” 的问题,输出连续数值(如 “房价 300 万”)。

虽然本文主要介绍分类问题,但KNN也可以用于回归。下面是一个简单的回归示例:

from sklearn.neighbors import KNeighborsRegressor
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt# 创建回归数据集
X, y = make_regression(n_samples=200, n_features=1, noise=10, random_state=42
)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 创建KNN回归器
knn_reg = KNeighborsRegressor(n_neighbors=5,weights='distance',  # 距离加权p=1                  # 曼哈顿距离
)# 训练模型
knn_reg.fit(X_train, y_train)# 预测
y_pred = knn_reg.predict(X_test)# 评估
mse = mean_squared_error(y_test, y_pred)
print(f"均方误差(MSE): {mse:.2f}")# 可视化结果
plt.scatter(X_test, y_test, color='blue', label='Actual')
plt.scatter(X_test, y_pred, color='red', label='Predicted')
plt.title('KNN Regression')
plt.xlabel('Feature')
plt.ylabel('Target')
plt.legend()
plt.show()

五、KNN调优技巧

5.1 K值选择

K值的选择对KNN性能有很大影响:

  • K值太小:模型对噪声敏感,容易过拟合

  • K值太大:模型过于简单,可能欠拟合

常用方法:

  1. 使用交叉验证选择最佳K值

  2. 经验法则:K通常取3-10之间的奇数(避免平票)

5.2 数据预处理

KNN对数据尺度敏感,通常需要:

  • 标准化:将特征缩放到均值为0,方差为1

  • 归一化:将特征缩放到[0,1]范围

5.3 维度灾难

高维空间中,所有点都趋向于远离彼此,导致距离度量失效。解决方法:

  • 特征选择:选择最相关的特征

  • 降维:使用PCA等方法降低维度

5.4 距离度量选择

不同距离度量适用于不同场景:

  • 欧氏距离:适用于连续变量

  • 曼哈顿距离:适用于高维数据或稀疏数据

  • 余弦相似度:适用于文本数据

六、总结

KNN是一种简单但强大的机器学习算法,特别适合小规模数据集和低维问题。通过Scikit-learn的KNeighborsClassifier,我们可以方便地实现KNN算法,并通过调整各种参数来优化模型性能。

关键点回顾:

  1. KNN是一种懒惰学习算法,没有显式的训练过程

  2. K值和距离度量的选择对模型性能至关重要

  3. 数据预处理(特别是标准化)对KNN非常重要

  4. 高维数据中KNN可能表现不佳,需要考虑降维

希望本教程能帮助你理解和应用KNN算法。在实际应用中,记得结合交叉验证和网格搜索来找到最佳参数组合。

 

 

 

 

 

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/89424.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/89424.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/89424.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【PMP】项目管理入门:从基础到环境的体系化拆解

不少技术管理者都有过这样的困惑: 明明按流程做了项目管理,结果还是延期、超预算?需求变更多到炸,到底是客户无理还是自己没管好?跨部门协作像“推皮球”,资源总不到位? 其实,项目失…

【Web前端】简易轮播图显示(源代码+解析+知识点)

一、简易轮播图源代码 <!DOCTYPE html> <html><head><meta charset"utf-8"><title>简易轮播图显示</title><style type"text/css">*{margin: 0 auto;padding: 0;/* 全局重置边距 */}p{text-align: center;fon…

电机试验平台的用途及实验范围

电机试验平台是一种专门设计用来对各种类型的电机进行测试和分析的设备。在现代工业中&#xff0c;电机作为驱动力的重要组成部分&#xff0c;在各个领域发挥着至关重要的作用。而为了确保电机的性能、效率和可靠性达到最佳水平&#xff0c;需要进行各种试验和测试。电机试验平…

自主/智能的本质内涵及其相互关系

论文地址&#xff1a;无人机的自主与智能控制 - 中国知网 (cnki.net) 自主/智能的本质内涵及其相互关系准则是无人机设计的基本原则。从一般意义上讲。自主与智能是两个不同范畴的概念。自主表达的是行为方式&#xff0c;由自身决策完成某行为则称之为“自主”。“智能”…

nignx+Tomcat+NFS负载均衡加共享储存服务脚本

本次使有4台主机&#xff0c;系统均为centos7&#xff0c;1台nignx&#xff0c;2台tomcat&#xff0c;1台nfs 第一台配置nignx脚本 #!/bin/bash #xiexianghu 2025.6.24 #nignx配置脚本&#xff0c;centos7#关闭防火墙和SElinux systemctl stop firewalld && system…

zabbix监控Centos 服务器

1.2&#xff1a;本地安装 先使用wget下载RPM安装包 然后解压安装 >>wget https://repo.zabbix.com/zabbix/6.4/rhel/8/x86_64/zabbix-agent2-6.4.21-release1.el8.x86_64.rpm ##### CENTOS 8 使用这一条>>rpm -ivh zabbix-agent2-6.4.21-release1.el8.x86_64.r…

中科米堆三维扫描仪3D扫描压力阀抄数设计

三维扫描技术正以数字化手段重塑传统制造流程。以压力阀这类精密流体控制元件为例&#xff0c;其内部流道结构的几何精度直接影响设备运行稳定性与使用寿命。 在传统设计模式下&#xff0c;压力阀的逆向工程需经历手工测绘、二维图纸绘制、三维建模转换等多个环节。技术人员需…

Python pytz 时区使用举例

Python pytz 时区使用举例 ⏰ 一、Python代码实现&#xff1a;时区转换与时间比较 import pytz from datetime import datetime# 1. 获取当前UTC时间 utc_now datetime.now(pytz.utc)# 2. 转换为目标时区&#xff08;示例&#xff1a;上海和纽约&#xff09; shanghai_tz py…

vue中ref()和reactive()区别

好的&#xff0c;这是 Vue 3 中 ref() 和 reactive() 这两个核心响应式 API 之间区别的详细解释。 简单来说&#xff0c;它们是创建响应式数据的两种方式&#xff0c;主要区别在于处理的数据类型和访问数据的方式。 核心区别速查表 特性ref()reactive()适用类型✅ 任何类型 …

目标检测数据集——交通信号灯红绿灯检测数据集

在智能交通系统迅速发展的背景下&#xff0c;准确且实时地识别交通信号灯的状态对于提升道路安全和优化交通流量管理至关重要。 无论是自动驾驶汽车还是辅助驾驶技术&#xff0c;可靠地检测并理解交通信号灯的指示——特别是红灯与绿灯的区别——是确保交通安全、避免交通事故…

哪款即时通讯服务稳定性靠谱?18家对比

本文将深入对比18家 IM 服务商&#xff1a;1.网易云信; 2. 有度即时通; 3. 环信; 4. 小天互连; 5. 企达即时通讯; 6. 敏信即时通讯; 7. 360织语; 8. 容联云通讯; 9. 云之讯 UCPaaS等。 在如今的数字化时代&#xff0c;即时通讯&#xff08;IM&#xff09;软件已经成为企业日常运…

【Android】Flow学习及使用

目录 前言基础基本用法概念与核心特点Android中使用与LiveData对比热流StateFlow、SharedFlow 搜索输入流实现实时搜索 前言 ​ Flow是kotlin协程库中的一个重要组成部分&#xff0c;它可以按顺序发送多个值&#xff0c;用于对异步数据流进行处理。所谓异步数据流就是连续的异…

idea常做的配置改动和常用插件

IDEA 使用 最强教程&#xff0c;不多不杂。基于idea旗舰版 2019.2.3左右的版本&#xff0c;大多数是windows的&#xff0c;少数是mac版的 一、必改配置 1、ctrl滚轮 调整字体大小 全局立即生效&#xff1a;settings -> Editor -> General -> Change font size with …

3. 物理信息神经网络(PINNs)和偏微分方程(PDE),用物理定律约束神经网络

导言&#xff1a;超越时间&#xff0c;拥抱空间 在前两篇章中&#xff0c;我们已经走过了漫长而深刻的旅程。我们学会了用常微分方程&#xff08;ODE&#xff09;来描述事物如何随时间演化&#xff0c;从一个初始状态出发&#xff0c;描绘出一条独一无二的生命轨迹。我们还学会…

Flutter基础(基础概念和方法)

概念比喻StatefulWidget会变魔术的电视机State电视机的小脑袋&#xff08;记信息&#xff09;build 方法电视机变身显示新画面setState按遥控器按钮改变状态Scaffold电视机的外壳 StatefulWidget&#xff1a;创建一个按钮组件。State&#xff1a;保存点赞数&#xff08;比如 i…

K8s——Pod(1)

目录 基本概念 ‌一、Pod 的原理‌ ‌二、Pod 的特性‌ ‌三、Pod 的意义‌ 状态码详解 ‌一、Pod 核心状态详解‌ ‌二、其他关键状态标识‌ ‌三、状态码运维要点‌ 探针 ‌一、探针的核心原理‌ ‌二、三大探针的特性与作用‌ ‌参数详解‌ ‌三、探针的核心意义…

MySQL 存储过程面试基础知识总结

文章目录 MySQL 存储过程面试基础知识总结一、存储过程基础&#xff08;一&#xff09;概述1.优点2.缺点 &#xff08;二&#xff09;创建与调用1.创建存储过程2.调用存储过程3.查看存储过程4.修改存储过程5.存储过程权限管理 &#xff08;三&#xff09;参数1.输入参数2.输出参…

NLP文本数据增强

文章目录 文本数据增强同义词替换示例Python代码示例 随机插入示例Python代码示例 随机删除示例Python代码示例 回译&#xff08;Back Translation&#xff09;示例Python代码示例 文本生成模型应用方式示例Python代码示例 总结 文本数据增强 数据增强通过对原始数据进行变换、…

(LeetCode 每日一题) 594. 最长和谐子序列 (哈希表)

题目&#xff1a;594. 最长和谐子序列 思路&#xff1a;哈希表&#xff0c;时间复杂度0(n)。 用哈希表mp来记录每个元素值出现的次数&#xff0c;然后枚举所有值x&#xff0c;看其x1是否存在&#xff0c;存在的话就可以维护最长的子序列长度mx。 C版本&#xff1a; class Sol…

FreePDF:让看英文文献像喝水一样简单

前言 第一次看英文文献&#xff0c;遇到不少看不懂的英文单词&#xff0c;一个个查非常费劲。 后来&#xff0c;学会了使用划词翻译&#xff0c;整段整段翻译查看&#xff0c;极大提升看文献效率。 最近&#xff0c;想到了一种更快的看文献的方式&#xff0c;那就是把英文PD…