目录

config参数配置

setup_dirs创建训练文件夹

 load_data加载数据

build_model创建模型

train训练


记录一下训练代码中不理解的地方

config参数配置

config = {'data_root': r"D:\project\megnetometer\datasets\WISDM_ar_latest\organized_dataset",'train_dir': 'train','test_dir': 'test','seq_length': 300,  # 序列长度'batch_size': 32,  # 可能需减小batch_size'epochs': 60,'initial_lr': 3e-4,  # 初始学习率'max_lr': 5e-4,'patience': 20}

配置好需要用到的参数,比如数据集地址,训练轮数,批次大小,学习率等

setup_dirs创建训练文件夹

    def setup_dirs(self):self.run_dir = os.path.join(self.config['data_root'], 'run')  os.makedirs(self.run_dir, exist_ok=True)print('创建运行目录run_dir  = ', self.run_dir)# 创建带时间戳的实验目录timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")print('时间戳 = ', timestamp)self.exp_dir = os.path.join(self.run_dir, f"exp_{timestamp}")os.makedirs(self.exp_dir, exist_ok=True)# 保存当前配置with open(os.path.join(self.exp_dir, 'config.json'), 'w') as f:json.dump(self.config, f, indent=2)  # 两个字符缩进,没有则压缩成一行,把config内容存在config.json里

os.path.join(self.config['data_root'], 'run')  

用于拼接文件路径data_root的路径加上run,中间的连接符会根据系统自动调整

os.makedirs(self.exp_dir, exist_ok=True)

创建文件,exist_ok=True当文件夹存在的时候不报错

创建的文件夹用于存放后续训练生成的模型以及保存训练参数等文件

 load_data加载数据

    def load_data(self):"""从按行为分类的目录加载数据(带多级进度条)"""def load_activity_data(subset_dir):"""加载train或test子目录下的数据"""data = []subset_path = os.path.join(self.config['data_root'], subset_dir)  #在数据集路径内读取,由subset_dir决定读取的是训练集还是测试集# 获取所有活动类别目录activities = [d for d in os.listdir(subset_path)if os.path.isdir(os.path.join(subset_path, d))]#print('activities=',activities)#activities= ['Downstairs', 'Jogging', 'Sitting', 'Standing', 'Upstairs', 'Walking']# 第一层进度条:活动类别pbar_activities = tqdm(activities, desc=f"扫描{subset_dir}目录", position=0)for activity in pbar_activities:activity_lower = activity.lower()if activity_lower not in self.label_map:continueactivity_dir = os.path.join(subset_path, activity)#当前活动的目录# 获取所有用户文件user_files = [f for f in os.listdir(activity_dir)if f.endswith('.txt')]#获取所有txt结尾的文件# 第二层进度条:用户文件#pbar_users = tqdm(user_files, desc="读取用户文件", leave=False, position=1)#后面要close,但是已经把所有的进度注释掉了只留下来一个总的第一层进度#print('pbar_users=',pbar_users)for user_file in user_files:file_path = os.path.join(activity_dir, user_file)# 获取文件行数用于进度条with open(file_path, 'r') as f:num_lines = sum(1 for _ in f)# 第三层进度条:读取文件内容with open(file_path, 'r') as f:for line in f:line = line.strip()if not line:continuetry:x, y, z = map(float, line.split(','))data.append({'x': x,'y': y,'z': z,'activity': activity_lower})except ValueError:continuepbar_activities.close()return data# 调用示例print("\n" + "=" * 50)print("开始加载数据集...")train_data = load_activity_data(self.config['train_dir'])#print(train_data)#{'x': 5.33, 'y': 8.73, 'z': -0.42, 'activity': 'walking'},test_data = load_activity_data(self.config['test_dir'])
pbar_activities = tqdm(activities, desc=f"扫描{subset_dir}目录", position=0)

tqdm创建进度条,desc是进度条前面的描述,position用于多级进度条之间的嵌套,以免位置混乱,在运行完之后要关闭进度条

pbar_activities.close()
with open('data.txt', 'r') as f:打开文件夹,r为只读模式

# 转换为模型输入格式(带优化进度条)def create_sequences(data, desc="生成序列"):seq_length = self.config['seq_length']features, labels = [], []total_windows = len(data) - seq_lengthpbar = tqdm(range(total_windows),desc=desc,position=0,bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [速度:{rate_fmt}]")for i in pbar:window = data[i:i + seq_length]# 检查窗口内活动是否一致if len(set(d['activity'] for d in window)) != 1:continuefeatures.append([[d['x'], d['y'], d['z']] for d in window])labels.append(self.label_map[window[0]['activity']])# 每1000次更新一次进度信息if i % 1000 == 0:pbar.set_postfix({"有效窗口": len(features),"跳过窗口": i - len(features) + 1}, refresh=True)return np.array(features), np.array(labels)print("\n正在预处理训练集...")X_train, y_train = create_sequences(train_data, "训练集序列化")#返回的x是数据,y是标签print("\n正在预处理测试集...")X_test, y_test = create_sequences(test_data, "测试集序列化")# 标准化(显示进度)print("\n正在计算标准化参数...")self.mean = np.mean(X_train, axis=(0, 1))self.std = np.std(X_train, axis=(0, 1))print("应用标准化...")X_train = (X_train - self.mean) / (self.std + 1e-8)X_test = (X_test - self.mean) / (self.std + 1e-8)# One-hot编码# 将 NumPy 数组转为 PyTorch 张量,并指定类型为 int64(等价于 .long())y_train = torch.from_numpy(y_train).long()  # 或 .to(torch.int64)y_train = torch.nn.functional.one_hot(y_train.long(), num_classes=len(self.label_map))y_test = torch.from_numpy(y_test).long()  # 或 .to(torch.int64)y_test = torch.nn.functional.one_hot(y_test.long(), num_classes=len(self.label_map))print("\n" + "=" * 50)print("数据预处理完成!")print(f"训练集形状: X_train{X_train.shape}, y_train{y_train.shape}")print(f"测试集形状: X_test{X_test.shape}, y_test{y_test.shape}")print("=" * 50 + "\n")return (X_train, y_train), (X_test, y_test)

滑动窗口开销大,改用向量化滑动窗口(NumPy)

参数标准化全部使用训练集数据

1e-8的作用:防止除零的小常数,特别适用于某些标准差接近0的特征

axis=(0,1):假设您的数据是3D张量(样本×时间步/空间×特征),这样计算每个特征通道的统计量

消除量纲影响:当特征的单位/量纲不同时(如年龄0-100 vs 工资0-100000),标准化使所有特征具有可比性

只使用训练集统计量:测试集必须使用训练集的mean/std,这是为了避免数据泄露(data leakage)

数据泄露:是机器学习中一个常见但严重的问题,指在模型训练过程中意外地使用了测试集或未来数据的信息,导致模型评估结果被高估,无法反映真实性能。这种现象会使模型在实际应用中表现远差于预期。

将分类标签(整数形式)转换为 One-hot 编码,这是机器学习中处理分类任务的常见方法。

build_model创建模型

    def build_model(self):"""构建改进的BiLSTM分类模型"""model = tf.keras.Sequential([tf.keras.layers.InputLayer(input_shape=(self.config['seq_length'], 3)),# 双向LSTM层tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True)),tf.keras.layers.BatchNormalization(),tf.keras.layers.Dropout(0.2),tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),tf.keras.layers.BatchNormalization(),# 全连接层tf.keras.layers.Dense(32, activation='relu'),tf.keras.layers.Dropout(0.3),tf.keras.layers.Dense(len(self.label_map), activation='softmax')])model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),loss='categorical_crossentropy',metrics=['accuracy'])return model

两个模型框架TensorFlow更早,但PyTorch的初始设计更现代,以上是TensorFlow的模型。

计算图(Computational Graph) 是描述数学运算和数据处理流程的抽象结构,而 静态图动态图 是两种不同的计算图构建和执行方式。

计算图 是一个有向无环图(DAG),用于表示计算过程:

  • 节点(Node):代表运算(如加法、矩阵乘法)或数据(如张量、变量)。

  • 边(Edge):描述数据流动方向(如张量从一层传递到下一层)。

改用PyTorch模型需要注意

PyTorch更推荐类式构建,而且保存时仅保存模型的参数(权重和偏置),不包含模型结构。如果需要测试,加载时必须先实例化一个结构完全相同的模型,再加载参数。

先创建一个模型类,再去调用 

class BiLSTMModel(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes, bidirectional=True):super(BiLSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.num_directions = 2 if bidirectional else 1# 双向LSTMself.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True,bidirectional=bidirectional)# 全连接层(双向时hidden_size需*2)self.fc = nn.Linear(hidden_size * self.num_directions, num_classes)def forward(self, x):# 初始化隐藏状态(可选,PyTorch默认全零)h0 = torch.zeros(self.num_layers * self.num_directions, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(self.num_layers * self.num_directions, x.size(0), self.hidden_size).to(x.device)# LSTM前向传播out, _ = self.lstm(x, (h0, c0))  # out形状: (batch, seq_len, hidden_size * num_directions)# 取最后一个时间步的输出out = out[:, -1, :]  # 形状: (batch, hidden_size * num_directions)# 分类层out = self.fc(out)return out

此处构建的就是双向LSTM模型,然后再构建函数调用

    def build_model(self):# 使用示例model = LSTMModel(input_size=3,  # 对应x/y/z特征hidden_size=32,num_layers=2,num_classes=6,  # 类别数bidirectional=True)return model

train训练

    def train(self):"""PyTorch版本训练流程"""# 1. 数据加载与预处理(X_train, y_train), (X_test, y_test) = self.load_data()# 转换为PyTorch张量并移至设备device = torch.device("cuda" if torch.cuda.is_available() else "cpu")X_train = torch.FloatTensor(X_train).to(device)y_train = torch.LongTensor(y_train.argmax(axis=1)).to(device)  # 如果y是one-hotX_test = torch.FloatTensor(X_test).to(device)y_test = torch.LongTensor(y_test.argmax(axis=1)).to(device)# 创建DataLoadertrain_dataset = TensorDataset(X_train, y_train)# 类似zip(features, labels)train_loader = DataLoader(train_dataset,batch_size=self.config['batch_size'],shuffle=True)# 2. 模型初始化self.model = self.build_model().to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(self.model.parameters(),lr=self.config.get('lr', 0.001))# 3. 回调函数设置"""# 早停early_stopping = EarlyStopping(patience=self.config['patience'],verbose=True,path=os.path.join(self.exp_dir, 'best_model.pth'))"""# 学习率调度scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.1,patience=5,verbose=True)# TensorBoard日志writer = SummaryWriter(log_dir=os.path.join(self.exp_dir, 'logs'))print("\n开始训练...")print(f"实验目录: {self.exp_dir}")print(f"使用设备: {device}")# 4. 训练循环for epoch in range(self.config['epochs']):self.model.train()train_loss = 0.0# 训练批次for inputs, labels in train_loader:optimizer.zero_grad()outputs = self.model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item()# 验证阶段self.model.eval()with torch.no_grad():test_outputs = self.model(X_test)test_loss = criterion(test_outputs, y_test)_, predicted = torch.max(test_outputs, 1)accuracy = (predicted == y_test).float().mean()# 记录日志writer.add_scalar('Loss/train', train_loss / len(train_loader), epoch)writer.add_scalar('Loss/test', test_loss.item(), epoch)writer.add_scalar('Accuracy/test', accuracy.item(), epoch)# 打印进度print(f"Epoch {epoch + 1}/{self.config['epochs']} | "f"Train Loss: {train_loss / len(train_loader):.4f} | "f"Test Loss: {test_loss.item():.4f} | "f"Accuracy: {accuracy.item():.4f}")# 学习率调整scheduler.step(test_loss)"""# 早停检查early_stopping(test_loss, self.model)if early_stopping.early_stop:print("Early stopping triggered")break"""# 5. 保存最终结果writer.close()self.save_results(X_test, y_test)  # 需要适配PyTorch的保存方法

TensorDatasetDataLoader 都是 PyTorch 官方库中的核心组件,专门用于高效的数据加载和批处理。

torch.utils.data.TensorDataset将多个张量(如特征张量和标签张量)打包成一个数据集对象

dataset = TensorDataset(features, labels)  # 类似zip(features, labels)

torch.utils.data.DataLoader将数据集按批次加载,支持自动批处理、打乱数据、多进程加载等

shuffle=True代表打乱数据,此处是时序信号,但是由于从长序列中通过滑动窗口提取样本每个窗口本身就是一个独立样本,此时打乱窗口顺序是安全的

损失函数 criterion = nn.CrossEntropyLoss()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/90727.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/90727.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/90727.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java填充Word模板

文章目录前言一、设置word模板普通字段列表字段复选框二、代码1. 引入POM2. 模板放入项目3.代码实体类工具类三、测试四、运行结果五、注意事项前言 最近有个Java填充Word模板的需求,包括文本,列表和复选框勾选,写一个工具类,以此…

【MYSQL8】springboot项目,开启ssl证书安全连接

文章目录一、开启ssl证书1、msysql部署时默认开启ssl证书2、配置文件3、创建用户并指定ssl二、添加Java信任库1、使用 keytool 导入证书2、验证证书是否已导入三、修改连接配置一、开启ssl证书 1、msysql部署时默认开启ssl证书 可通过命令查看: SHOW VARIABLES L…

Telegraf vs. Logstash:实时数据处理架构中的关键组件对比

在现代数据基础设施中,Telegraf 和 Logstash 是两种广泛使用的开源数据收集与处理工具,但它们在设计目标、应用场景和架构角色上存在显著差异。本文将从实时数据处理架构、时序数据库集成、消息代理支持等方面对比两者的核心功能,并结合实际应…

Vue Vue-route (4)

Vue 渐进式JavaScript 框架 基于Vue2的学习笔记 - Vue-route 编程式导航和几种路由 目录 编程式导航 详情组件 创建组件 设置路由 电影列表 传参 另一种方式 动态路由 命名路由 别名 总结 编程式导航 点击电影列表 跳转电影详情 详情组件 创建组件 在views中创…

存在两个cuda环境,在conda中切换到另一个

进入 openmmlab 环境 conda activate openmmlab 设置环境变量为 CUDA 12.4(只影响当前 shell 会话) export PATH/usr/local/cuda-12.4/bin:PATHexportLDLIBRARYPATH/usr/local/cuda−12.4/lib64:PATH export LD_LIBRARY_PATH/usr/local/cuda-12.4/lib64:…

Django 视图(View)

1. 视图简介 视图负责接收 web 请求并返回 web 响应。视图就是一个 python 函数,被定义在 views.py 中。响应可以是一张网页的 HTML 内容、一个重定向、一个 404 错误等等。响应处理过程如下图: 用户在浏览器中输入网址:www.demo.com/1/100Django 获取网址信息,去除域名和端…

HarmonyOS基础概念

一、OpenHarmony、HarmonyOS和Harmony NEXT区别OpenHarmony是由开放原子开源基金会(OpenAtom Foundation)孵化及运营的开源项目,开放原子开源基金会由华为、阿里、腾讯、百度、浪潮、招商银行、360等十家互联网企业共同发起组建。目标是面向全…

spark3 streaming 读kafka写es

1. 代码 package data_import import org.apache.spark.sql.{DataFrame, Row, SparkSession, SaveMode} import org.apache.spark.sql.types.{ArrayType, DoubleType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.functions._…

【跟着PMP学习项目管理】每日一练 - 3

1、你是一个建筑项目的项目经理。电工已经开始铺设路线,此时客户带着一个变更请求来找你。他需要增加插座,你认为这会增加相关工作的成本。你要做的第一件事? A、拒绝做出变更,因为这会增加项目的成本并超出预算 B、参考项目管理计划,查看是否应当处理这个变更 C、查阅…

CentOS 安装 JDK+ NGINX+ Tomcat + Redis + MySQL搭建项目环境

目录第一步:安装JDK 1.8方法 1:安装 Oracle JDK 1.8方法 2:安装 OpenJDK 1.8第二步:使用yum安装NGINX第三步:安装Tomcat第四步:安装Redis第五步:安装MySQL第六步:MySQL版本兼容性问题…

如何设计一个登录管理系统:单点登录系统架构设计

关键词:如何设计一个登录管理系统、登录系统架构、用户认证、系统安全设计 📋 目录 开篇:为什么登录系统这么重要?整体架构设计核心功能模块安全设计要点技术实现细节性能优化策略总结与展望 开篇:为什么登录系统这么…

论迹不论心

2025年7月11日,16~26℃,阴 紧急不紧急重要 备考ing 备课不重要 遇见:免费人格测试 | 16Personalities,下面是我的结果 INFJ分析与优化建议 User: Anonymous (隐藏) Created: 2025/7/11 23:38 Updated: 2025/7/11 23:43 Exported:…

【面板数据】省级泰尔指数及城乡收入差距测算(1990-2024年)

对中国各地区1990-2024年的泰尔指数、城乡收入差距进行测算。本文参考龙海明等(2015),程名望、张家平(2019)的做法,采用泰尔指数测算城乡收入差距。参考陈斌开、林毅夫(2013)的做法&…

http get和http post的区别

HTTP GET 和 HTTP POST 是两种最常用的 HTTP 请求方法,它们在用途、数据传输方式、安全性等方面存在显著差异。以下是它们的主要区别:1. 用途GET:主要用于请求从服务器获取资源,比如获取网页内容、查询数据库等。GET 请求不应该用…

I2C集成电路总线

(摘要:空闲时,时钟线数据线都是高电平,主机发送数据前,要在时钟为高电平时,把数据线从高电平拉低,数据发送采取高位先行,时钟线低电平时可以修改数据线,时钟线高电平时要…

为了安全应该使用非root用户启动nginx

nginx基线安全,修复步骤。主要是由于使用了root用户启动nginx。为了安全应该使用非root用户启动nginx一、检查项和问题检查项分类检查项名称身份鉴别检查是否配置Nginx账号锁定策略。服务配置检查Nginx进程启动账号。服务配置Nginx后端服务指定的Header隐藏状态服务…

论文解析篇 | YOLOv12:以注意力机制为核心的实时目标检测算法

前言:Hello大家好,我是小哥谈。长期以来,改进YOLO框架的网络架构一直至关重要,但尽管注意力机制在建模能力方面已被证明具有优越性,相关改进仍主要集中在基于卷积神经网络(CNN)的方法上。这是因…

学习C++、QT---20(C++的常用的4种信号与槽、自定义信号与槽的讲解)

每日一言相信自己,你比想象中更接近成功,继续勇往直前吧!那么我们开始用这4种方法进行信号与槽的通信第一种信号与槽的绑定方式我们将按键右键后转到槽会自动跳转到这个widget.h文件里面并自动生成了定义,我们要记住我们这个按钮叫…

Anolis OS 23 架构支持家族新成员:Anolis OS 23.3 版本及 RISC-V 预览版发布

自 Anolis OS 23 版本发布之始,龙蜥社区就一直致力于探索同源异构的发行版能力,从 Anolis OS 23.1 版本支持龙芯架构同源异构开始,社区就在持续不断地寻找更多的异构可能性。 RISC-V 作为开放、模块化、可扩展的指令集架构,正成为…

4万亿英伟达,凭什么?

CUDA正是英伟达所有神话的起点。它不是一个产品,而是一个生态系统。当越多的开发者使用CUDA,就会催生越多的基于CUDA的应用程序和框架;这些杀手级应用又会吸引更多的用户和开发者投身于CUDA生态。这个正向飞轮一旦转动起来,其产生…