本文目录:

  • 一、导入工具包
  • 二、数据集
  • 三、 构建词表
  • 四、 构建数据集对象
  • 五、 构建网络模型
  • 六、 构建训练函数
  • 七、构建预测函数

前言:上篇文章讲解了RNN,这篇文章分享文本生成任务案例:文本生成是一种常见的自然语言处理任务,输入一个开始词能够预测出后面的词序列。本案例将会使用循环神经网络来实现周杰伦歌词生成任务。
在这里插入图片描述

一、导入工具包

import torch
import jieba
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import time

二、数据集

我们收集了周杰伦从第一张专辑《Jay》到第十张专辑《跨时代》中的歌词,来训练神经网络模型,当模型训练好后,我们就可以用这个模型来创作歌词。数据集如下:

想要有直升机
想要和你飞到宇宙去
想要和你融化在一起
融化在宇宙里
我每天每天每天在想想想想著你
这样的甜蜜
让我开始相信命运
感谢地心引力
让我碰到你
漂亮的让我面红的可爱女人
...

该数据集共有 5819 行文本。

三、 构建词表

在进行自然语言处理任务之前,首要做的就是构建词表。

所谓的词表就是将语料进行分词,然后给每一个词分配一个唯一的编号,便于我们送入词嵌入层。

在这里插入图片描述
接下来, 我们对周杰伦歌词的数据进行处理构建词表,具体流程如下:

  • 获取文本数据
  • 分词,并进行去重
  • 构建词表
# 获取数据,并进行分词,构建词表
def build_vocab():# 数据集位置file_name = 'data/jaychou_lyrics.txt'# 分词结果存储位置# 唯一词列表unique_words = []# 每行文本分词列表all_words = []# 遍历数据集中的每一行文本for line in open(file_name, 'r', encoding='utf-8'):# 使用jieba分词,分割结果是一个列表words = jieba.lcut(line)# print(words)# 所有的分词结果存储到all_words,其中包含重复的词组all_words.append(words)# 遍历分词结果,去重后存储到unique_wordsfor word in words:if word not in unique_words:unique_words.append(word)# 语料中词的数量word_count = len(unique_words)# 词到索引映射word_to_index = {word: idx for idx, word in enumerate(unique_words)}# 歌词文本用词表索引表示corpus_idx = []# 遍历每一行的分词结果for words in all_words:temp = []# 获取每一行的词,并获取相应的索引for word in words:temp.append(word_to_index[word])# 在每行词之间添加空格隔开temp.append(word_to_index[' '])# 获取当前文档中每个词对应的索引corpus_idx.extend(temp)return unique_words, word_to_index, word_count, corpus_idxif __name__ == "__main__":# 获取数据unique_words, word_to_index, unique_word_count, corpus_idx = build_vocab()print("词的数量:\n",unique_word_count)print("去重后的词:\n",unique_words)print("每个词的索引:\n",word_to_index)print("当前文档中每个词对应的索引:\n",corpus_idx)

我们的词典主要包含了:

  • unique_words: 存储了每个词

  • word_to_index: 存储了词到编号的映射

在这里插入图片描述

四、 构建数据集对象

我们在训练的时候,为了便于读取语料,并送入网络,所以我们会构建一个Dataset对象。

class LyricsDataset(torch.utils.data.Dataset):def __init__(self, corpus_idx, num_chars):# 文档数据中词的索引self.corpus_idx = corpus_idx# 每个句子中词的个数self.num_chars = num_chars# 文档数据中词的数量,不去重self.word_count = len(self.corpus_idx)# 句子数量self.number = self.word_count // self.num_chars# len(obj)时自动调用此方法def __len__(self):# 返回句子数量return self.number# obj[idx]时自动调用此方法def __getitem__(self, idx):# idx指词的索引,并将其修正索引值到文档的范围里面"""我 爱你 中国 , 亲爱 的 母亲word_count: 7num_chars: 2 一个句子由num_chars个词组成word_count-num_chars-2: 7-2-1=4  -1:网络预测结果y在x上后移一个词取值-1idx=5->start=4"""start = min(max(idx, 0), self.word_count - self.num_chars - 1)end = start + self.num_chars# 输入值x = self.corpus_idx[start: end]# 网络预测结果(目标值)y = self.corpus_idx[start + 1: end + 1]# 返回结果return torch.tensor(x), torch.tensor(y)if __name__ == "__main__":# 获取数据unique_words, word_to_index, unique_word_count, corpus_idx = build_vocab()# 数据获取实例化dataset = LyricsDataset(corpus_idx, 5)# 查看句子数量print('句子数量:', len(dataset))# x, y = dataset.__getitem__(0)x, y = dataset[0]print("网络输入值:", x)print("目标值:", y)

运行结果:

句子数量: 9827
网络输入值: tensor([ 0,  1,  2,  3, 40])
目标值: tensor([ 1,  2,  3, 40,  0])

五、 构建网络模型

我们用于实现《歌词生成》的网络模型,主要包含了三个层:

  • 词嵌入层: 用于将语料转换为词向量

  • 循环网络层: 提取句子语义

  • 全连接层: 输出对词典中每个词的预测概率

# 模型构建
class TextGenerator(nn.Module):def __init__(self, unique_word_count):super(TextGenerator, self).__init__()# 初始化词嵌入层: 语料中词的数量, 词向量的维度为128self.ebd = nn.Embedding(unique_word_count, 128)# 循环网络层: 词向量维度128, 隐藏向量维度256, 网络层数1self.rnn = nn.RNN(128, 256, 1)# 输出层: 特征向量维度256与隐藏向量维度相同, 词表中词的个数self.out = nn.Linear(256, unique_word_count)def forward(self, inputs, hidden):# 输出维度: (batch, seq_len, 词向量维度128)# batch:句子数量# seq_len: 句子长度, 每个句子由多少个词 词数量embed = self.ebd(inputs)# rnn层x的表示形式为(seq_len, batch, 词向量维度128)# output的表示形式与输入x类似,为(seq_len, batch, 词向量维度256)# 前后的hidden形状要一样, 所以DataLoader加载器的batch数要能被整数output, hidden = self.rnn(embed.transpose(0, 1), hidden)# 全连接层输入二维数据, 词数量*词维度# 输入维度: (seq_len*batch, 词向量维度256) # 输出维度: (seq_len*batch, 语料中词的数量)# output: 每个词的分值分布,后续结合softmax输出概率分布output = self.out(output.reshape(shape=(-1, output.shape[-1])))# 网络输出结果return output, hiddendef init_hidden(self, bs):# 隐藏层的初始化:[网络层数, batch, 隐藏层向量维度]return torch.zeros(1, bs, 256)if __name__ == "__main__":# 获取数据unique_words, word_to_index, unique_word_count, corpus_idx = build_vocab()model = TextGenerator(unique_word_count)for named, parameter in model.named_parameters():print(named)print(parameter)

六、 构建训练函数

前面的准备工作完成之后, 我们就可以编写训练函数。训练函数主要负责编写数据迭代、送入网络、计算损失、反向传播、更新参数,其流程基本较为固定。

由于我们要实现文本生成,文本生成本质上,输入一串文本,预测下一个文本,也属于分类问题,所以,我们使用多分类交叉熵损失函数。优化方法我们学习过 SGB、AdaGrad、Adam 等,在这里我们选择学习率、梯度自适应的 Adam 算法作为我们的优化方法。

训练完成之后,我们使用 torch.save 方法将模型持久化存储。

def train():# 构建词典unique_words, word_to_index, unique_word_count, corpus_idx = build_vocab()# 数据集 LyricsDataset对象,并实现了 __getitem__ 方法lyrics = LyricsDataset(corpus_idx=corpus_idx, num_chars=32)# 查看句子数量# print(lyrics.number)# 初始化模型model = TextGenerator(unique_word_count)# 数据加载器 DataLoader对象,并将lyrics dataset对象传递给它lyrics_dataloader = DataLoader(lyrics, shuffle=True, batch_size=5)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.Adam(model.parameters(), lr=1e-3)# 训练轮数epoch = 10for epoch_idx in range(epoch):# 训练时间start = time.time()iter_num = 0  # 迭代次数# 训练损失total_loss = 0.0# 遍历数据集 DataLoader 会在后台调用 dataset.__getitem__(index) 来获取每个样本的数据和标签,并将它们组合成一个 batchfor x, y in lyrics_dataloader:# 隐藏状态的初始化hidden = model.init_hidden(bs=5)# 模型计算output, hidden = model(x, hidden)# 计算损失# y形状为(batch, seq_len), 需要转换成一维向量->160个词的下标索引# output形状为(seq_len, batch, 词向量维度)# 需要先将y进行维度交换(和output保持一致)再改变形状y = torch.transpose(y, 0, 1).reshape(shape=(-1,))loss = criterion(output, y)optimizer.zero_grad()loss.backward()optimizer.step()iter_num += 1  # 迭代次数加1total_loss += loss.item()# 打印训练信息print('epoch %3s loss: %.5f time %.2f' % (epoch_idx + 1, total_loss / iter_num, time.time() - start))# 模型存储torch.save(model.state_dict(), 'model/lyrics_model_%d.pth' % epoch)if __name__ == "__main__":train()

运行结果:

epoch   1 loss: 1.84424 time 5.75
epoch   2 loss: 0.21154 time 5.91
epoch   3 loss: 0.12014 time 5.85
epoch   4 loss: 0.10625 time 5.73
epoch   5 loss: 0.10226 time 5.58
epoch   6 loss: 0.10009 time 5.65
epoch   7 loss: 0.09942 time 5.66
epoch   8 loss: 0.09783 time 5.66
epoch   9 loss: 0.09663 time 5.75
epoch  10 loss: 0.09568 time 5.77

七、构建预测函数

从磁盘加载训练好的模型,进行预测。预测函数,输入第一个指定的词,我们将该词输入网路,预测出下一个词,再将预测的出的词再次送入网络,预测出下一个词,以此类推,知道预测出我们指定长度的内容。

def predict(start_word, sentence_length):# 构建词典unique_words, word_to_index, unique_word_count, _ = build_vocab()# 构建模型model = TextGenerator(unique_word_count)# 加载参数model.load_state_dict(torch.load('model/lyrics_model_10.pth'))# 隐藏状态hidden = model.init_hidden(bs=1)# 将起始词转换为索引word_idx = word_to_index[start_word]# 产生的词的索引存放位置generate_sentence = [word_idx]# 遍历到句子长度,获取每一个词for _ in range(sentence_length):# 模型预测output, hidden = model(torch.tensor([[word_idx]]), hidden)# 获取预测结果word_idx = torch.argmax(output)generate_sentence.append(word_idx)# 根据产生的索引获取对应的词,并进行打印for idx in generate_sentence:print(unique_words[idx], end='')if __name__ == '__main__':# 调用预测函数predict('分手', 50)

运行结果:

分手的话像语言暴力我已无能为力再提起 决定中断熟悉周杰伦 周杰伦一步两步三步四步望著天 看星星一颗两颗三颗四颗 连成线背著背默默许下心愿看远方的星

今天的分享到此结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/88882.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/88882.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/88882.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI时代的接口自动化优化实践:如何突破Postman的局限性

编者语:本文作者为某非银金融测试团队负责人。其团队自 2024 年起局部试用 Apipost,目前已在全团队正式投入使用 。在推进微服务 API 自动化测试的过程中,研发和测试人员常常需要在接口请求中动态构造带有特定业务规则的数据。我们团队就遇到…

动态规划题解_将一个数字表示成幂的和的方案数【LeetCode】

2787. 将一个数字表示成幂的和的方案数 给你两个正整数 n 和 x 。 请你返回将 n 表示成一些 互不相同 正整数的 x 次幂之和的方案数。换句话说,你需要返回互不相同整数 [n1, n2, ..., nk] 的集合数目,满足 n n1x n2x ... nkx 。 由于答案可能非常…

C#常用的LinQ方法

LINQ(Language Integrated Query)是 .NET 中用于处理集合的强大工具,它提供了多种方法来简化数据查询和操作。以下是一些常用的 LINQ 方法及其功能:Where: 根据指定的条件筛选集合中的元素。var filteredResults matchResults.Wh…

目标检测之数据增强

数据翻转,需要把bbox相应的坐标值也进行交换代码:import random from torchvision.transforms import functional as Fclass Compose(object):"""组合多个transform函数"""def __init__(self, transforms):self.transform…

DiffDet4SAR——首次将扩散模型用于SAR图像目标检测,来自2024 GRSL(ESI高被引1%论文)

一. 论文摘要 合成孔径雷达(SAR)图像中的飞机目标检测是一项具有挑战性的任务,由于离散的散射点和严重的背景杂波干扰。目前,基于卷积或基于变换的方法不能充分解决这些问题。 本文首次探讨了SAR图像飞机目标检测的扩散模型&#…

html案例:编写一个用于发布CSDN文章时,生成有关缩略图

CSDN博客文章缩略图生成器起因:之前注意到CSDN可以随机选取文章缩略图,但后来这个功能似乎取消了。于是我想调整一下缩略图的配色方案。html制作界面 界面分上下两块区域,上面是参数配置,下面是效果预览图。参数配置: …

lightgbm算法学习

主要组件 Boosting #mermaid-svg-1fiqPsJfErv6AV82 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-1fiqPsJfErv6AV82 .error-icon{fill:#552222;}#mermaid-svg-1fiqPsJfErv6AV82 .error-text{fill:#552222;stroke:#…

安卓基于 FirebaseAuth 实现 google 登录

安卓基于 FirebaseAuth 实现 google 登录 文章目录安卓基于 FirebaseAuth 实现 google 登录1. 前期准备1.1 创建 Firebase 项目1.2 将 Android 应用连接到 Firebase1.3 在 Firebase 控制台中启用 Google 登录2. 在 Android 应用中实现 Google 登录2.1 初始化 GoogleSignInClien…

李宏毅(Deep Learning)--(三)

一.前向传播与反向传播的理解:二.模型训练遇到的问题在模型训练中,我们可能会遇到效果不好的情况,那么我们应该怎么思考切入,找到问题所在呢?流程图如下:第一个就是去看训练的损失函数值情况。如果损失较大…

android studio 运行,偶然会导致死机,设置Memory Settings尝试解决

1、android studio导致死机 鼠标不能动,键盘没有反应,只能硬重启,但是内存并没有用完,cpu也不是100% 2、可能的原因 android studio内存设置的问题,为了限制占用内存,所以手工设置内存最小的一个&#x…

HTB 赛季8靶场 - Outbound

Rustscan扫描我们开局便拥有账号 tyler / LhKL1o9Nm3X2,我们使用rustscan进行扫描 rustscan -a 10.10.11.77 --range 1-65535 --scan-order "Random" -- -A Web服务漏洞探查 我们以账号tyler / LhKL1o9Nm3X2登录webmail,并快速确认版本信息。该…

动态组件和插槽

[Vue2]动态组件和插槽 动态组件和插槽来实现外部传入自定义渲染 组件 <template><!-- 回复的处理进度 --><div v-if"steps.length > 0" class"gain-box-header"><el-steps direction"vertical"><div class"l…

Unreal5从入门到精通之如何实现UDP Socket通讯

文章目录 一.前言二.什么是FSocket1. FSocket的作用2. FSocket关键特性三.创建Socket四.数据传输五.线程安全六.UDPSocketComponentUDPSocketComponent.hUUDPSocketComponent.cpp七.SocketTest测试八.最后一.前言 我们在开发UE 的过程中,会经常使用到Socket通讯,包括TCP,UD…

UI前端大数据处理新趋势:基于边缘计算的数据处理与响应

hello宝子们...我们是艾斯视觉擅长ui设计、前端开发、数字孪生、大数据、三维建模、三维动画10年经验!希望我的分享能帮助到您!如需帮助可以评论关注私信我们一起探讨!致敬感谢感恩!一、引言&#xff1a;前端大数据的 “云端困境” 与边缘计算的破局当用户在在线文档中实时协作…

Reading and Writing to a State Variable

本节是《Solidity by Example》的中文翻译与深入讲解&#xff0c;专为零基础或刚接触区块链开发的小白朋友打造。我们将通过“示例 解说 提示”的方式&#xff0c;带你逐步理解每一段 Solidity 代码的实际用途与背后的逻辑。Solidity 是以太坊等智能合约平台使用的主要编程语…

c# 深度解析:实现一个通用配置管理功能,打造高并发、可扩展的配置管理神器

文章目录深入分析 ConfigManager<TKey, TValue> 类1. 类设计概述2. 核心成员分析2.1 字段和属性2.2 构造函数3. 数据加载机制4. CRUD 操作方法4.1 添加数据4.2 删除数据4.3 更新数据4.4 查询数据4.5 清空数据5. 数据持久化6. 设计亮点7. 使用示例ConfigManager<TKey, …

运维打铁: Python 脚本在运维中的常用场景与实现

文章目录引言思维导图常用场景与代码实现1. 服务器监控2. 文件管理3. 网络管理4. 自动化部署总结注意事项引言 在当今的 IT 运维领域&#xff0c;自动化和效率是至关重要的。Python 作为一种功能强大且易于学习的编程语言&#xff0c;已经成为运维人员不可或缺的工具。它可以帮…

【零基础入门unity游戏开发——unity3D篇】3D光源之——unity反射和反射探针技术

文章目录 前言实现天空盒反射1、新建一个cube2、全反射材质3、增加环境反射分辨率反射探针1、一样把小球材质调成全反射2、在小球身上加添加反射探针3、设置静态物体4、点击烘培5、效果6、可以修改反射探针区域大小7、实时反射专栏推荐完结前言 当对象收到直接和间接光照后,它…

React Three Fiber 实现 3D 模型点击高亮交互的核心技巧

在 WebGL 3D 开发中&#xff0c;模型交互是提升用户体验的关键功能之一。本文将基于 React Three Fiber&#xff08;R3F&#xff09;和 Three.js&#xff0c;总结 3D 模型点击高亮&#xff08;包括模型本身和边框&#xff09;的核心技术技巧&#xff0c;帮助开发者快速掌握复杂…

卷积神经网络实战:MNIST手写数字识别

夜渐深&#xff0c;我还在&#x1f618; 老地方 睡觉了&#x1f64c; 文章目录&#x1f4da; 卷积神经网络实战&#xff1a;MNIST手写数字识别&#x1f9e0; 4.1 预备知识⚙️ 4.1.1 torch.nn.Conv2d() 三维卷积操作&#x1f4cf; 4.1.2 nn.MaxPool2d() 池化层的作用&#x1f4…