在这里插入图片描述

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

博主简介:努力学习的22级本科生一枚 🌟​;探索AI算法,C++,go语言的世界;在迷茫中寻找光芒​🌸
博客主页:羊小猪~~-CSDN博客
内容简介:这一篇是NLP的入门项目,以AG_NEW新闻数据为例。
🌸箴言🌸:去寻找理想的“天空“”之城
上一篇内容:【NLP入门系列三】NLP文本嵌入(以Embedding和EmbeddingBag为例)-CSDN博客
​💁​​💁​​💁​​💁​: 如果在conda安装环境,由于nlp的核心包是torchtext,所以如果把握不好就重新创建一虚拟环境(小编的“难忘”经历)

文章目录

    • 1、准备
      • 数据加载
      • 构建词表
    • 2、生成数据批次和迭代器
    • 3、定义与模型
      • 模型定义
      • 创建模型
    • 4、创建训练和评估函数
      • 训练函数
      • 评估函数
      • 创建超参数
    • 5、模型训练
    • 6、结果展示
    • 7、预测

🤔 思路

在这里插入图片描述

1、准备

AG News 数据集(也叫 AG’s Corpus or AG News Dataset),这是一个广泛用于自然语言处理(NLP)任务中的文本分类数据集


基本信息:

  • 全称:AG News
  • 来源:来源于 AG’s corpus,由 A. Godin 在 2005 年构建。
  • 用途:主要用于短文本多类别分类任务
  • 语言:英文
  • 类别数:4 类新闻主题
  • 训练样本数:120,000 条
  • 测试样本数:7,600 条

类别标签(共 4 类)

标签含义
1World (世界)
2Sports (体育)
3Business (商业)
4Science and Technology (科技)

数据加载

import torch
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader 
import torchtext 
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator# 检查设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
device(type='cuda')
# 加载本地数据
train_df = pd.read_csv("./data/train.csv")
test_df = pd.read_csv("./data/test.csv")# 合并标题和描述数据
train_df["text"] = train_df["Title"] + " " + train_df["Description"]
test_df["text"] = test_df["Title"] + " " + test_df["Description"]# 查看数据格式
train_df
Class IndexTitleDescriptiontext
03Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindli...Wall St. Bears Claw Back Into the Black (Reute...
13Carlyle Looks Toward Commercial Aerospace (Reu...Reuters - Private investment firm Carlyle Grou...Carlyle Looks Toward Commercial Aerospace (Reu...
23Oil and Economy Cloud Stocks' Outlook (Reuters)Reuters - Soaring crude prices plus worries\ab...Oil and Economy Cloud Stocks' Outlook (Reuters...
33Iraq Halts Oil Exports from Main Southern Pipe...Reuters - Authorities have halted oil export\f...Iraq Halts Oil Exports from Main Southern Pipe...
43Oil prices soar to all-time record, posing new...AFP - Tearaway world oil prices, toppling reco...Oil prices soar to all-time record, posing new...
...............
1199951Pakistan's Musharraf Says Won't Quit as Army C...KARACHI (Reuters) - Pakistani President Perve...Pakistan's Musharraf Says Won't Quit as Army C...
1199962Renteria signing a top-shelf dealRed Sox general manager Theo Epstein acknowled...Renteria signing a top-shelf deal Red Sox gene...
1199972Saban not going to Dolphins yetThe Miami Dolphins will put their courtship of...Saban not going to Dolphins yet The Miami Dolp...
1199982Today's NFL gamesPITTSBURGH at NY GIANTS Time: 1:30 p.m. Line: ...Today's NFL games PITTSBURGH at NY GIANTS Time...
1199992Nets get Carter from RaptorsINDIANAPOLIS -- All-Star Vince Carter was trad...Nets get Carter from Raptors INDIANAPOLIS -- A...

120000 rows × 4 columns

构建词表

# 定义 Dataset
class AGNewsDataset(Dataset):def __init__(self, dataframe):self.labels = dataframe['Class Index'].tolist()  # 列表数据self.texts = dataframe['text'].tolist()def __len__(self):return len(self.labels)def __getitem__(self, idx):return self.labels[idx], self.texts[idx]# 加载数据
train_dataset = AGNewsDataset(train_df)
test_dataset = AGNewsDataset(test_df)# 构建词表
tokenizer = get_tokenizer("basic_english")  # 英文数据,设置英文分词def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)  # 构建词表# 构建词表,设置索引
vocab = build_vocab_from_iterator(yield_tokens(train_dataset), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])print("Vocab size:", len(vocab))
Vocab size: 95804
# 查看这些单词所在词典的索引
vocab(['here', 'is', 'an', 'example'])  
[475, 21, 30, 5297]
'''  
标签,原始是字符串类型,现在要转换成 数字 类型
文本数字化,需要一个函数进行转换(vocab)
'''
text_pipline = lambda x : vocab(tokenizer(x))  # 先分词。在数字化
label_pipline = lambda x : int(x) - 1   # 标签转化为数字# 举例
text_pipline('here is the an example')
[475, 21, 2, 30, 5297]

2、生成数据批次和迭代器

# 采用embeddingbag嵌入方式,故需要构建数据,包括长度、标签、偏移量
''' 
数据格式:长度(~, 1)
标签:一维
偏移量:一维
'''
def collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_label, _text) in batch:# 标签列表,注意字符串转换成数字label_list.append(label_pipline(_label))# 文本列表,注意要转入tensro数据temp_text = torch.tensor(text_pipline(_text), dtype=torch.int64)text_list.append(temp_text)# 偏移量offsets.append(temp_text.size(0))# 全部转变成tensor变量label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)return label_list.to(device), text_list.to(device), offsets.to(device)# 数据加载
batch_size = 16
train_dl = DataLoader(train_dataset,batch_size=batch_size,shuffle=False,collate_fn=collate_batch
)test_dl = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,collate_fn=collate_batch
)

3、定义与模型

模型定义

class TextModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super().__init__()self.embeddingBag = nn.EmbeddingBag(vocab_size,  # 词典大小embed_dim,   # 嵌入维度sparse=False)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()# 初始化权重def init_weights(self):initrange = 0.5self.embeddingBag.weight.data.uniform_(-initrange, initrange)  # 初始化权重范围self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()  # 偏置置为0def forward(self, text, offsets):embedding = self.embeddingBag(text, offsets)return self.fc(embedding)

创建模型

# 查看数据类别
train_df.groupby('Class Index').count()
TitleDescriptiontext
Class Index
1300003000030000
2300003000030000
3300003000030000
4300003000030000
class_num = 4
vocab_len = len(vocab)
embed_dim = 64  # 嵌入到64维度中
model = TextModel(vocab_size=vocab_len, embed_dim=embed_dim, num_class=class_num).to(device=device)

4、创建训练和评估函数

训练函数

def train(model, dataset, optimizer, loss_fn):size = len(dataset.dataset)num_batch = len(dataset)train_acc = 0train_loss = 0for _, (label, text, offset) in enumerate(dataset):label, text, offset = label.to(device), text.to(device), offset.to(device)predict_label = model(text, offset)loss = loss_fn(predict_label, label)# 求导与反向传播optimizer.zero_grad()loss.backward()optimizer.step()train_acc += (predict_label.argmax(1) == label).sum().item()train_loss += loss.item()train_acc /= size train_loss /= num_batchreturn train_acc, train_loss

评估函数

def test(model, dataset, loss_fn):size = len(dataset.dataset)batch_size = len(dataset)test_acc, test_loss = 0, 0with torch.no_grad():for _, (label, text, offset) in enumerate(dataset):label, text, offset = label.to(device), text.to(device), offset.to(device)predict = model(text, offset)loss = loss_fn(predict, label) test_acc += (predict.argmax(1) == label).sum().item()test_loss += loss.item()test_acc /= size test_loss /= batch_sizereturn test_acc, test_loss

创建超参数

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.01)  # 动态调整学习率

5、模型训练

import copyepochs = 10train_acc, train_loss, test_acc, test_loss = [], [], [], []best_acc = 0for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(model, train_dl, optimizer, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)model.eval()epoch_test_acc, epoch_test_loss = test(model, test_dl, loss_fn)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)if best_acc is not None and epoch_test_acc > best_acc:# 动态调整学习率scheduler.step()best_acc = epoch_test_accbest_model = copy.deepcopy(model)  # 保存模型# 当前学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,  epoch_test_acc*100, epoch_test_loss, lr))# 保存最佳模型到文件
path = './best_model.pth'
torch.save(best_model.state_dict(), path) # 保存模型参数
Epoch: 1, Train_acc:79.9%, Train_loss:0.562, Test_acc:86.9%, Test_loss:0.392, Lr:5.00E-01
Epoch: 2, Train_acc:89.7%, Train_loss:0.313, Test_acc:88.9%, Test_loss:0.346, Lr:5.00E-01
Epoch: 3, Train_acc:91.2%, Train_loss:0.269, Test_acc:89.6%, Test_loss:0.329, Lr:5.00E-01
Epoch: 4, Train_acc:92.0%, Train_loss:0.243, Test_acc:89.8%, Test_loss:0.319, Lr:5.00E-01
Epoch: 5, Train_acc:92.6%, Train_loss:0.224, Test_acc:90.2%, Test_loss:0.315, Lr:5.00E-03
Epoch: 6, Train_acc:93.3%, Train_loss:0.207, Test_acc:90.6%, Test_loss:0.297, Lr:5.00E-03
Epoch: 7, Train_acc:93.4%, Train_loss:0.204, Test_acc:90.7%, Test_loss:0.295, Lr:5.00E-03
Epoch: 8, Train_acc:93.4%, Train_loss:0.203, Test_acc:90.7%, Test_loss:0.294, Lr:5.00E-03
Epoch: 9, Train_acc:93.4%, Train_loss:0.202, Test_acc:90.8%, Test_loss:0.293, Lr:5.00E-03
Epoch:10, Train_acc:93.4%, Train_loss:0.201, Test_acc:90.7%, Test_loss:0.293, Lr:5.00E-03

6、结果展示

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epoch_length = range(epochs)plt.figure(figsize=(12, 3))plt.subplot(1, 2, 1)
plt.plot(epoch_length, train_acc, label='Train Accuaray')
plt.plot(epoch_length, test_acc, label='Test Accuaray')
plt.legend(loc='lower right')
plt.title('Accurary')plt.subplot(1, 2, 2)
plt.plot(epoch_length, train_loss, label='Train Loss')
plt.plot(epoch_length, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Loss')plt.show()


在这里插入图片描述

7、预测

model.load_state_dict(torch.load("./best_model.pth"))
model.eval() # 模型评估# 测试句子
test_sentence = "This is a news about Technology"# 转换为 token
token_ids = vocab(tokenizer(test_sentence))   # 切割分词--> 词典序列
text = torch.tensor(token_ids, dtype=torch.long).to(device)  # 转化为tensor
offsets = torch.tensor([0], dtype=torch.long).to(device)# 测试,注意:不需要反向求导
with torch.no_grad():output = model(text, offsets)predicted_label = output.argmax(1).item()# 输出结果
class_names = ["World", "Sports", "Business", "Science and Technology"]
print(f"预测类别: {class_names[predicted_label]}")
预测类别: Science and Technology

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/87528.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/87528.shtml
英文地址,请注明出处:http://en.pswp.cn/web/87528.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu安装ClickHouse

注&#xff1a;本文章的ubuntu的版本为&#xff1a;ubuntu-20.04.6-live-server-amd64。 Ubuntu&#xff08;在线版&#xff09; 更新软件源 sudo apt-get update 安装apt-transport-https 允许apt工具通过https协议下载软件包。 sudo apt-get install apt-transport-htt…

C++26 下一代C++标准

C++26 将是继 C++23 之后的下一个 C++ 标准。这个新标准对 C++ 进行了重大改进,很可能像 C++98、C++11 或 C++20 那样具有划时代的意义。 一:C++标准回顾 C++ 已经有 40 多年的历史了。过去这些年里发生了什么?这里给出一个简化版的答案,直到即将到来的 C++26。 1. C++9…

【MySQL】十六,MySQL窗口函数

在 MySQL 8.0 及以后版本中&#xff0c;窗口函数&#xff08;Window Functions&#xff09;为数据分析和处理提供了强大的工具。窗口函数允许在查询结果集上执行计算&#xff0c;而不必使用子查询或连接&#xff0c;这使得某些类型的计算更加高效和简洁。 语法结构 function_…

微型气象仪在城市环境的应用

微型气象仪凭借其体积小、成本低、部署灵活、数据实时性强等特点&#xff0c;在城市环境中得到广泛应用&#xff0c;能够为城市规划、环境管理、公共安全、居民生活等领域提供精细化气象数据支持。一、核心应用场景1. 城市微气候监测与优化热岛效应研究场景&#xff1a;在城市不…

【仿muduo库实现并发服务器】eventloop模块

仿muduo库实现并发服务器一.eventloop模块1.成员变量std::thread::id _thread_id;//线程IDPoller _poll;int _event_fd;std::vector<Function<Function>> _task;TimerWheel _timer_wheel2.EventLoop构造3.针对eventfd的操作4.针对poller的操作5.针对threadID的操作…

Redis 加锁、解锁

Redis 加锁和解锁的应用 上代码 应用调用示例 RedisLockEntity lockEntityYlb RedisLockEntity.builder().lockKey(TradeConstants.HP_APP_AMOUNT_LOCK_PREFIX appUser.getAccount()).value(orderId).build();boolean isLockedYlb false;try {if (redisLock.tryLock(lockE…

在 Windows 上为 WSL 增加 root 账号密码并通过 Shell 工具连接

1. 为 WSL 设置 root 用户密码 在 Windows 上使用 WSL&#xff08;Windows Subsystem for Linux&#xff09;时&#xff0c;默认情况下并没有启用 root 账号的密码。为了通过 SSH 或其他工具以 root 身份连接到 WSL&#xff0c;我们需要为 root 用户设置密码。 设置 root 密码步…

2730、找到最长的半重复子字符穿

题目&#xff1a; 解答&#xff1a; 窗口为[left&#xff0c;right]&#xff0c;ans为窗口长度&#xff0c;same为子串长度&#xff0c;窗口满足题设条件&#xff0c;即只含一个连续重复字符&#xff0c;则更新ans&#xff0c;否则从左边开始一直弹出&#xff0c;直到满足条件…

MCP Java SDK源码分析

MCP Java SDK源码分析 一、引言 在当今人工智能飞速发展的时代&#xff0c;大型语言模型&#xff08;LLMs&#xff09;如GPT - 4、Claude等展现出了强大的语言理解和生成能力。然而&#xff0c;这些模型面临着一个核心限制&#xff0c;即无法直接访问外部世界的数据和工具。M…

[Linux]内核如何对信号进行捕捉

要理解Linux中内核如何对信号进行捕捉&#xff0c;我们需要很多前置知识的理解&#xff1a; 内核态和用户态的区别CPU指令集权限内核态和用户态之间的切换 由于文章的侧重点不同&#xff0c;上面这些知识我会在这篇文章尽量详细提及&#xff0c;更详细内容还得请大家查看这篇…

设计模式-观察者模式、命令模式

观察者模式Observer&#xff08;观察者&#xff09;—对象行为型模式定义&#xff1a;定义了一种一对多的依赖关系,让多个观察者对象同时监听某一主题对象,在它的状态发生变化时,会通知所有的观察者.先将 Observer A B C 注册到 Observable &#xff0c;那么当 Observable 状态…

【Unity笔记01】基于单例模式的简单UI框架

单例模式的UIManagerusing System.Collections; using System.Collections.Generic; using UnityEngine;public class UIManager {private static UIManager _instance;public Dictionary<string, string> pathDict;public Dictionary<string, GameObject> prefab…

深入解析 OPC UA:工业自动化与物联网的关键技术

在当今快速发展的工业自动化和物联网&#xff08;IoT&#xff09;领域&#xff0c;数据的无缝交换和集成变得至关重要。OPC UA&#xff08;Open Platform Communications Unified Architecture&#xff09;作为一种开放的、跨平台的工业通信协议&#xff0c;正在成为这一领域的…

MCP 协议的未来发展趋势与学习路径

MCP 协议的未来发展趋势 6.1 MCP 技术演进与更新 MCP 协议正在快速发展&#xff0c;不断引入新的功能和改进。根据 2025 年 3 月 26 日发布的协议规范&#xff0c;MCP 的最新版本已经引入了多项重要更新&#xff1a; 1.HTTP Transport 正式转正&#xff1a;引入 Streamable …

硬件嵌入式学习路线大总结(一):C语言与linux。内功心法——从入门到精通,彻底打通你的任督二脉!

嵌入式工程师学习路线大总结&#xff08;一&#xff09; 引言&#xff1a;C语言——嵌入式领域的“屠龙宝刀”&#xff01; 兄弟们&#xff0c;如果你想在嵌入式领域闯出一片天地&#xff0c;C语言就是你手里那把最锋利的“屠龙宝刀”&#xff01;它不像Python那样优雅&#xf…

MCP server资源网站去哪找?国内MCP服务合集平台有哪些?

在人工智能飞速发展的今天&#xff0c;AI模型与外部世界的交互变得愈发重要。一个好的工具不仅能提升开发效率&#xff0c;还能激发更多的创意。今天&#xff0c;我要给大家介绍一个宝藏平台——AIbase&#xff08;<https://mcp.aibase.cn/>&#xff09;&#xff0c;一个…

修改Spatial-MLLM项目,使其专注于无人机航拍视频的空间理解

修改Spatial-MLLM项目&#xff0c;使其专注于无人机航拍视频的空间理解。以下是修改方案和关键代码实现&#xff1a; 修改思路 输入处理&#xff1a;将原项目的视频文本输入改为单一无人机航拍视频/图像输入问题生成&#xff1a;自动生成空间理解相关的问题&#xff08;无需用户…

攻防世界-Reverse-insanity

知识点 1.ELF文件逆向 2.IDApro的使用 3.strings的使用 步骤 方法一&#xff1a;IDA 使用exeinfo打开&#xff0c;发现是32位ELF文件&#xff0c;然后用ida32打开。 找到main函数&#xff0c;然后F5反编译&#xff0c;得到flag。 tip&#xff1a;该程序是根据随机函数生成…

【openp2p】 学习1:P2PApp和优秀的go跨平台项目

P2PApp下面给出一个基于 RESTful 风格的 P2PApp 管理方案示例,供二次开发或 API 对接参考。核心思路就是把每个 P2PApp 当成一个可创建、查询、修改、启动/停止、删除的资源来管理。 一、P2PApp 资源模型 P2PApp:id: string # 唯一标识name: string # …

边缘设备上部署模型的限制之一——显存占用:模型的参数量只是冰山一角

边缘设备上部署模型的限制之一——显存占用&#xff1a;模型的参数量只是冰山一角 在边缘设备上部署深度学习模型已成为趋势&#xff0c;但资源限制是其核心挑战之一。其中&#xff0c;显存&#xff08;或更广义的内存&#xff09;占用是开发者们必须仔细考量的重要因素。许多…