• 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

一、准备工作

数据格式:

import torch
from torch import nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")device = torch.device("cuda" if torch.cuda.is_available else "cpu")import pandas as pd# CSV 格式通常为 无表头(header=None),以制表符(sep='\t')分隔
train_data = pd.read_csv('./data/train.csv',sep='\t',header=None)
train_data.head()

# 构造数据集迭代器
def custom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,ytrain_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])

二、数据预处理

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba# 中文分词方法
tokenizer = jieba.lcutdef yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])label_name = list(set(train_data[1].values[:]))text_pipeline = lambda x:vocab(tokenizer(x))
label_pipeline = lambda x:label_name.index(x)

三、模型搭建

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self,vocab_size,embed_dim,num_class):super(TextClassificationModel,self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,embed_dim)self.fc = nn.Linear(embed_dim,num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange,initrange)self.fc.weight.data.uniform_(-initrange,initrange)self.fc.bias.data.zero_()def forward(self,text,offsets):embedded = self.embedding(text,offsets)return self.fc(embedded)
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size,em_size,num_class).to(device)
model

import timedef train(dataloader):model.train()total_acc,train_loss,total_count = 0,0,0log_interval = 50start_time = time.time()for idx,(text,label,offsets) in enumerate(dataloader):predicted_label = model(text,offsets)optimizer.zero_grad()loss = criterion(predicted_label,label)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),0.1) # 梯度裁剪optimizer.step()total_acc += (predicted_label.argmax(1)==label).sum().item()train_loss += loss.item()*label.size(0)total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()total_acc,test_loss,total_count =0,0,0with torch.no_grad():for idx,(text,label,offsets) in enumerate(dataloader):predicted_label = model(text,offsets)loss = criterion(predicted_label,label)total_acc += (predicted_label.argmax(1)==label).sum().item()test_loss += loss.item()*label.size(0)total_count += label.size(0)return total_acc/total_count,test_loss/total_count

四、训练模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset# 超参数
EPOCHS = 10
LR = 5
BATCH_SIZE = 64criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)
total_accu = Nonetrain_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)num_train = int(len(train_dataset)*0.8)
split_train,split_valid = random_split(train_dataset,[num_train,len(train_dataset)-num_train])train_dataloader = DataLoader(split_train,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_batch)for epoch in range(1,EPOCHS+1):epoch_start_time = time.time()train(train_dataloader)val_acc,val_loss = evaluate(valid_dataloader)lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:1d} | time: {:4.2f}s | ''valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,time.time()-epoch_start_time,val_acc,val_loss,lr))print('-' * 69)
def predict(text):with torch.no_grad():text = torch.tensor(text_pipeline(text)).to(device)output = model(text,torch.tensor([0]).to(device))return output.argmax(1).item()
# ex_text_str = "还有南昌到哈尔滨西的火车票吗?"
ex_text_str = "我想听TWICE的新曲"
print("该文本的类别是:%s" %label_name[predict(ex_text_str)])

总结

本次学习对中文文本实现了分类,主要代码和N1周基本一致。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/98941.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/98941.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/98941.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【代码随想录day 24】 力扣 90. 集合II

视频讲解&#xff1a;https://www.bilibili.com/video/BV1vm4y1F71J/?vd_sourcea935eaede74a204ec74fd041b917810c 文档讲解&#xff1a;https://programmercarl.com/0090.%E5%AD%90%E9%9B%86II.html#%E6%80%9D%E8%B7%AF 力扣题目&#xff1a;https://leetcode.cn/problems/su…

.NET 6 文件下载

.NET 6 API中实现文件的下载。创建HttpHeaderConstant用于指定http头。public sealed class HttpHeaderConstant{public const string RESPONSE_HEADER_CONTENTTYPE_STREAM "application/octet-stream";public const string RESPONSE_HEADER_NAME_FILENAME "f…

[数据结构——lesson6.栈]

目录 引言 1.栈的概念和结构 栈的核心概念 栈的结构 2.栈的实现 2.1栈的实现方式 2.2栈的功能 2.3栈的声明 1.顺序栈 2。链式栈 2.4栈的功能实现 1.栈的初始化 2.判断栈是否为空 3.返回栈顶元素 4.返回栈的大小 5.元素入栈 6.元素出栈 7.打印栈的元素 8.销毁…

华为HICE云计算的含金量高吗?

在数字时代的今天&#xff0c;云计算技术证飞速的发展成为企业数字化转型的重要支撑。而华为作为领先的通信和信息技术公司&#xff0c;推出的HCIE云计算认证备受关注。接下来就来说说华为HCIE云计算认证的含金量到底有多高。HCIE认证被认为是华为认证中的最高等级&#xff0c;…

OSPF协议原理讲解和实际配置(华为/思科)

OSPF&#xff08;open shorest path first&#xff0c;开放最短路径优先&#xff09;是一种动态的&#xff0c;基于链路状态的动态路由协议&#xff0c;广泛的应用在企业网络中&#xff0c;通过维护网络拓扑信息&#xff0c;利用 Dijkstra 算法实现最短路径&#xff0c;实现高效…

【开题答辩全过程】以 《黄帝内经》问答系统为例,包含答辩的问题和答案

个人简介一名14年经验的资深毕设内行人&#xff0c;语言擅长Java、php、微信小程序、Python、Golang、安卓Android等开发项目包括大数据、深度学习、网站、小程序、安卓、算法。平常会做一些项目定制化开发、代码讲解、答辩教学、文档编写、也懂一些降重方面的技巧。感谢大家的…

npm : 无法加载文件 C:\Program Files\nodejs\npm.ps1,因为在此系统上禁止运行脚

这个错误是由于 PowerShell 的执行策略限制&#xff0c;导致无法运行脚本。你可以通过以下步骤解决这个问题&#xff1a; 1. 查看当前的执行策略 打开 PowerShell&#xff0c;以管理员身份运行&#xff0c;输入以下命令查看当前的执行策略&#xff1a; Get-ExecutionPolicy如果…

macOS苹果电脑运行向日葵远程控制软件闪退

文章目录问题原因分析修复附录向日葵字太小按Ctrl键会弹出开始菜单的问题问题 向日葵是一款远程控制的应用&#xff0c;在macOS下也能运行&#xff0c; 本来用的好好的&#xff0c;有一天升级后突然就运行不起来了&#xff0c;一点开能显示几秒首界面&#xff0c;立马就自动退…

Linux dma-buf 框架原理、实现与应用详解

1. 背景与意义 1.1 异构系统与缓冲区共享的挑战 在现代 SoC、嵌入式、图形和多媒体系统中&#xff0c;CPU、GPU、VPU、ISP、DMA 控制器等多个硬件单元需要高效地共享和传递大块数据&#xff08;如图像帧、视频流、AI 张量等&#xff09;。如果每个设备都维护独立的缓冲区&…

Scikit-learn Python机器学习 - 分类算法 - 朴素贝叶斯

锋哥原创的Scikit-learn Python机器学习视频教程&#xff1a; https://www.bilibili.com/video/BV11reUzEEPH 课程介绍 ​ 本课程主要讲解基于Scikit-learn的Python机器学习知识&#xff0c;包括机器学习概述&#xff0c;特征工程(数据集&#xff0c;特征抽取&#xff0c;特…

如何免费股票数据API(第13期):沪深A股《最新分时交易》数据获取大全:附Python、Java等多语言实战教程与接口文档说明

在金融科技迅猛发展的今天&#xff0c;股票量化分析以其严谨的科学性和强大的系统性&#xff0c;正日益成为投资领域的主流方法论。任何卓越的量化模型的诞生&#xff0c;都离不开全面、精准、及时的数据支撑。无论是跃动着的实时交易数据、沉淀了历史规律的K线走势&#xff0c…

国标GB28181视频EasyGBS视频监控平台:一网联全城,交通道路可视化、视频巡检、应急指挥“三合一”。

一、方案背景​人车暴涨&#xff0c;路口告急&#xff1a;高峰堵、事故慢、取证难&#xff0c;老办法已拖不动城市交通。破局之道&#xff0c;先看摄像头——EasyGBS 严格遵循 GB28181 国标&#xff0c;一站式完成直播、存储、检索、转码&#xff0c;把万千路口秒级搬上云端&am…

单元测试(白盒测试方法)

一、单元测试1.单元测试是对软件的基本组成单元进行的测试&#xff0c;如函数、类或类的方法。单元测试是对软件的最小可测试单元&#xff08;即可独立编译或汇编的程序模块&#xff09;进行的测试活动&#xff0c;也称为模块测试二、白盒测试方法实例代码public static int te…

2010-2022 同等学力申硕国考:软件工程简答题真题汇总

2010年简答题 给出数据流图的定义&#xff0c;并举例说明数据流图的四个基本构成成份。 数据流图&#xff08;Data Flow Diagram, DFD&#xff09;是一种用于描述系统中数据流动和处理过程的图形工具。它通过直观的方式展示了系统的输入数据如何经过一系列处理变换为输出数据&a…

海外盲盒APP开发:如何用技术重构“惊喜经济”

当盲盒的神秘感遇上技术的确定性&#xff0c;一场关于消费体验的革命正在海外市场悄然发生。从概率算法的公平性到AR虚拟开箱的沉浸感&#xff0c;从跨境物流的实时追踪到多语言支持的无缝切换&#xff0c;海外盲盒APP的开发是一场技术、设计与商业逻辑的深度融合。概率算法&am…

Aosp13 手机sim卡信号格显示修改

工作中&#xff0c;客户要求对信号格显示偏弱不够友好为由&#xff0c;提出修改&#xff0c;要求使其显示信号强一些。在此记录 一问题&#xff1a;修改系统sim卡显示的信号格&#xff0c;在设备其他配置不变的情况下&#xff0c;使其信号格显示比原有的要优秀二 …

硬件开发2-汇编2(ARMv7-A)- 裸机开发

一、指令1、b&#xff08;Branch&#xff09;原型&#xff1a;B<c> <label>作用&#xff1a;实现无条件跳转&#xff0c;常用于不返回的跳转场景特点&#xff1a;仅跳转到目标地址&#xff0c;不保存返回地址示例&#xff1a;b reset ;跳转到reset标号处执…

清源 SCA 社区版更新(V4.2.0)|漏洞前置感知、精准修复、合规清晰,筑牢软件供应链安全防线!

随着数字化进程加速&#xff0c;软件供应链安全威胁日益复杂&#xff0c;公开漏洞响应滞后、0day 攻击防不胜防、组件升级编译失败、安全与合规风险混杂......这些痛点让企业安全团队、运维人员及研发团队疲于应对。自 2025 年 7 月 1 日安势清源 SCA 社区版首次正式发布以及在…

氚燃料增殖里程碑:MIT新型BABY包层技术实验验证

● 导语 5月20日&#xff0c;麻省理工学院&#xff08;MIT&#xff09;发文称&#xff0c;BABY实验首次获取了氚在装置内增殖的实测数据&#xff0c;验证了核心模型&#xff0c;并为未来核聚变电厂的燃料自循环奠定了重要基础。 原文&#x1f447;&#x1f3fb; https://m…

python+springboot+uniapp微信小程序题库系统 在线答题 题目分类 错题本管理 学习记录查询系统

目录技术栈介绍具体实现截图系统设计研究方法&#xff1a;设计步骤设计流程核心代码部分展示研究方法详细视频演示试验方案论文大纲源码获取/详细视频演示技术栈介绍 Django-SpringBoot-php-Node.js-flask 本课题的研究方法和研究步骤基本合理&#xff0c;难度适中&#xff0…