目录

用 SVM 实现鸢尾花数据集分类:从代码到可视化全解析

一、算法原理简述

二、完整代码实现

三、代码解析

1. 导入所需库

2. 加载并处理数据

3. 划分训练集和测试集

4. 训练 SVM 模型

5. 计算决策边界参数

6. 生成决策边界数据

7. 绘制样本点

8. 绘制决策边界

9. 设置坐标轴范围

10. 标记支持向量

11. 显示图像

用 SVM 实现鸢尾花数据集分类:从代码到可视化全解析

支持向量机(SVM)是一种经典的机器学习算法,特别适合处理小样本、高维空间的分类问题。本文将通过鸢尾花(Iris)数据集,从零开始实现基于 SVM 的分类任务,并通过可视化直观展示分类效果。

一、算法原理简述

SVM 的核心思想是寻找最大间隔超平面,通过这个超平面将不同类别的数据分开。对于线性可分的数据,存在无数个可分超平面,SVM 会选择距离两类数据最近点(支持向量)距离最大的那个超平面,从而获得更好的泛化能力。

当数据线性不可分时,SVM 可以通过核函数将低维数据映射到高维空间,使其在高维空间中线性可分。本文使用线性核(kernel='linear')进行演示,适合处理线性可分的鸢尾花数据集。

二、完整代码实现

下面是基于鸢尾花数据集的 SVM 分类完整代码,包含数据加载、模型训练、决策边界可视化等功能:部分数据集如下:

import pandas as pd
from sklearn.svm import SVC
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn import metrics# 1. 加载数据
f = pd.read_csv('iris.csv')  # 读取鸢尾花数据集# 2. 数据划分(按类别拆分用于可视化)
data = f.iloc[:50,:]   # 第一类数据(前50条)
data1 = f.iloc[50:,:]  # 后两类数据(第50条之后)# 3. 准备特征和标签
x = f.iloc[:,[1,3]]    # 选择第2列和第4列作为特征(萼片宽度和花瓣宽度)
y = f.iloc[:,-1]       # 最后一列为标签(花的类别)# 4. 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=0)  # 20%数据作为测试集# 5. 初始化并训练SVM模型
svm = SVC(kernel='linear', C=1, random_state=0)  # 线性核,正则化参数C=1
svm.fit(x_train, y_train)# 6. 获取模型参数(用于绘制决策边界)
w = svm.coef_[0]  # 权重系数
b = svm.intercept_[0]  # 偏置项# 7. 生成决策边界数据
x1 = np.linspace(0,7,300)  # 生成300个从0到7的均匀点
x2 = -(w[0]*x1 + b)/w[1]   # 决策边界公式:w0*x1 + w1*x2 + b = 0 → 求解x2
x3 = 1 + x2  # 边界1(决策边界+1)
x4 = -1 + x2 # 边界2(决策边界-1)# 8. 绘制散点图(样本点)
plt.scatter(data.iloc[:,1], data.iloc[:,3], marker='+', color='b', label='第一类')
plt.scatter(data1.iloc[:,1], data1.iloc[:,3], marker='*', color="r", label='其他类别')# 9. 绘制决策边界
plt.plot(x1, x3, linewidth=1, color='r', linestyle='--', label='边界1')
plt.plot(x1, x2, linewidth=2, color='r', label='决策边界')
plt.plot(x1, x4, linewidth=1, color='r', linestyle='--', label='边界2')# 10. 设置坐标轴范围
plt.xlim(4,7)
plt.ylim(0,5)# 11. 标记支持向量
vets = svm.support_vectors_  # 获取支持向量
plt.scatter(vets[:,0], vets[:,1], c='b', marker='x', label='支持向量')# 12. 添加图例和标题
plt.legend()
plt.title('SVM分类鸢尾花数据集(线性核)')
plt.show()# 13. 模型评估
y_pred = svm.predict(x_test)
print("模型准确率:", metrics.accuracy_score(y_test, y_pred))

三、代码解析

1. 导入所需库

import pandas as pd                  # 用于数据处理和分析
from sklearn.svm import SVC          # 从sklearn库导入支持向量分类器
import numpy as np                   # 用于数值计算
import matplotlib.pyplot as plt      # 用于数据可视化
from sklearn.model_selection import train_test_split  # 用于划分训练集和测试集
from sklearn import metrics          # 用于模型评估

2. 加载并处理数据

f = pd.read_csv('iris.csv')  # 读取鸢尾花数据集(CSV格式)

鸢尾花数据集包含 100 条样本,分为 2类鸢尾花,每条样本有 4 个特征(花萼长度、花萼宽度、花瓣长度、花瓣宽度)和 1 个标签(花的类别)。

data = f.iloc[:50,:]   # 取前50条数据(第一类鸢尾花,通常是setosa)
data1 = f.iloc[50:,:]  # 取第50条之后的数据(后两类鸢尾花,通常是versicolor和virginica)

这里按行索引拆分数据,用于后续可视化时区分不同类别。

x = f.iloc[:,[1,3]]    # 选择特征:取所有行的第2列(索引1)和第4列(索引3)
y = f.iloc[:,-1]       # 选择标签:取所有行的最后一列(花的类别)
  • 特征选择第 2 列和第 4 列(通常对应花萼宽度和花瓣宽度),便于二维可视化
  • 标签为最后一列(花的种类)

3. 划分训练集和测试集

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=0)  # 划分数据集
  • x_train:训练集特征(80% 的数据)
  • x_test:测试集特征(20% 的数据)
  • y_train:训练集标签
  • y_test:测试集标签
  • test_size=0.2:测试集占比 20%
  • random_state=0:随机种子,保证每次运行划分结果一致

4. 训练 SVM 模型

svm = SVC(kernel='linear', C=1, random_state=0)  # 初始化SVM模型
svm.fit(x_train, y_train)  # 用训练集训练模型
  • kernel='linear':使用线性核函数(适用于线性可分数据)
  • C=1:正则化参数,控制对误分类的惩罚程度(值越大惩罚越重)
  • random_state=0:随机种子,保证结果可复现
  • fit():训练模型,通过训练集学习特征与标签的关系

5. 计算决策边界参数

w = svm.coef_[0]  # 获取权重系数(对于线性核,shape为[特征数])
b = svm.intercept_[0]  # 获取偏置项(截距)

对于线性 SVM,决策边界是一个超平面,二维情况下是一条直线,公式为:

(其中(w0, w1)是权重,b是偏置,(x1, x2)是两个特征)

6. 生成决策边界数据

x1 = np.linspace(0,7,300)  # 生成300个从0到7的均匀点(作为x轴数据)
x2 = -(w[0]*x1 + b)/w[1]   # 计算决策边界的y值(由超平面公式推导)
x3 = 1 + x2  # 决策边界上方的辅助线(间隔边界)
x4 = -1 + x2 # 决策边界下方的辅助线(间隔边界)
  • x1是横轴坐标,x2是决策边界在对应x1处的纵轴坐标
  • x3x4是决策边界两侧的间隔边界,用于展示 SVM 的 "最大间隔" 特性

7. 绘制样本点

# 绘制第一类样本(蓝色+号)
plt.scatter(data.iloc[:,1], data.iloc[:,3], marker='+', color='b')
# 绘制后两类样本(红色*号)
plt.scatter(data1.iloc[:,1], data1.iloc[:,3], marker='*', color="r")
  • scatter():绘制散点图
  • data.iloc[:,1]data.iloc[:,3]:分别取第一类样本的第 2 列和第 4 列特征作为 x、y 坐标
  • marker:指定点的形状(+ 号和 * 号区分不同类别)

8. 绘制决策边界

plt.plot(x1, x3, linewidth=1, color='r', linestyle='--')  # 上方间隔边界(虚线)
plt.plot(x1, x2, linewidth=2, color='r')                  # 决策边界(实线)
plt.plot(x1, x4, linewidth=1, color='r', linestyle='--')  # 下方间隔边界(虚线)
  • plot():绘制直线
  • 红色实线是 SVM 找到的最优决策边界,虚线是间隔边界,两条虚线之间的距离是 "最大间隔"

9. 设置坐标轴范围

plt.xlim(4,7)  # x轴范围设置为4到7
plt.ylim(0,5)  # y轴范围设置为0到5

限制坐标轴范围,使图像聚焦在样本密集区域,更清晰地展示分类效果。

10. 标记支持向量

vets = svm.support_vectors_  # 获取支持向量(距离决策边界最近的样本点)
plt.scatter(vets[:,0], vets[:,1], c='b', marker='x')  # 用x标记支持向量
  • 支持向量是决定决策边界位置的关键样本,SVM 的决策仅由这些点决定
  • vets[:,0]vets[:,1]:支持向量的两个特征值

11. 显示图像

plt.show()  # 显示绘制的图像

总结

这段代码的核心逻辑是:

  1. 加载鸢尾花数据集并选择特征
  2. 划分训练集和测试集
  3. 训练线性 SVM 模型
  4. 计算并绘制决策边界、间隔边界
  5. 可视化样本点和支持向量

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/92381.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/92381.shtml
英文地址,请注明出处:http://en.pswp.cn/web/92381.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度虚值期权合约有什么特点?

本文主要介绍深度虚值期权合约有什么特点?深度虚值期权合约是期权市场中一类特殊且风险收益特征鲜明的合约,其核心特点可归纳为以下六点。深度虚值期权合约有什么特点?一、定义:执行价与标的价差距极大深度虚值期权是指执行价&…

(LeetCode 面试经典 150 题) 86. 分隔链表(链表+双指针)

题目:86. 分隔链表 思路:双指针,时间复杂度0(n)。 双指针来维护小于x的链表和不小于x的链表即可,后面将两个链表连起来即可。 C版本: /*** Definition for singly-linked list.* struct ListNode {* int val;* …

安全扫描:检测到目标站点存在javascript框架库漏洞问题(vue)

如果升级Vue版本有限制或者时间比较紧急,可以暂时用下面方式来,规避检测到目标站点存在javascript框架库vue漏洞。 在 vue.config.js 中配置: module.exports {configureWebpack: {optimization: {minimizer: [new (require(terser-webpack-plugin))({t…

【性能测试】-2- JMeter工具的使用

一、基础脚本搭建🎯 核心目标:用 JMeter 搭性能测试 “骨架”(测试计划、线程组、取样器、监听器)📋 测试计划:整个测试的 “容器”,所有测试组件都放这里🧵 线程组:模拟…

从零构建桌面写作软件的书籍管理系统:Electron + Vue 3 实战指南

📚 从零构建桌面写作软件的书籍管理系统:Electron Vue 3 实战指南 💡 本文深入探讨了基于 Electron Vue 3 技术栈的桌面写作软件中书籍管理系统的设计与实现,涵盖了书籍的创建、编辑、删除等核心功能的完整技术方案,…

2025华数杯数学建模A题【 多孔膜光反射性能的优化与控制】原创论文讲解(含完整python代码)

大家好呀,从发布赛题一直到现在,总算完成了2025华数杯数学建模A题【 多孔膜光反射性能的优化与控制】完整的成品论文。本论文可以保证原创,保证高质量。绝不是随便引用一大堆模型和代码复制粘贴进来完全没有应用糊弄人的垃圾半成品论文。A题论…

网络相关(AI回答)

网管交换机都支持带tag的以太帧吗DeepSeek-R1 回答完成深度思考完成 用时24秒,搜索全网20篇资料是的,所有网管交换机都支持处理和转发带Tag(802.1Q VLAN标签)的以太网帧,这是其核心功能之一。以下是关键要点&#xff1…

每日五个pyecharts可视化图表-bars(5)

探索pyecharts库中条形图的高级用法与定制技巧,打造精美数据可视化图表 在数据可视化中,条形图是最常用的图表类型之一。它能够清晰地展示不同类别的数据对比,帮助我们快速理解数据特征。本文将为您介绍pyecharts库中条形图的5种高级用法&…

分布式版本控制工具Git

一.开发中为什么需要Git因为在多人开发中Git可以管理代码,而且每个人都可以从库里面下载代码进行修改,每个人上传和修改Git都会有记录,如果出现大错误,还可以回退到正常版本。二.Git原理我们首先从代码库(Remote)下载代码到工作区…

OpenAI重磅开源GPT-oss:首款支持商用的AI Agent专属模型

今日凌晨,OpenAI宣布开源两款全新大模型——GPT-oss-120B(1168亿参数)与GPT-oss-20B(209亿参数),成为全球首个支持商业化应用的开放权重推理模型。该模型专为AI智能体(Agent)设计&am…

【STM32】GPIO的输入输出

GPIO是通用的输入输出接口,可配置8种输入模式,输出模式下可控制端口输出高低电平,用于点亮LED、控制蜂鸣器、模拟通信协议等;输入模式下可以读取端口的高低电平或者电压,用于读取按键、外接模块的电平信号、ADC的电压采…

5分钟了解OpenCV

在数字化时代,图像和视频已经成为信息传递的核心载体。从手机拍照的美颜功能到自动驾驶的路况识别,从医学影像分析到安防监控系统,视觉技术正深刻改变着我们的生活。而在这背后,OpenCV 作为一款强大的开源计算机视觉库&#xff0c…

Oracle 关闭 impdp任务

Oracle 关闭 impdp任务 执行 impdp system/123456 attachSYS_EXPORT_TABLE_01 执行 stop_jobimmediate

数据结构——链表2

1.2 实现单链表 在上一篇文章中&#xff0c;单链表的实现只有一少部分&#xff0c;这一篇接着来了解单链表剩下的接口实现。 SList.h#pragma once #include<stdio.h> #include<stdlib.h> #include<assert.h>//定义单链表就是定义节点&#xff0c;因为单链表…

Windows和Linux应急响应以及IP封堵

目录 1、Windows入侵排查思路 1.1 检查系统账号安全 1.2 检查异常端口、进程 1.3 检查启动项、计划任务、服务 1.4 检查系统相关信息 1.5 自动化查杀 1.6 日志分析 系统日志分析 Web 访问日志 2、Linux 入侵排查思路 2.1 账号安全 2.1.1、基本使用 2.1.2、入侵排查…

MIT成果登上Nature!液态神经网络YYDS

2025深度学习发论文&模型涨点之——液态神经网络液态神经网络&#xff08;Liquid Neural Networks&#xff0c;LNN&#xff09;是一种受生物神经系统启发的连续时间递归神经网络&#xff08;RNN&#xff09;&#xff0c;其核心创新在于将静态神经网络转化为由微分方程驱动的…

AI 对话高效输入指令攻略(四):AI+Apache ECharts:生成各种专业图表

- **AI与数据可视化的革命性结合**:介绍AI如何降低数据可视化门槛,提升效率。 - **Apache ECharts:专业可视化的利器**:使用表格对比展示ECharts的特点、优势和适用场景。 - **四步实现AI驱动图表生成**:通过分步指南讲解从环境准备到图表优化的全流程,包含多个代码示例及…

vue2 基础学习 day04 (结构/样式/逻辑、组件通信、进阶语法)下

一、非父子通信-event bus 事件总线1.作用非父子组件之间&#xff0c;进行简易消息传递。(复杂场景→ Vuex)2.步骤创建一个都能访问的事件总线 &#xff08;空Vue实例&#xff09;import Vue from vue const Bus new Vue() export default BusA组件&#xff08;接受方&#xf…

ubuntu 20.04 C和C++的标准头文件都放在哪个目录?

在 Ubuntu 20.04 中&#xff0c;C 和 C 标准头文件的存放目录主要由编译器&#xff08;如 GCC&#xff09;的安装路径决定&#xff0c;通常分为以下两类&#xff1a;​1. C 标准头文件​C 语言的标准头文件&#xff08;如 <stdio.h>、<stdlib.h> 等&#xff09;默认…

change和watch

是的&#xff0c;你理解得很对&#xff01; change 与 v-model 的结合&#xff1a;change 事件通常用于监听 表单元素的变化&#xff0c;但它并不一定意味着值发生了变化。它主要是当 用户与输入框交互时&#xff08;如点击选项、选择文本框内容、提交表单等&#xff09;触发的…