•         🍨 本文为🔗365天深度学习训练营中的学习记录博客
  •         🍖 原作者:K同学啊

一、前期准备

1.加载数据
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")
#忽略警告信息#win10系统,调用GPU运行
#device = torch.device("cuda" if torch.cuda.is_available()else "cpu")
#devicedevice = torch.device("cpu")
device
device(type='cpu')
from torchtext.datasets import AG_NEWStrain_iter = list(AG_NEWS(split='train')) #加载 AG_News 数据集
num_class = len(set([label for (label, text) in train_iter]))
2.构建词典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iteratortokenizer = get_tokenizer('basic_english') # 返回分词器函数,训练营内“get tokenizer函数详解def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) #设置默认索引,如果找不到单词,则会选择默认索引
vocab(['here','is','an','example'])

 [475, 21, 30, 5297]

text_pipeline= lambda x:vocab(tokenizer(x))
label_pipeline=lambda x:int(x)-1text_pipeline('here is the an example')

 [475, 21, 2, 30, 5297]

label_pipeline('10')

3.生成数据批次和迭代器
from torch.utils.data import DataLoaderdef collate_batch(batch):label_list,text_list,offsets =[],[],[0]for(_label, _text) in batch:#标签列表label_list.append(label_pipeline(_label))#文本列表processed_text = torch.tensor(text_pipeline(_text),dtype=torch.int64)text_list.append(processed_text)#偏移量,即语句的总词汇量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  #返回维度dim中输入元素的累计和return label_list.to(device),text_list.to(device),offsets.to(device)#数据加载器
dataloader =DataLoader(train_iter,batch_size=8,shuffle =False,collate_fn=collate_batch)

二、准备模型

1.定义模型
from torch import nnclass TextclassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextclassificationModel,self).__init__()self.embedding =nn.EmbeddingBag(vocab_size, #词典大小embed_dim,  #嵌入的维度sparse=False) #self.fc =nn.Linear(embed_dim,num_class)self.init_weights()def init_weights(self):initrange =0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self,text, offsets):embedded =self.embedding(text,offsets)return self.fc(embedded)
2.定义实例
num_class = len(set([label for(label,text)in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextclassificationModel(vocab_size,em_size,num_class).to(device)
3.定义训练函数和评估函数
import timedef train(dataloader, model, optimizer, criterion, epoch):model.train()total_acc, train_loss, total_count = 0, 0, 0log_interval = 500start_time = time.time()for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()loss = criterion(predicted_label, label)loss.backward()optimizer.step()total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc / total_count, train_loss / total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader, model, criterion):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 计算loss值# 记录测试数据total_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

三、训练模型

1.拆分数据集并运行模型
import timedef train(dataloader, model, optimizer, criterion, epoch):model.train()total_acc, train_loss, total_count = 0, 0, 0log_interval = 500start_time = time.time()for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()loss = criterion(predicted_label, label)loss.backward()optimizer.step()total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc / total_count, train_loss / total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader, model, criterion):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 计算loss值# 记录测试数据total_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count
| epoch 1 |  500/1782 batches | train_acc 0.904 train_loss 0.00450
| epoch 1 | 1000/1782 batches | train_acc 0.903 train_loss 0.00455
| epoch 1 | 1500/1782 batches | train_acc 0.904 train_loss 0.00443
---------------------------------------------------------------------
| epoch 1 | time:11.72s | valid_acc 0.901 valid_loss 0.005
---------------------------------------------------------------------
| epoch 2 |  500/1782 batches | train_acc 0.918 train_loss 0.00379
| epoch 2 | 1000/1782 batches | train_acc 0.920 train_loss 0.00377
| epoch 2 | 1500/1782 batches | train_acc 0.913 train_loss 0.00399
---------------------------------------------------------------------
| epoch 2 | time:11.52s | valid_acc 0.907 valid_loss 0.005
---------------------------------------------------------------------
| epoch 3 |  500/1782 batches | train_acc 0.930 train_loss 0.00323
| epoch 3 | 1000/1782 batches | train_acc 0.925 train_loss 0.00345
| epoch 3 | 1500/1782 batches | train_acc 0.925 train_loss 0.00350
---------------------------------------------------------------------
| epoch 3 | time:11.77s | valid_acc 0.915 valid_loss 0.004
---------------------------------------------------------------------
| epoch 4 |  500/1782 batches | train_acc 0.937 train_loss 0.00294
| epoch 4 | 1000/1782 batches | train_acc 0.931 train_loss 0.00317
| epoch 4 | 1500/1782 batches | train_acc 0.927 train_loss 0.00332
---------------------------------------------------------------------
| epoch 4 | time:11.81s | valid_acc 0.914 valid_loss 0.004
---------------------------------------------------------------------
| epoch 5 |  500/1782 batches | train_acc 0.951 train_loss 0.00243
| epoch 5 | 1000/1782 batches | train_acc 0.950 train_loss 0.00243
| epoch 5 | 1500/1782 batches | train_acc 0.949 train_loss 0.00245
---------------------------------------------------------------------
| epoch 5 | time:11.94s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------
| epoch 6 |  500/1782 batches | train_acc 0.951 train_loss 0.00236
| epoch 6 | 1000/1782 batches | train_acc 0.951 train_loss 0.00241
| epoch 6 | 1500/1782 batches | train_acc 0.951 train_loss 0.00241
---------------------------------------------------------------------
| epoch 6 | time:11.69s | valid_acc 0.918 valid_loss 0.004
---------------------------------------------------------------------
| epoch 7 |  500/1782 batches | train_acc 0.952 train_loss 0.00233
| epoch 7 | 1000/1782 batches | train_acc 0.952 train_loss 0.00236
| epoch 7 | 1500/1782 batches | train_acc 0.952 train_loss 0.00235
---------------------------------------------------------------------
| epoch 7 | time:11.88s | valid_acc 0.920 valid_loss 0.004
---------------------------------------------------------------------
| epoch 8 |  500/1782 batches | train_acc 0.953 train_loss 0.00233
| epoch 8 | 1000/1782 batches | train_acc 0.954 train_loss 0.00226
| epoch 8 | 1500/1782 batches | train_acc 0.953 train_loss 0.00229
---------------------------------------------------------------------
| epoch 8 | time:11.92s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------
| epoch 9 |  500/1782 batches | train_acc 0.956 train_loss 0.00223
| epoch 9 | 1000/1782 batches | train_acc 0.955 train_loss 0.00219
| epoch 9 | 1500/1782 batches | train_acc 0.955 train_loss 0.00223
---------------------------------------------------------------------
| epoch 9 | time:11.78s | valid_acc 0.919 valid_loss 0.004
---------------------------------------------------------------------
| epoch 10 |  500/1782 batches | train_acc 0.955 train_loss 0.00226
| epoch 10 | 1000/1782 batches | train_acc 0.954 train_loss 0.00223
| epoch 10 | 1500/1782 batches | train_acc 0.955 train_loss 0.00221
---------------------------------------------------------------------
| epoch 10 | time:11.82s | valid_acc 0.919 valid_loss 0.004
---------------------------------------------------------------------
2.使用测试数据集评估模型 
print('checking the results of test dataset.')
test_acc,test_loss = evaluate(test_dataloader,model, criterion)
print('test accuracy{:8.3f}'.format(test_acc))

四、学习心得

       本周额外安装了 portalocker 库,并且下载了AG_News数据集,并TextClassificationModel模型,首先对文本进行嵌入,然后对句子嵌入之后的结果进行均值聚合,从而最终实现了文本分类的任务。在训练过程出现一些问题得到有效解决。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/86835.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/86835.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/86835.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniappx 安卓app项目本地打包运行,腾讯地图报错:‘鉴权失败,请检查你的key‘

根目录下添加 AndroidManifest.xml 文件&#xff0c; <application><meta-data android:name"TencentMapSDK" android:value"腾讯地图申请的key" /> </application> manifest.json 文件中添加&#xff1a; "app": {"…

【向上教育】结构化面试开口秘籍.pdf

向 上 教 育 XI A N G S H A N G E D U C A T I O N 结构化 面试 开口秘笈 目 录 第一章 自我认知类 ........................................................................................................................... 2 第二章 工作关系处理类 .......…

Webpack 热更新(HMR)原理详解

&#x1f525; Webpack 热更新&#xff08;HMR&#xff09;原理详解 &#x1f4cc; 本文适用于 Vue、React 等使用 Webpack 的项目开发者&#xff0c;适配 Vue CLI / 自定义 Webpack 项目。 &#x1f3af; 一、什么是 HMR&#xff1f; Hot Module Replacement 是 Webpack 提供的…

MySQL索引完全指南

一、索引是什么&#xff1f;为什么这么重要&#xff1f; 索引就像字典的目录 想象一下&#xff0c;你要在一本1000页的字典里找"程序员"这个词&#xff0c;你会怎么做&#xff1f; 没有目录&#xff1a;从第1页开始一页一页翻&#xff0c;可能要翻500页才能找到有…

学习使用dotnet-dump工具分析.net内存转储文件(2)

运行ShenNiusModularity项目&#xff0c;使用createdump工具dump完整的进程内存映射文件&#xff0c;然后运行dotnet-dump analyze命令加载dump文件。   可以先使用dumpheap命令显示有关垃圾回收堆的信息和有关对象的收集统计信息。dumpheap支持多类参数&#xff08;如下所示…

Oracle BIEE 交互示例(一)同一分析内

Oracle BIEE 交互示例(一)同一分析内 1 示例背景2 实践目标3 实操步骤3.1 创建数据集3.1.1 TEST_TABLE3.1.2 保存名字为【01 TEST_TABLE】3.2 创建分析3.2.1 创建列3.2.2 创建视图3.2.2.1 数据透视表3.2.2.2 图形3.2.2.3 表3.3 设置交互4 结果示例1 示例背景 版本:OBIEE 12…

使用API有效率地管理Dynadot域名,出售账户中的域名

关于Dynadot Dynadot是通过ICANN认证的域名注册商&#xff0c;自2002年成立以来&#xff0c;服务于全球108个国家和地区的客户&#xff0c;为数以万计的客户提供简洁&#xff0c;优惠&#xff0c;安全的域名注册以及管理服务。 Dynadot平台操作教程索引&#xff08;包括域名邮…

Vite 打包原理详解 + Webpack 对比

&#x1f680; Vite 打包原理详解 Webpack 对比 &#x1f44b; 本文适合&#xff1a;Vite 使用者、Vue/React 工程师、希望搞清楚打包流程及与 Webpack 区别的开发者 &#x1f310; 技术背景&#xff1a;Vite 采用 ES Modules 原生浏览器能力驱动开发体验&#xff0c;Webpack…

区块链RWA(Real World Assets)系统开发全栈技术架构与落地实践指南

一、技术架构设计&#xff1a;分层架构与模块协同 1. 核心区块链层 区块链选型策略&#xff1a; 公链&#xff1a;以太坊主网&#xff08;安全性高&#xff0c;DeFi生态完备&#xff09; Polygon CDK&#xff08;Layer2定制化合规链&#xff0c;Gas费低至$0.003&#xff09;…

GBDT:梯度提升决策树——集成学习中的预测利器

核心定位&#xff1a;一种通过串行集成弱学习器&#xff08;决策树&#xff09;、以梯度下降方式逐步逼近目标函数的机器学习算法&#xff0c;在结构化数据预测任务中表现出色。 本文由「大千AI助手」原创发布&#xff0c;专注用真话讲AI&#xff0c;回归技术本质。拒绝神话或妖…

Redis持久化机制深度解析:RDB与AOF全面指南

摘要 本文深入剖析Redis的持久化机制&#xff0c;全面讲解RDB和AOF两种持久化方式的原理、配置与应用场景。通过详细的操作步骤和原理分析&#xff0c;您将掌握如何配置Redis持久化策略&#xff0c;确保数据安全性与性能平衡。文章包含思维导图概览、命令实操演示、核心原理图…

CentOS7升级openssh10.0p2和openssl3.5.0详细操作步骤

背景 近期漏洞扫描时&#xff0c;发现有很多关于openssh的相关高危漏洞&#xff0c;因此需要升级openssh的版本 升级步骤 由于openssh和openssl的版本是需要相匹配的&#xff0c;这次计划将openssh升级至10.0p2版本&#xff0c;将openssl升级至3.5.0版本&#xff0c;都是目前…

fishbot随身系统安装nvidia显卡驱动

小鱼的fishbot是已经配置好的ubuntu22.04,我听说在预先配置系统时需要勾选安装第三方图形化软件&#xff0c;不然直接安装会有进不去图形化界面的风险&#xff0c;若没有勾选&#xff0c;建议使用其他安装方法&#xff0c;比如禁用系统自带的驱动那套安装流程 1.打开设置->关…

学习昇腾开发的第十天--ffmpeg推拉流

1、FFmpeg推流 注意&#xff1a;在推流之前先运行rtsp-simple-server&#xff08;mediamtx&#xff09; ./mediamtx 1.1 UDP推流 ffmpeg -re -i input.mp4 -c copy -f rtsp rtsp://127.0.0.1:8554/stream 1.2 TCP推流 ffmpeg -re -i input.mp4 -c copy -rtsp_transport t…

成为一名月薪 2 万的 web 安全工程师需要掌握哪些技能??

现在 web 安全工程师比较火&#xff0c;岗位比较稀缺&#xff0c;现在除了一些大公司对学历要求严格&#xff0c;其余公司看中的大部分是能力。 有个亲戚的儿子已经工作 2 年了……当初也是因为其他的行业要求比较高&#xff0c;所以才选择的 web 安全方向。 资料免费分享给你…

Pytorch8实现CNN卷积神经网络

CNN卷积神经网络 本章提供一个对CNN卷积网络的快速实现 全连接网络 VS 卷积网络 全连接神经网络之所以不太适合图像识别任务&#xff0c;主要有以下几个方面的问题&#xff1a; 参数数量太多 考虑一个输入10001000像素的图片(一百万像素&#xff0c;现在已经不能算大图了)&…

平地起高楼: 环境搭建

技术选型 本小册是采用纯前端的技术栈模拟实现小程序架构的系列文章&#xff0c;所以主要以前端技术栈为主&#xff0c;但是为了模拟一个App应用的效果&#xff0c;以及小程序包下载管理流程的实现&#xff0c;我们还是需要搭建一个基础的App应用。这里我们将选择 Tauri2.0 来…

langgraph学习2 - MCP编程

3中通信方式&#xff1a; 目前sse用的很少 3.开发mcp框架 主流框架2个&#xff1a; MCP skd 官方 Fast Mcp V2 &#xff0c;&#xff08;V1捐给MCP 官方&#xff09; 大模型如何识别用哪个tools&#xff0c; 以及如何使用tools&#xff1a;

CSS 与 JavaScript 加载优化

&#x1f4c4; CSS 与 JavaScript 加载优化指南&#xff1a;位置、阻塞与性能 让你的网页飞起来&#xff01;&#x1f680; 本文详细解析 CSS 和 JavaScript 标签的放置位置如何影响页面性能&#xff0c;涵盖阻塞原理、浏览器机制和最佳实践。掌握这些知识可显著提升用户体验…

WSL安装发行版上安装podman

WSL安装发行版上安装podman 1.WSL拉取发行版1.1 拉取2.2.修改系统拉取的镜像&#xff0c;可以加速软件包的更新 2.podman安装2.1.安装podman 容器工具2.2.配置podman的镜像仓库2.3.拉取n8n镜像并创建容器 本文在windows11上&#xff0c;使用WSL拉取并创建ubuntu24.04虚拟机&…