训练模型是机器学习和深度学习中的核心过程,旨在通过大量数据学习模型参数,以便模型能够对新的、未见过的数据做出准确的预测。

训练模型通常包括以下几个步骤:

1.数据准备:
收集和处理数据,包括清洗、标准化和归一化。
将数据分为训练集、验证集和测试集。

2.定义模型:
选择模型架构,例如决策树、神经网络等。
初始化模型参数(权重和偏置)。

3.选择损失函数:
根据任务类型(如分类、回归)选择合适的损失函数。

4.选择优化器:
选择一个优化算法,如SGD、Adam等,来更新模型参数。

5.前向传播:
在每次迭代中,将输入数据通过模型传递,计算预测输出。

6.计算损失:
使用损失函数评估预测输出与真实标签之间的差异。

7.反向传播:
利用自动求导计算损失相对于模型参数的梯度。

8.参数更新:
根据计算出的梯度和优化器的策略更新模型参数。

9.迭代优化:
重复步骤5-8,直到模型在验证集上的性能不再提升或达到预定的迭代次数。

10.评估和测试:
使用测试集评估模型的最终性能,确保模型没有过拟合。

11.模型调优:
根据模型在测试集上的表现进行调参,如改变学习率、增加正则化等。

12.部署模型:
将训练好的模型部署到生产环境中,用于实际的预测任务。


一、PyTorch 数据处理与加载

PyTorch 提供了Dataset 和 DataLoader,帮助管理数据集、批量加载和数据增强等任务。

PyTorch 数据处理与加载:
自定义 Dataset:通过继承 torch.utils.data.Dataset 来加载自己的数据集。
DataLoader:使用DataLoader 按批次加载数据,支持多线程加载并进行数据打乱。(torch.utils.data.DataLoader

(一)自定义 Dataset

torch.utils.data.Dataset 是一个抽象类,允许你自己的数据源中创建数据集。

使用时需要继承该类并实现以下两个方法:
len(self):返回数据集中的样本数量。
getitem(self, idx):通过索引返回一个样本。

import os
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from models.utils import match_seq_lenDATASET_DIR = "D:\EMDKT\datasets\ASSIST2009"class ASSIST2009(Dataset):def __init__(self, seq_len, dataset_dir=DATASET_DIR) -> None:super().__init__()self.dataset_dir = dataset_dirself.dataset_path = os.path.join(self.dataset_dir, "skill_builder_data.csv")# 调用预处理self.q_seqs, self.r_seqs, self.q_list, self.u_list, self.q2idx, \self.u2idx = self.preprocess()self.num_u = self.u_list.shape[0]  # 用户总数self.num_q = self.q_list.shape[0]  # 题目总数if seq_len:self.q_seqs, self.r_seqs = match_seq_len(self.q_seqs, self.r_seqs, seq_len)self.len = len(self.q_seqs)def __getitem__(self, index):return self.q_seqs[index], self.r_seqs[index]def __len__(self):return self.lendef preprocess(self):# 数据加载与清洗df = pd.read_csv(self.dataset_path, encoding="ISO-8859-15").dropna(subset=["skill_name"]).drop_duplicates(subset=["order_id", "skill_name"]).sort_values(by=["order_id"])u_list = np.unique(df["user_id"].values)q_list = np.unique(df["skill_name"].values)u2idx = {u: idx for idx, u in enumerate(u_list)}q2idx = {q: idx for idx, q in enumerate(q_list)}# 生成序列数据q_seqs = []r_seqs = []for u in u_list:df_u = df[df["user_id"] == u]q_seq = np.array([q2idx[q] for q in df_u["skill_name"]])r_seq = df_u["correct"].valuesq_seqs.append(q_seq)r_seqs.append(r_seq)# 返回结果return q_seqs, r_seqs, q_list, u_list, q2idx, u2idx

(二)使用 DataLoader 加载数据

DataLoader 是 PyTorch 提供的一个重要工具,用于从 Dataset 中按批次(batch)加载数据。
DataLoader 允许批量读取数据并进行多线程加载,从而提高训练效率。

from torch.utils.data import DataLoader, random_split
from models.utils import collate_fn# 加载数据集
dataset = ASSIST2009(seq_len)  # seq_len 是序列长度参数# 划分数据集
train_size = int(len(dataset) * train_ratio)  # train_ratio 是训练集比例
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size]
)# 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=batch_size,  # batch_size 是批处理大小shuffle=True,collate_fn=collate_fn  # 使用自定义的collate函数
)test_loader = DataLoader(test_dataset,batch_size=test_size,  # 测试集使用整个测试集作为一个批次shuffle=True,collate_fn=collate_fn  # 使用自定义的collate函数
)

注释:
batch_size: 每次加载的样本数量。
shuffle: 是否对数据进行洗牌,通常训练时需要将数据打乱。

二、模型架构实现

通过继承 nn.Module 来定义模型

class DKT(Module):def __init__(self, num_q, emb_size, hidden_size):super().__init__()self.num_q = num_qself.emb_size = emb_sizeself.hidden_size = hidden_sizeself.interaction_emb = Embedding(self.num_q * 2, self.emb_size)self.lstm_layer = LSTM(self.emb_size, self.hidden_size, batch_first=True)self.out_layer = Linear(self.hidden_size, self.num_q)self.dropout_layer = Dropout()def forward(self, q, r):'''q: [batch_size, n]r: [batch_size, n]'''x = q + self.num_q * rh, _ = self.lstm_layer(self.interaction_emb(x))y = self.out_layer(h)y = self.dropout_layer(y)y = torch.sigmoid(y)return y# 创建模型实例
model = DKT()

三、训练配置

(一)初始化模型与设备

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = DKT(dataset.num_q, emb_size=100,hidden_size=100).to(device)

(二)定义损失函数与优化器

损失函数用于衡量预测值与真实值之间的差异。PyTorch 中提供了现成的损失函数。
将使用 SGD(随机梯度下降) 或 Adam 优化器来最小化损失函数。

(1)损失函数

from torch.nn.functional import binary_cross_entropy
criterion = nn.binary_cross_entropy()

(2)优化器

from torch.optim import SGD, Adamif optimizer == "sgd":opt = SGD(model.parameters(), learning_rate, momentum=0.9)elif optimizer == "adam":opt = Adam(model.parameters(), learning_rate)

(三)训练模型评估模型

在训练过程中,将执行以下步骤:
使用输入数据 X 进行前向传播,得到预测值。
计算损失(预测值与实际值之间的差异)。
使用反向传播计算梯度。
更新模型参数(权重和偏置)。

# 训练模型
num_epochs = 1000  # 训练 1000 轮
for epoch in range(num_epochs):model.train()  # 设置模型为训练模式# 前向传播predictions = model(X)  # 模型输出预测值loss = criterion(predictions.squeeze(), Y)  # 计算损失# 反向传播optimizer.zero_grad()  # 清空之前的梯度loss.backward()  # 计算梯度optimizer.step()  # 更新模型参数# 打印损失if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/1000], Loss: {loss.item():.4f}')

注释:
optimizer.zero_grad():每次反向传播前需要清空之前的梯度。
loss.backward():计算梯度。
optimizer.step():更新权重和偏置。

(四)评估模型

训练完成后,可以通过查看模型的权重和偏置来评估模型的效果

with torch.no_grad():  # 评估时不需要计算梯度predictions = model(X)

(五)训练循环实现

import os
import numpy as np
import torch
from torch.nn.functional import one_hot, binary_cross_entropy
from sklearn import metricsdef train_dkt_model(model, train_loader, test_loader, num_epochs, optimizer, ckpt_path):"""训练DKT模型的独立函数参数:model: 要训练的DKT模型实例train_loader: 训练数据加载器test_loader: 测试数据加载器num_epochs: 训练轮数optimizer: 优化器实例ckpt_path: 模型检查点保存路径"""aucs = []       # 存储每轮测试AUCloss_means = [] # 存储每轮平均训练损失max_auc = 0     # 记录最佳AUC# 开始训练循环for epoch in range(1, num_epochs + 1):epoch_losses = []  # 存储当前epoch的训练损失# 训练阶段model.train()for data in train_loader:# 解包数据q, r, qshft, rshft, mask = data# 前向传播y_pred = model(q.long(), r.long())y_pred = (y_pred * one_hot(qshft.long(), model.num_q)).sum(-1)# 应用掩码选择有效预测valid_pred = torch.masked_select(y_pred, mask)valid_target = torch.masked_select(rshft, mask)# 计算损失loss = binary_cross_entropy(valid_pred, valid_target)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 记录损失epoch_losses.append(loss.detach().cpu().item())# 计算本轮平均训练损失epoch_loss_mean = np.mean(epoch_losses)loss_means.append(epoch_loss_mean)# 验证阶段model.eval()all_preds = []all_targets = []with torch.no_grad():for data in test_loader:q, r, qshft, rshft, mask = data# 预测并选择有效结果y_pred = model(q.long(), r.long())y_pred = (y_pred * one_hot(qshft.long(), model.num_q)).sum(-1)valid_pred = torch.masked_select(y_pred, mask).cpu().numpy()valid_target = torch.masked_select(rshft, mask).cpu().numpy()all_preds.extend(valid_pred)all_targets.extend(valid_target)# 计算整体AUCauc = metrics.roc_auc_score(all_targets, all_preds)aucs.append(auc)# 打印训练信息print(f"Epoch: {epoch}, AUC: {auc:.4f}, Loss Mean: {epoch_loss_mean:.4f}")# 保存最佳模型if auc > max_auc:torch.save(model.state_dict(), os.path.join(ckpt_path, "model.ckpt"))max_auc = aucprint(f"保存最佳模型,AUC = {auc:.4f}")return aucs, loss_means

理解 y = self(q.long(), r.long())

这行代码是知识追踪模型的核心,表示将输入数据传入模型进行前向传播。

1. 代码结构解析

y = self(q.long(), r.long())

self: 指当前模型实例(EM_DKT)

q.long(): 将问题ID序列转换为长整型(整数类型)

r.long(): 将响应序列(0/1)转换为长整型

y: 模型输出(预测概率)

2. 数据流分析

输入数据
变量 含义 维度 示例
q 问题ID序列 (batch_size, seq_len) [[101, 102, 0], [201, 0, 0]]
r 响应序列 (batch_size, seq_len) [[1, 0, 0], [0, 0, 0]]

3. 模型内部处理(在EM_DKT.forward()中)

步骤1: 交互编码
x = q + self.num_q * r

目的: 创建唯一的交互ID
逻辑:
-正确响应: ID = q + num_q * 1
-错误响应: ID = q + num_q * 0 = q

示例:
问题101正确: 101 + 100 * 1 = 201
问题101错误: 101 + 100 * 0= 101

步骤2: 嵌入层
emb = self.interaction_emb(x)

输入: 交互ID (batch_size, seq_len)
输出: 嵌入向量 (batch_size, seq_len, emb_size)
作用: 将离散ID映射为连续向量表示

步骤3: XLSTM处理
for t in range(seq_len):x_t = emb[:, t, :]  # 当前时间步h_t, states = self.xlstm(x_t, states)y_t = self.out_layer(h_t)

XLSTM结构:
-7层MLSTM: 处理长期知识状态
-1层ELSTM: 处理近期动态

状态传递: 每个时间步更新内部状态

输出: 每个时间步的隐藏表示 (hidden_size)

步骤4: 输出层
y = torch.stack(outputs, dim=1)
y = torch.sigmoid(y)

维度变化: (batch_size, seq_len, num_q)
sigmoid激活: 将输出转换为概率[0,1]

4. 输出

输出 y 的结构[batch_size,seq_len,num_q]
输出示例 y[0, 2, 101] = 0.85
表示:批次0中,第2个时间步后,学生答对问题101的概率是85%

5. 实际应用场景

训练时
# 预测下一个问题的正确概率
y_next = (y * one_hot(qshft)).sum(-1)
预测时
# 获取学生当前知识状态
current_state = self.xlstm.states# 预测下一个问题
next_q = 105
next_input = create_input(next_q)
next_pred = self(next_input, current_state)

6. 数学表示

模型本质上学习了一个函数:

P ( r t + 1 = 1 ∣ q 1 : t , r 1 : t ) = f ( q 1 : t , r 1 : t ) P(r_{t+1}=1 | q_{1:t}, r_{1:t}) = f(q_{1:t}, r_{1:t}) P(rt+1=1∣q1:t,r1:t)=f(q1:t,r1:t)
其中:

q 1 : t q_{1:t} q1:t: 到时间t为止的问题序列

r 1 : t r_{1:t} r1:t: 到时间t为止的响应序列

f f f: 由EM_DKT模型参数化的复杂非线性函数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/86575.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/86575.shtml
英文地址,请注明出处:http://en.pswp.cn/web/86575.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity_导航操作(鼠标控制人物移动)_运动动画

文章目录 前言一、Navigation 智能导航地图烘焙1.创建Plan和NavMesh Surface2.智能导航地图烘焙 二、MouseManager 鼠标控制人物移动1.给场景添加人物,并给人物添加导航组件2.编写脚本管理鼠标控制3.给人物编写脚本,订阅事件(添加方法给Mouse…

6. 接口分布式测试pytest-xdist

pytest-xdist实战指南:解锁分布式测试的高效之道 随着测试规模扩大,执行时间成为瓶颈。本文将带你深入掌握pytest-xdist插件,利用分布式测试将执行速度提升300%。 一、核心命令解析 加速安装(国内镜像) pip install …

预训练语言模型

预训练语言模型 1.1Encoder-only PLM ​ Transformer结构主要由Encoder、Decoder组成,根据特点引入了ELMo的预训练思路。 ELMo(Embeddings from Language Models)是一种深度上下文化词表示方法, 该模型由一个**前向语言模型&…

Altera PCI IP target设计分享

最近调试也有关于使用Altera 家的PCI IP,然后分享一下代码: 主要实现:主控作为主设备,FPGA作为从设备,主控对FPGA IO读写的功能 后续会分享FPGA作为主设备, 从 FPGA通过 memory写到主控内存,会…

基于机器学习的智能文本分类技术研究与应用

在当今数字化时代,文本数据的爆炸式增长给信息管理和知识发现带来了巨大的挑战。从新闻文章、社交媒体帖子到企业文档和学术论文,海量的文本数据需要高效地分类和管理,以便用户能够快速找到所需信息。传统的文本分类方法主要依赖于人工规则和…

前端项目3-01:登录页面

一、效果图 二、全部代码 <!DOCTYPE html> <html><head><meta charset"utf-8"><title>码农魔盒</title><style>.bg{position: fixed;top: 0;left:0;object-fit: cover;width: 100vw;height: 100vh;}.box{width: 950px;he…

Nexus CLI:简化你的分布式计算贡献之旅

探索分布式证明网络的力量&#xff1a;Nexus CLI 项目深入解析 在今天的数字时代&#xff0c;分布式计算和去中心化技术正成为互联网发展的前沿。Nexus CLI 是一个为 Nexus 网络提供证明的高性能命令行界面&#xff0c;它不仅在概念上先进&#xff0c;更是在具体实现中为开发者…

IBW 2025: CertiK首席商务官出席,探讨AI与Web3融合带来的安全挑战

6月26日至27日&#xff0c;全球最大的Web3安全公司CertiK亮相伊斯坦布尔区块链周&#xff08;IBW 2025&#xff09;&#xff0c;首席商务官Jason Jiang出席两场圆桌论坛&#xff0c;分享了CertiK在AI与Web3融合领域的前沿观察与安全见解。他与普华永道土耳其网络安全服务主管Nu…

Vivado 五种仿真类型的区别

Vivado 五种仿真类型的区别 我们还是用“建房子”的例子来类比。您已经有了“建筑蓝图”&#xff08;HLS 生成的 RTL 代码&#xff09;&#xff0c;现在要把它建成真正的房子&#xff08;FPGA 电路&#xff09;。这五种仿真就是在这个过程中不同阶段的“质量检查”。 1. 行为…

小程序快速获取url link方法,短信里面快速打开链接

获取小程序链接方法 uni.request({url:https://api.weixin.qq.com/cgi-bin/token?grant_typeclient_credential&appidwxxxxxxxxxxxx&secret111111111111111111111111111111111,method:GET,success(res) {console.log(res.data)let d {"path": "/xxx/…

Spring 框架(1-4)

第一章&#xff1a;Spring 框架概述 1.1 Spring 框架的定义与背景 Spring 是一个开源的轻量级 Java 开发框架&#xff0c;于 2003 年由 Rod Johnson 创立&#xff0c;旨在解决企业级应用开发的复杂性。其核心设计思想是面向接口编程和松耦合架构&#xff0c;通过分层设计&…

RabitQ 量化:既省内存又提性能

突破高维向量内存瓶颈:Mlivus Cloud RaBitQ量化技术的工程实践与调优指南 作为大禹智库高级研究员,拥有三十余年向量数据库与AI系统架构经验的我发现,在当今多模态AI落地的核心场景中,高维向量引发的内存资源消耗问题已成为制约系统规模化部署的“卡脖子”因素。特别是在大…

创客匠人:创始人 IP 打造的得力助手

在当今竞争激烈的商业环境中&#xff0c;创始人 IP 的打造对于企业的发展愈发重要。一个鲜明且具有影响力的创始人 IP&#xff0c;能够为企业带来独特的竞争优势&#xff0c;提升品牌知名度与美誉度。创客匠人在创始人 IP 打造过程中扮演着不可或缺的角色&#xff0c;为创始人提…

如何为虚拟机上的 Manjaro Linux启用 VMware 拖放功能

如果你的Manjaro 发行版本是安装在 VMware Workstation Player 上使用的 &#xff0c;而且希望可以通过拖放功能将文件或文件夹从宿主机复制到客户端的Manjaro 里面&#xff0c;那么可以按照以下的步骤进行操作&#xff0c;开启拖放功能。 在 VMware 虚拟机上安装 Manjaro 后&…

【C/C++】单元测试实战:Stub与Mock框架解析

C 单元测试中的 Stub/Mock 框架详解 在单元测试中&#xff0c;Stub&#xff08;打桩&#xff09;和Mock都是替代真实依赖以简化测试的技术。通常&#xff0c;Stub&#xff08;或 Fake&#xff09;提供了一个简化实现&#xff0c;用于替代生产代码中的真实对象&#xff08;例如…

工厂 + 策略设计模式(实战教程)

在软件开发中&#xff0c;设计模式是解决特定问题的通用方案&#xff0c;而工厂模式与策略模式的结合使用&#xff0c;能在特定业务场景下发挥强大的威力。本文将基于新增题目&#xff08;题目类型有单选、多选、判断、解答&#xff09;这一业务场景&#xff0c;详细阐述如何运…

Nuxt3中使用 Ant-Design-Vue 的BackTop 组件实现自动返回页面顶部

在现代 Web 应用中&#xff0c;提供一个方便用户返回页面顶部的功能是非常重要的。Ant Design Vue 提供了 BackTop 组件&#xff0c;可以轻松实现这一功能。本文将详细介绍如何在 Nuxt 3 项目中使用 <a-back-top/> 组件&#xff0c;并通过按需引入的方式加载组件及其样式…

在统信UOS(Linux)中构建SQLite3桌面应用笔记

目录 1 下载lazarus 2 下载sqlite3源码编译生成库文件 3 新建项目 4 设置并编译 一次极简单的测试&#xff0c;记录一下。 操作系统&#xff1a;统信UOS&#xff0c; 内核&#xff1a;4.19.0-arm64-desktop 处理器&#xff1a;D3000 整个流程难点是生成so库文件并正确加…

Host ‘db01‘ is not allowed to connect to this MariaDB server 怎么解决?

出现错误 ERROR 1130 (HY000): Host db01 is not allowed to connect to this MariaDB server&#xff0c;表示当前用户 test 没有足够的权限从主机 db01 连接到 MariaDB 服务器。以下是逐步解决方案&#xff1a; 1. 检查用户权限 登录 MariaDB 服务器&#xff08;需本地或通过…

打造高可用的大模型推理服务:基于 DeepSeek 的企业级部署实战

&#x1f4dd;个人主页&#x1f339;&#xff1a;一ge科研小菜鸡-CSDN博客 &#x1f339;&#x1f339;期待您的关注 &#x1f339;&#x1f339; 一、引言&#xff1a;从“能部署”到“可用、好用、能扩展” 近年来&#xff0c;随着 DeepSeek、Qwen、Yi 等开源大模型的持续发…