学习笔记(29):训练集与测试集划分详解:train_test_split 函数深度解析

一、为什么需要划分训练集和测试集?

在机器学习中,模型需要经历两个核心阶段:

  1. 训练阶段:用训练集数据学习特征与目标值的映射关系(如线性回归的权重)。
  2. 测试阶段:用测试集评估模型在未见过的数据上的表现,避免 “过拟合”(模型只记住训练数据的噪声,无法泛化到新数据)。

类比场景:学生通过 “练习题”(训练集)学习知识,再通过 “考试题”(测试集)检验真实水平。

二、train_test_split 函数的核心参数与逻辑
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42
)
1. 输入参数解析
  • X_scaled:特征矩阵(已标准化的面积、房龄等特征)。
  • y:目标变量(房价)。
  • test_size=0.2:测试集占总数据的比例(20%),也可设为整数(如 test_size=20 表示取 20 个样本)。
  • random_state=42:随机种子,确保每次划分结果一致(与 np.random.seed(42) 作用类似)。
2. 划分逻辑
  • 随机抽样:按 test_size 比例从原始数据中随机抽取样本作为测试集,剩余作为训练集。
  • 数据对齐:确保 X 和 y 的样本顺序一一对应(如第 i 个特征向量对应第 i 个房价标签)。
三、划分结果的维度与含义

假设原始数据有 100 个样本(n_samples=100):

  • 训练集:80 个样本(X_train.shape=(80, 2)y_train.shape=(80,)),用于模型学习。
  • 测试集:20 个样本(X_test.shape=(20, 2)y_test.shape=(20,)),用于评估模型泛化能力。
四、关键参数深度解析
1. test_size:平衡训练与测试的样本量
  • 取值建议
    • 小数据集(<1000 样本):常用 test_size=0.2~0.3(20%-30% 作为测试集)。
    • 大数据集(>10000 样本):可设 test_size=0.1 甚至更低(因少量样本已足够评估)。
  • 极端案例:若 test_size=1.0,则所有数据都是测试集,无训练集;若 test_size=0,则全是训练集。
2. random_state:确保可复现的 “随机” 划分
  • 作用:固定随机种子后,每次运行代码时,训练集和测试集的样本索引完全相同。
  • 示例对比
    • 不设置 random_state:每次划分结果不同,导致模型评估指标波动。
    • 设置 random_state=42:多次运行代码,划分结果一致,便于对比不同模型效果。
3. shuffle=True(默认参数):打乱数据顺序
  • 为什么需要打乱?
    若数据按顺序排列(如前 50 个是小户型,后 50 个是大户型),不打乱会导致训练集和测试集样本分布不均(如测试集全是大户型)。
  • 参数设置train_test_split 默认为 shuffle=True,即先打乱数据再划分;若数据已随机排列,可设 shuffle=False
五、进阶应用:分层抽样(Stratified Sampling)

当目标变量是分类变量(如二分类 “是否违约”)时,普通随机划分可能导致训练 / 测试集的类别比例失衡(如测试集全是 “违约” 样本)。此时需用 StratifiedShuffleSplit 实现分层抽样:

from sklearn.model_selection import StratifiedShuffleSplit# 4. 使用分层抽样(确保类别比例平衡)
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_idx, test_idx in sss.split(X_scaled, y_binary):X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]y_train, y_test = y_binary[train_idx], y_binary[test_idx]print("===== 分类模型结果 =====")
print(f"原始数据类别比例:{np.bincount(y_binary)/len(y_binary)}")
print(f"训练集类别比例:{np.bincount(y_train)/len(y_train)}")
print(f"测试集类别比例:{np.bincount(y_test)/len(y_test)}")
六、实战误区与注意事项
  1. 禁止在测试集上训练:测试集只能用于评估,若根据测试集结果调整模型参数(如调优正则化系数),本质上是 “偷看答案”,会导致评估结果过于乐观。
  2. 数据标准化的顺序
    • 正确流程:先划分训练测试集,再对训练集拟合标准化器(scaler.fit(X_train)),最后用训练集的标准化参数转换测试集(scaler.transform(X_test))。
    • 错误操作:对全量数据标准化后再划分,会导致测试集 “偷看到” 全量数据的统计特征,违反 “未知数据” 假设。
  3. 多轮划分与交叉验证:当数据量较小时,可使用 K 折交叉验证(如 10 折),将数据分成 10 份,每次用 9 份训练、1 份测试,重复 10 次取平均,减少单次划分的随机性误差。
七、总结:划分训练测试集的核心原则
  1. 独立性:测试集数据必须是模型未见过的,模拟真实应用场景。
  2. 代表性:训练集和测试集的样本分布应尽可能一致(如特征取值范围、类别比例)。
  3. 可复现性:通过设置随机种子,确保实验结果可重复验证。

通过合理划分训练集与测试集,你可以更准确地评估模型的实际能力,避免被 “过拟合” 的假象误导 —— 这是机器学习工程化中至关重要的一步!

二分类问题(房价是否高于中位数)-全代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score# 配置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 1. 生成模拟数据(假设房价与面积、房龄的关系)
np.random.seed(42)
n_samples = 100
# 面积(平方米),房龄(年)
X = np.random.rand(n_samples, 2) * 100
X[:, 0] = X[:, 0]  # 面积范围:0-100
X[:, 1] = X[:, 1]  # 房龄范围:0-100# 真实房价 = 5000*面积 + 1000*房龄 + 随机噪声(模拟真实场景)
y = 5000 * X[:, 0] + 1000 * X[:, 1] + np.random.randn(n_samples) * 10000# 2. 将连续的房价y转换为分类标签(例如分为低、中、高3个类别)
y_category = pd.qcut(y, q=3, labels=[0, 1, 2])  # 使用pandas的qcut进行分位数切割# 3. 数据预处理:标准化特征
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 4. 使用分层抽样
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_idx, test_idx in sss.split(X_scaled, y_category):X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]y_train, y_test = y[train_idx], y[test_idx]  # 注意:这里仍然使用原始的连续房价作为目标_# 确保训练集和测试集的类别比例与原始数据一致
print(f"原始数据类别比例:{np.bincount(y_category)/len(y_category)}")
print(f"训练集类别比例:{np.bincount(y_category[train_idx])/len(y_category[train_idx])}")
print(f"测试集类别比例:{np.bincount(y_category[test_idx])/len(y_category[test_idx])}")# 后续回归模型训练和评估代码保持不变
model = LinearRegression()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)# 评估模型
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"均方误差: {mse:.2f}")
print(f"决定系数R²: {r2:.2f}")

打印:

原始数据类别比例:[0.34 0.32 0.34]
训练集类别比例:[0.3375 0.325  0.3375]
测试集类别比例:[0.35 0.3  0.35]
均方误差: 101112597.45
决定系数R²: 1.00

代码解析:
核心步骤解析
  1. 数据准备与二分类转换

    • 生成与方案 1 相同的模拟数据(面积、房龄 → 房价)。
    • 将连续的房价y转换为二分类标签:
threshold = np.median(y)  # 使用中位数作为阈值
y_binary = (y > threshold).astype(int)  # 0=低于中位数,1=高于中位数
  1. 这样做的目的是将 “预测具体房价” 转化为 “判断房价高低”。

分层抽样(Stratified Sampling)

  • 使用StratifiedShuffleSplit确保训练集和测试集中高低房价的比例与原始数据一致:
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_idx, test_idx in sss.split(X_scaled, y_category):X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]y_train, y_test = y[train_idx], y[test_idx]  # 注意:这里仍然使用原始的连续房价作为目标_

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/87357.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/87357.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/87357.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【全网唯一】自动化编辑器 Windows版纯本地离线文字识别插件

目的 自动化编辑器超轻量级RPA工具&#xff0c;零代码制作RPA自动化任务&#xff0c;解放双手&#xff0c;释放双眼&#xff0c;轻松玩游戏&#xff0c;刷任务。本篇文章主要讲解下自动化编辑器的TomatoOCR纯本地离线文字识别Windows版插件如何使用和集成。 准备工作 1、下载自…

GitHub 2FA绑定

GitHub 2FA绑定 作为全球最大的代码托管平台&#xff0c;GitHub对账号安全的重视程度不断提升——自2023年3月起&#xff0c;GitHub已要求所有在GitHub.com上贡献代码的用户必须启用双因素身份验证&#xff08;2FA&#xff09;。如果你是符合条件的用户&#xff0c;会收到一封…

pytest fixture基础大全详解

一、介绍 作用 fixture主要有两个作用&#xff1a; 复用测试数据和环境&#xff0c;可以减少重复的代码&#xff1b;可以在测试用例运行前和运行后设置和清理资源&#xff0c;避免对测试结果产生影响&#xff0c;同时也可以提高测试用例的运行效率。 优势 pytest框架的fix…

Unity知识点-Renderer常用材质变量

本篇总结了Unity中renderer的3种常用的材质相关的变量&#xff1a;renderer.material,renderer.sharedMaterial,renderer.MaterialPropertyBlock。以及三者对SRPBatcher的影响。 一.介绍及对比 1.概念介绍 1.material 定义&#xff1a;material 是Render组件&#xff08;如…

【算法】​​如何判断时间复杂度?

文章目录 1. 什么是时间复杂度&#xff1f;为什么需要时间复杂度&#xff1f; 2. 常见时间复杂度对比3. 如何分析时间复杂度&#xff1f;&#xff08;Java版&#xff09;&#x1f539; 步骤1&#xff1a;找出基本操作&#x1f539; 步骤2&#xff1a;分析循环结构&#xff08;1…

MySQL使用C语言连接

文章目录 版本查看以及编译mysql接口介绍初始化链接数据库下发mysql命令mysql_query获取执行结果mysql_store_result获取结果行数mysql_num_rows获取结果列数mysql_num_fields获取列名mysql_fetch_fields获取结果内容mysql_fetch_row关闭mysql链接mysql_closeC语言操作mysql查看…

坚持每日Codeforces三题挑战:Day 7 - 题目详解(2025-06-11,难度:1200,1300,1500)

每天坚持写三道题第七天&#xff1a; Problem - A - Codeforces 1200 Problem - B - Codeforces 1300 Problem - A - Codeforces 1500 目录 题目一: 题目大意: 解题思路: 代码(C): 题目二: 题目大意: 解题思路: 代码(C): 题目三: 题目大意: 解题思路: 代码(C): …

洛谷 P4305:[JLOI2011] 不重复数字 ← unordered_set

【题目来源】 https://www.luogu.com.cn/problem/P4305 【题目描述】 给定 n 个数&#xff0c;要求把其中重复的去掉&#xff0c;只保留第一次出现的数。 【输入格式】 第一行一个整数 T&#xff0c;表示数据组数。 对于每组数据&#xff0c;第一行一个整数 n。第二行 n 个数…

STM32固件升级设计——SPIFLASH模拟U盘升级固件

目录 概述 一、功能描述 1、BootLoader部分&#xff1a; 2、APP部分&#xff1a; 二、BootLoader程序制作 1、分区定义 2、 主函数 3、配置USB 4、配置fatfs文件系统 5、程序跳转 三、APP程序制作 四、工程配置&#xff08;默认KEIL5&#xff09; 五、运行测试 六…

解锁阿里云日志服务SLS:云时代的日志管理利器

引言&#xff1a;开启日志管理新篇 在云计算时代&#xff0c;数据如同企业的血液&#xff0c;源源不断地产生并流动。从用户的每一次点击&#xff0c;到系统后台的每一个操作&#xff0c;数据都在记录着企业运营的轨迹。而在这些海量的数据中&#xff0c;日志数据占据着至关重…

Keye-VL-8B-Preview:由快手 Kwai Keye 团队精心打造的尖端多模态大语言模型

&#x1f525; News 2025.06.26 &#x1f31f; 我们非常自豪地推出Kwai Keye-VL&#xff0c;这是快手Kwai Keye团队精心打造的前沿多模态大语言模型。作为快手先进技术生态中的核心AI产品&#xff0c;Keye在视频理解、视觉感知和推理任务方面表现卓越&#xff0c;树立了新的性…

Web前端之JavaScript实现图片圆环、圆环元素根据角度指向圆心、translate、rotate

MENU 前言效果HtmlStyleJavaScript 前言 代码段创建了一个由6个WiFi图标组成的圆形排列&#xff0c;每个图标均匀分布在圆周上。 效果 Html 代码 <div class"ring"><div class"item"><img class"img" src"../image/icon/W…

1 Studying《Computer Vision: Algorithms and Applications 2nd Edition》11-15

目录 Chapter 11 Structure from motion and SLAM 11.1 几何内禀校准 11.2 姿态估计 11.3 从运动中获得的双帧结构 11.4 从运动中提取多帧结构 11.5 同步定位与建图&#xff08;SLAM&#xff09; 11.6 额外阅读 Chapter 12 Depth estimation 12.1 极点几何 12.2 稀疏…

phpstudy 可以按照mysql 数据库

phpstudy 可以按照mysql 数据库 PHPStudy&#xff08;小皮面板&#xff09;是一款专为开发者设计的集成环境工具&#xff0c;涵盖服务器配置、开发环境搭建、网站部署等多项功能。以下是其核心用途及优势的详细解析&#xff1a; 一、开发环境快速搭建 一站式集成环境集成Apa…

Python搭建HTTP服务,如何用内网穿透快速远程访问?

Python的内置HTTP服务模块是开发者工具箱中的瑞士军刀&#xff0c;只需一行命令即可启动一个功能完备的Web服务器。无论是前端工程师调试页面、数据科学家共享Jupyter Notebook&#xff0c;还是后端开发者快速验证API原型&#xff0c;Python HTTP服务都能以零配置的方式满足需求…

拨号音识别系统的设计与实现

拨号音识别系统的设计与实现 摘要 本文设计并实现了一个完整的拨号音识别系统&#xff0c;该系统能够自动识别电话号码中的数字。系统基于双音多频(DTMF)技术原理&#xff0c;使用MATLAB开发&#xff0c;包含GUI界面展示处理过程和结果。系统支持从麦克风实时录音或加载音频文…

数据结构-树详解

树简介 树存储和组织具有层级结构的数据&#xff08;例&#xff1a;公司职级&#xff09;&#xff0c;就是一颗倒立生长的树。 属性&#xff1a; 递归n个节点有n-1个连接节点x的深度&#xff1a;节点x到根节点的最长路径节点x的高度&#xff1a;节点x到叶子节点的最长路径 …

【安卓Sensor框架-2】应用注册Sensor 流程

注册传感器的核心流程为如下&#xff1a;应用层调用 SensorManager注册传感器&#xff0c;framework层创建SensorEventQueue对象&#xff08;事件队列&#xff09;&#xff0c;通过JNI调用Native方法nativeEnableSensor()&#xff1b;SensorService服务端createEventQueue()创建…

新版本没有docker-desktop-data分发 | docker desktop 镜像迁移

在新版本的docker desktop中&#xff08;如4.42版本&#xff09;&#xff0c;镜像迁移只需要更改路径即可。如下&#xff1a; 打开docker desktop的设置&#xff08;图1&#xff09;&#xff0c;将图2的原来的地址C:\Users\用户\AppData\Local\Docker\wsl修改为你想要的空文件…

EtherCAT SOEM源码分析 - ec_init

ec_init SOEM主站一切开始的地方始于ec_init, 它是EtherCAT主站初始化的入口。初始化SOEM 主站&#xff0c;并绑定到socket到ifname。 /** Initialise lib in single NIC mode* param[in] ifname Dev name, f.e. "eth0"* return >0 if OK* see ecx_init*/ in…