PyTorch RNN 名字分类器详解

使用PyTorch实现的字符级RNN(循环神经网络)项目,用于根据人名预测其所属的语言/国家。该模型通过学习不同语言名字的字符模式,够识别名字的语言起源。

环境设置

import torch
import string
import unicodedata
import glob
import os
import time
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

1. 数据预处理

1.1 字符编码处理

# 定义允许的字符集(ASCII字母 + 标点符号 + 占位符)
allowed_characters = string.ascii_letters + " .,;'" + "_"
n_letters = len(allowed_characters)  # 58个字符def unicodeToAscii(s):"""将Unicode字符串转换为ASCII"""return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn' and c in allowed_characters)

关键点:

  • 使用One-hot编码表示每个字符
  • 将非ASCII字符规范化(如 ‘Ślusàrski’ → ‘Slusarski’)
  • 未知字符用 “_” 表示

1.2 张量转换

def letterToIndex(letter):"""将字母转换为索引"""if letter not in allowed_characters:return allowed_characters.find("_")return allowed_characters.find(letter)def lineToTensor(line):"""将名字转换为张量 <line_length x 1 x n_letters>"""tensor = torch.zeros(len(line), 1, n_letters)for li, letter in enumerate(line):tensor[li][0][letterToIndex(letter)] = 1return tensor

张量维度说明:

  • 每个名字表示为3D张量:[序列长度, 批次大小=1, 字符数=58]
  • 使用One-hot编码:每个字符位置只有一个1,其余为0

2. 数据集构建

2.1 自定义Dataset类

class NamesDataset(Dataset):def __init__(self, data_dir):self.data = []           # 原始名字self.data_tensors = []   # 名字的张量表示self.labels = []         # 语言标签self.labels_tensors = [] # 标签的张量表示# 读取所有.txt文件(每个文件代表一种语言)text_files = glob.glob(os.path.join(data_dir, '*.txt'))for filename in text_files:label = os.path.splitext(os.path.basename(filename))[0]lines = open(filename, encoding='utf-8').read().strip().split('\n')for name in lines:self.data.append(name)self.data_tensors.append(lineToTensor(name))self.labels.append(label)

2.2 数据集划分

# 85/15 训练/测试集划分
train_set, test_set = torch.utils.data.random_split(alldata, [.85, .15], generator=torch.Generator(device=device).manual_seed(2024)
)

3. RNN模型架构

3.1 模型定义

class CharRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(CharRNN, self).__init__()# RNN层:输入大小 → 隐藏层大小self.rnn = nn.RNN(input_size, hidden_size)# 输出层:隐藏层 → 输出类别self.h2o = nn.Linear(hidden_size, output_size)# LogSoftmax用于分类self.softmax = nn.LogSoftmax(dim=1)def forward(self, line_tensor):rnn_out, hidden = self.rnn(line_tensor)output = self.h2o(hidden[0])output = self.softmax(output)return output

模型参数:

  • 输入大小:58(字符数)
  • 隐藏层大小:128
  • 输出大小:18(语言类别数)

4. 训练过程

4.1 训练函数

def train(rnn, training_data, n_epoch=10, n_batch_size=64, learning_rate=0.2, criterion=nn.NLLLoss()):rnn.train()optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate)for iter in range(1, n_epoch + 1):# 创建小批量batches = list(range(len(training_data)))random.shuffle(batches)batches = np.array_split(batches, len(batches)//n_batch_size)for batch in batches:batch_loss = 0for i in batch:label_tensor, text_tensor, label, text = training_data[i]output = rnn.forward(text_tensor)loss = criterion(output, label_tensor)batch_loss += loss# 反向传播和优化batch_loss.backward()nn.utils.clip_grad_norm_(rnn.parameters(), 3)  # 梯度裁剪optimizer.step()optimizer.zero_grad()

训练技巧:

  • 使用SGD优化器,学习率0.15
  • 梯度裁剪防止梯度爆炸
  • 批量大小:64

5. 模型评估

5.1 混淆矩阵可视化

def evaluate(rnn, testing_data, classes):confusion = torch.zeros(len(classes), len(classes))rnn.eval()with torch.no_grad():for i in range(len(testing_data)):label_tensor, text_tensor, label, text = testing_data[i]output = rnn(text_tensor)guess, guess_i = label_from_output(output, classes)label_i = classes.index(label)confusion[label_i][guess_i] += 1# 归一化并可视化# ...

6. 训练结果

  • 训练样本数:17,063
  • 测试样本数:3,011
  • 训练轮数:27
  • 最终损失:约0.43

损失曲线显示模型收敛良好,从初始的0.88降至0.43。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/917941.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/917941.shtml
英文地址,请注明出处:http://en.pswp.cn/news/917941.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面向对象之类方法,成员变量和局部变量

1.类的方法必须包含几个部分&#xff1f;2.成员变量和局部变量类的方法必须包含哪几个部分&#xff1f;.方法名&#xff1a;用于标识方法的名称&#xff0c;遵循标识符命名规则&#xff0c;通常采用驼峰命名法。返回值类型&#xff1a;指定方法返回的数据类型。如果方法不返回任…

古法笔记 | 通过查表进行ASCII字符编码转换

ASCII字符集是比较早期的一种字符编码&#xff0c;只能表示英文字符&#xff0c;最多能表示128个字符。 字符集规定了每个字符和二进制数之间的对应关系&#xff0c;可以通过查表完成二进制数到字符的转换ASCII字符占用的存储空间是定长的1字节 ASCII字符的官方码点表见下图&…

Linux C实现单生产者多消费者环形缓冲区

使用C11里的原子变量实现&#xff0c;没有用互斥锁&#xff0c;效率更高。ring_buffer.h:/*** file ring_buffer.h* author tl* brief 单生产者多消费者环形缓冲区&#xff0c;每条数据被所有消费者读后才释放。读线程安全&#xff0c;写仅单线程。* version* date 2025-08-06*…

复杂场景识别率↑31%!陌讯多模态融合算法在智慧环卫的实战解析

摘要&#xff1a;针对边缘计算优化的垃圾堆放识别场景&#xff0c;本文解析了基于动态决策机制的视觉算法如何提升复杂环境的鲁棒性。实测数据显示在遮挡/光照干扰下&#xff0c;mAP0.5较基线提升28.3%&#xff0c;误报率降低至行业1/5水平。一、行业痛点&#xff1a;智慧环卫的…

MyBatis-Plus Service 接口:如何在 MyBatis-Plus 中实现业务逻辑层??

全文目录&#xff1a;开篇语前言1. MyBatis-Plus 的 IService 接口1.1 基本使用示例&#xff1a;创建实体类 User 和 UserService1.2 创建 IService 接口1.3 创建 ServiceImpl 类1.4 典型的数据库操作方法1.4.1 save()&#xff1a;保存数据1.4.2 remove()&#xff1a;删除数据1…

[激光原理与应用-168]:光源 - 常见光源的分类、特性及应用场景的详细解析,涵盖技术原理、优缺点及典型应用领域

一、半导体光源1. LED光源&#xff08;发光二极管&#xff09;原理&#xff1a;通过半导体PN结的电子-空穴复合发光&#xff0c;波长由材料带隙决定&#xff08;如GaN发蓝光、AlGaInP发红光&#xff09;。特性&#xff1a;优点&#xff1a;寿命长&#xff08;>5万小时&#…

Metronic v.7.1.7企业级Web应用前端框架全攻略

本文还有配套的精品资源&#xff0c;点击获取 简介&#xff1a;Metronic是一款专注于构建响应式、高性能企业级Web应用的前端开发框架。最新版本v.7.1.7引入了多种功能和优化&#xff0c;以增强开发效率和用户体验。详细介绍了其核心特性&#xff0c;包括响应式设计、多种模…

鸿蒙开发--Notification Kit(用户通知服务)

通知是手机系统中很重要的信息展示方式&#xff0c;通知不仅可以展示文字&#xff0c;也可以展示图片&#xff0c;甚至可以将组件加到通知中&#xff0c;只要用户不清空&#xff0c;通知的信息可以永久保留在状态栏上通知的介绍 通知 Notification通知&#xff0c;即在一个应用…

鸿蒙 - 分享功能

文章目录一、背景二、app发起分享1. 通过分享面板进行分享2. 使用其他应用打开二、处理分享的内容1. module.json5 配置可接收分享2. 解析分享的数据一、背景 在App开发中&#xff0c;分享是常用功能&#xff0c;这里介绍鸿蒙开发中&#xff0c;其他应用分享到自己的app中&…

【Agent 系统设计】基于大语言模型的智能Agent系统

一篇阿里博文引发的思考和探索。基于大语言模型的智能Agent系统 1. 系统核心思想 核心思想是构建一个以大语言模型&#xff08;LLM&#xff09;为“大脑”的智能代理&#xff08;Agent&#xff09;&#xff0c;旨在解决将人类的自然语言指令高效、准确地转化为机器可执行的自动…

企业级Web框架性能对决:Spring Boot、Django、Node.js与ASP.NET深度测评

企业级Web应用的开发效率与运行性能直接关系到业务的成败。本文通过构建标准化的待办事项&#xff08;Todo&#xff09;应用&#xff0c;对四大主流框架——Spring Boot、Django、Node.js和ASP.NET展开全面的性能较量。我们将从底层架构特性出发&#xff0c;结合实测数据与数据…

为什么 `source ~/.bashrc` 在 systemd 或 crontab 中不生效

摘要&#xff1a;你是否遇到过这样的问题&#xff1a;在终端里运行脚本能正常工作&#xff0c;但用 systemd 或 crontab 自动启动时却报错“命令找不到”、“模块导入失败”&#xff1f; 本文将揭示一个深藏在 ~/.bashrc 中的“陷阱”&#xff1a;非交互式 shell 会直接退出&am…

Linux 磁盘中的文件

1.磁盘结构 Linux中的文件加载到内存上之前是放到哪的&#xff1f; 放在磁盘上的文件——>访问文件&#xff0c;打开它——>找到这个文件——>路径 但文件是怎样存储在磁盘上的 1.1物理结构磁盘可以理解为上百亿个小磁铁&#xff08;如N为1&#xff0c;S为0&#xff0…

【方法】Git本地仓库的文件夹不显示红色感叹号、绿色对号等图标

文章目录前言开始操作winr&#xff0c;输入regedit&#xff0c;打开注册表重启资源管理器前言 这个绿色对号图标表示本地仓库和远程的GitHub仓库内容保持一致&#xff0c;红色则是相反咯&#xff0c;给你们瞅一下。 首先这两个东西你一定要安装配置好了&#xff0c;安装顺序不…

量化交易与主观交易:哪种方式更胜一筹?

文章概要 在投资的世界里&#xff0c;量化交易和主观交易如同冰与火&#xff0c;各自拥有独特的优势与挑战。作为一名投资者&#xff0c;了解这两种交易方式的差异和各自的优缺点至关重要。本文将从决策依据、执行方式、风险管理等方面深入探讨量化交易的精确性与主观交易的灵活…

【JS】扁平树数据转为树结构

扁平数据转为最终效果[{"label":"疼逊有限公司","code":"1212","disabled":false,"parentId":"none","children":[{"label":"财务部","code":"34343&quo…

数据结构4-栈、队列

摘要&#xff1a;本文系统介绍了栈和队列两种基础数据结构。栈采用"先进后出"原则&#xff0c;分为顺序栈和链式栈&#xff0c;详细说明了压栈、出栈等基本操作及其实现方法。队列遵循"先进先出"规则&#xff0c;同样分为顺序队列和链式队列&#xff0c;重…

大数据spark、hasdoop 深度学习、机器学习算法的音乐平台用户情感分析系统设计与实现

大数据spark、hasdoop 深度学习、机器学习算法的音乐平台用户情感分析系统设计与实现

视频汇聚系统EasyCVR调用设备录像保活时视频流不连贯问题解决方案

在使用EasyCVR过程中&#xff0c;有用户反馈调用设备录像保活功能时&#xff0c;出现视频流不连贯的情况。针对这一问题&#xff0c;我们经过排查与测试&#xff0c;整理出如下解决步骤&#xff0c;供开发者参考&#xff1a;具体解决步骤1&#xff09;先调用登录接口完成鉴权确…

【保姆级喂饭教程】python基于mysql-connector-python的数据库操作通用封装类(连接池版)

目录项目环境一、db_config.py二、mysql_executor.py三、test/main.py在使用mysql-connector-python连接MySQL数据库的时候&#xff0c;如同Java中的jdbc一般&#xff0c;每条sql需要创建和删除连接&#xff0c;很自然就想到写一个抽象方法&#xff0c;但是找了找没有官方标准的…