1.模型结构

本案例整体采用transformer论文中提出的结构,部分设置做了调整。transformer网络结构介绍可参考博客——入门级别的Transformer模型介绍,这里着重介绍其代码实现。
模型的整体结构,包括词嵌入层,位置编码,编码器,解码器、输出层部分。

2.词嵌入层

词嵌入层用于将token转化为词向量,该层可直接调用nn模块中的Embedding方法。该方法主要包括两个参数,分别表示词表的大小(vocab_size)和词嵌入的维度(emb_size),同时为了训练更稳定,加入了缩放因子dk\sqrt {d_k}dk,代码如下:

class TokenEmbedding(nn.Module):def __init__(self, vocab_size: int, emb_size):super(TokenEmbedding, self).__init__()# 词嵌入层:将词索引映射到emb_size维的向量self.embedding = nn.Embedding(vocab_size, emb_size)# 记录嵌入维度(用于缩放)self.emb_size = emb_sizedef forward(self, tokens: Tensor):# 将词索引转换为词向量,并乘以√emb_size(缩放,稳定梯度)return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

3.位置编码

位置编码层用于给序列添加位置信息,解决自注意力机制无法感知序列顺序的问题。公式为:
PE(pos,2i)=sin(pos10002id)PE(pos,2i)=sin(\frac{pos}{1000\frac{2i}{d}})PE(pos,2i)=sin(1000d2ipos)
PE(pos,2i+1)=cos(pos10002id)PE(pos,2i+1)=cos(\frac{pos}{1000\frac{2i}{d}})PE(pos,2i+1)=cos(1000d2ipos)
代码表示如下:

class PositionalEncoding(nn.Module):def __init__(self, emb_size: int, dropout, maxlen: int = 5000):super(PositionalEncoding, self).__init__()# 计算位置编码的衰减因子(控制正弦/余弦函数的频率)den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)# 位置索引(0到maxlen-1)pos = torch.arange(0, maxlen).reshape(maxlen, 1)# 初始化位置编码矩阵(形状:[maxlen, emb_size])pos_embedding = torch.zeros((maxlen, emb_size))# 偶数列用正弦函数填充(pos * den)pos_embedding[:, 0::2] = torch.sin(pos * den)# 奇数列用余弦函数填充(pos * den)pos_embedding[:, 1::2] = torch.cos(pos * den)# 调整维度(添加批次维度,便于与词嵌入向量相加)pos_embedding = pos_embedding.unsqueeze(-2)# Dropout层(正则化,防止过拟合)self.dropout = nn.Dropout(dropout)# 注册为缓冲区(模型保存/加载时自动处理)self.register_buffer('pos_embedding', pos_embedding)def forward(self, token_embedding: Tensor):# 将词嵌入向量与位置编码相加,并应用Dropoutreturn self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0),:])

4.编码器

由于编码器部分是通过堆叠多个子编码器层所构成的,子编码器包括:多头自注意力层、残差连接与归一化、前馈网络三部分,该部分代码全部被封装成TransformerEncoderLayer函数中,使用时只需要传递相应超参数即可,如词嵌入维度、多头注意力的头数、前馈网络的隐含层维度,代码实现为:

# 定义编码器层(单头注意力→多头注意力→前馈网络)
encoder_layer = TransformerEncoderLayer(d_model=emb_size,       # 输入特征维度(与词嵌入维度一致)nhead=NHEAD,            # 多头注意力的头数dim_feedforward=dim_feedforward  # 前馈网络隐藏层维度
)
# 堆叠多层编码器层形成完整编码器
self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

5.解码器

解码器同编码器类似,代码可以表述为:

# 定义解码器层(掩码多头注意力→编码器-解码器多头注意力→前馈网络)
decoder_layer = TransformerDecoderLayer(d_model=emb_size,       # 输入特征维度(与词嵌入维度一致)nhead=NHEAD,            # 多头注意力头数(与编码器一致)dim_feedforward=dim_feedforward  # 前馈网络隐藏层维度
)
# 堆叠多层解码器层形成完整解码器
self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

6.输出层

输出通过线性层得到每个单词的得分,可直接通过Linear层直接实现。

7.大体代码

基于上述介绍,完整代码如下:

from torch.nn import (TransformerEncoder, TransformerDecoder,TransformerEncoderLayer, TransformerDecoderLayer)class Seq2SeqTransformer(nn.Module):"""基于Transformer的序列到序列翻译模型(日中机器翻译核心模块)包含编码器(处理源语言序列)和解码器(生成目标语言序列)"""def __init__(self, num_encoder_layers: int, num_decoder_layers: int,emb_size: int, src_vocab_size: int, tgt_vocab_size: int,dim_feedforward: int = 512, dropout: float = 0.1):"""初始化Transformer模型参数和组件:param num_encoder_layers: 编码器层数(论文中通常为6,此处根据计算资源调整):param num_decoder_layers: 解码器层数(与编码器层数一致):param emb_size: 词嵌入维度(对应Transformer的d_model,需与多头注意力维度匹配):param src_vocab_size: 源语言(日语)词表大小:param tgt_vocab_size: 目标语言(中文)词表大小:param dim_feedforward: 前馈网络隐藏层维度(通常为4*d_model):param dropout:  dropout概率(用于正则化,防止过拟合)"""super(Seq2SeqTransformer, self).__init__()# 定义编码器层(单头注意力→多头注意力→前馈网络)encoder_layer = TransformerEncoderLayer(d_model=emb_size,       # 输入特征维度(与词嵌入维度一致)nhead=NHEAD,            # 多头注意力的头数(需满足 emb_size % nhead == 0)dim_feedforward=dim_feedforward  # 前馈网络隐藏层维度)# 堆叠多层编码器层形成完整编码器self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)# 定义解码器层(掩码多头注意力→编码器-解码器多头注意力→前馈网络)decoder_layer = TransformerDecoderLayer(d_model=emb_size,       # 输入特征维度(与词嵌入维度一致)nhead=NHEAD,            # 多头注意力头数(与编码器一致)dim_feedforward=dim_feedforward  # 前馈网络隐藏层维度)# 堆叠多层解码器层形成完整解码器self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)# 生成器:将解码器输出映射到目标词表(预测每个位置的目标词)self.generator = nn.Linear(emb_size, tgt_vocab_size)# 源语言词嵌入层(将词索引转换为连续向量)self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)# 目标语言词嵌入层(与源语言共享嵌入层可提升效果,此处未共享)self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)# 位置编码层(注入序列位置信息,解决Transformer的位置无关性)self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,tgt_mask: Tensor, src_padding_mask: Tensor,tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):"""前向传播(训练时使用教师强制,输入完整目标序列):param src: 源语言序列张量(形状:[seq_len, batch_size]):param trg: 目标语言序列张量(形状:[seq_len, batch_size]):param src_mask: 源序列注意力掩码(形状:[seq_len, seq_len],全0表示无掩码):param tgt_mask: 目标序列掩码(下三角掩码,防止关注未来词):param src_padding_mask: 源序列填充掩码(标记<pad>位置,形状:[batch_size, seq_len]):param tgt_padding_mask: 目标序列填充掩码(标记<pad>位置,形状:[batch_size, seq_len]):param memory_key_padding_mask: 编码器输出的填充掩码(与src_padding_mask一致):return: 目标序列的词表概率分布(形状:[seq_len, batch_size, tgt_vocab_size])"""# 源序列处理:词嵌入 + 位置编码src_emb = self.positional_encoding(self.src_tok_emb(src))# 目标序列处理:词嵌入 + 位置编码(训练时使用教师强制,输入完整目标序列)tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))# 编码器处理源序列,生成记忆向量(memory)memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)# 解码器利用记忆向量生成目标序列outs = self.transformer_decoder(tgt_emb,                # 目标序列嵌入(含位置信息)memory,                 # 编码器输出的记忆向量tgt_mask,               # 目标序列掩码(防止未来词)None,                   # 编码器-解码器注意力掩码(此处未使用)tgt_padding_mask,       # 目标序列填充掩码(忽略<pad>)memory_key_padding_mask # 记忆向量填充掩码(与源序列填充掩码一致))# 通过生成器输出目标词表的概率分布return self.generator(outs)def encode(self, src: Tensor, src_mask: Tensor):"""编码源序列(推理时单独调用,生成编码器记忆向量):param src: 源语言序列张量(形状:[seq_len, batch_size]):param src_mask: 源序列注意力掩码(形状:[seq_len, seq_len]):return: 编码器输出的记忆向量(形状:[seq_len, batch_size, emb_size])"""return self.transformer_encoder(self.positional_encoding(self.src_tok_emb(src)),  # 源序列嵌入+位置编码src_mask  # 源序列注意力掩码)def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):"""解码目标序列(推理时逐步生成目标词):param tgt: 当前已生成的目标序列前缀(形状:[current_seq_len, batch_size]):param memory: 编码器输出的记忆向量(形状:[seq_len, batch_size, emb_size]):param tgt_mask: 目标序列掩码(下三角掩码,防止关注未来词):return: 解码器输出(形状:[current_seq_len, batch_size, emb_size])"""return self.transformer_decoder(self.positional_encoding(self.tgt_tok_emb(tgt)),  # 目标前缀嵌入+位置编码memory,  # 编码器记忆向量tgt_mask  # 目标前缀掩码(仅允许关注已生成部分))
class PositionalEncoding(nn.Module):def __init__(self, emb_size: int, dropout, maxlen: int = 5000):super(PositionalEncoding, self).__init__()# 计算位置编码的衰减因子(控制正弦/余弦函数的频率)den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)# 位置索引(0到maxlen-1)pos = torch.arange(0, maxlen).reshape(maxlen, 1)# 初始化位置编码矩阵(形状:[maxlen, emb_size])pos_embedding = torch.zeros((maxlen, emb_size))# 偶数列用正弦函数填充(pos * den)pos_embedding[:, 0::2] = torch.sin(pos * den)# 奇数列用余弦函数填充(pos * den)pos_embedding[:, 1::2] = torch.cos(pos * den)# 调整维度(添加批次维度,便于与词嵌入向量相加)pos_embedding = pos_embedding.unsqueeze(-2)# Dropout层(正则化,防止过拟合)self.dropout = nn.Dropout(dropout)# 注册为缓冲区(模型保存/加载时自动处理)self.register_buffer('pos_embedding', pos_embedding)def forward(self, token_embedding: Tensor):# 将词嵌入向量与位置编码相加,并应用Dropoutreturn self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0),:])class TokenEmbedding(nn.Module):def __init__(self, vocab_size: int, emb_size):super(TokenEmbedding, self).__init__()# 词嵌入层:将词索引映射到emb_size维的向量self.embedding = nn.Embedding(vocab_size, emb_size)# 记录嵌入维度(用于缩放)self.emb_size = emb_sizedef forward(self, tokens: Tensor):# 将词索引转换为词向量,并乘以√emb_size(缩放,稳定梯度)return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

结语

至此,模型已完成搭建,后续博客将继续介绍模型训练部分的内容,希望本篇博客能够对你理解transformer有所帮助!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/93578.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/93578.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/93578.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

上位机TCP/IP通信协议层常见问题汇总

以太网 TCP 通信是上位机开发中常用的通信方式&#xff0c;西门子 S7 通信、三菱 MC 通信以及 MQTT、OPC UA、Modbus TCP 等都是其典型应用。为帮助大家更好地理解 TCP 通信&#xff0c;我整理了一套常见问题汇总。一、OSI参考模型与TCP/IP参考模型基于TCP/IP的参考模型将协议分…

搭建ktg-mes

项目地址 该安装事项&#xff0c;基于当前最新版 2025年8月16日 之前的版本 下载地址&#xff1a; 后端JAVA 前端VUE 后端安装&#xff1a; 还原数据表 路径&#xff1a;根目录/sql/ry_20210908.sql、根目录/sql/quartz.sql、根目录/doc/实施文档/ktgmes-202505180846.sql.g…

uniapp纯前端绘制商品分享图

效果如图// useMpCustomShareImage.ts interface MpCustomShareImageData {canvasId: stringprice: stringlinePrice: stringgoodsSpecFirmName: stringimage: string }const CANVAS_WIDTH 500 const CANVAS_HEIGHT 400 const BG_IMAGE https://public-scjuchuang.oss-cn-ch…

醋酸镧:看不见的科技助力

虽然我们每天都在使用各种科技产品&#xff0c;但有些关键的化学物质却鲜为人知。醋酸镧&#xff0c;就是这样一种默默为科技进步贡献力量的“幕后英雄”。它不仅是稀土元素镧的一种化合物&#xff0c;还在许多高科技领域中发挥着重要作用。今天&#xff0c;让我们一起来了解这…

苍穹外卖日记

day 1 windows系统启动nginx报错: The system cannot find the path specified 在启动nginx的时候报错&#xff1a; /temp/client_body_temp" failed (3: The system cannot find the path specified) 解决办法&#xff1a; 1.检查nginx的目录是否存在中文 &#xff0c;路…

楼宇自控系统赋能建筑全维度管理,实现环境、安全与能耗全面监管

随着城市化进程加速和绿色建筑理念普及&#xff0c;现代楼宇管理正经历从粗放式运营向精细化管控的转型。楼宇自控系统&#xff08;BAS&#xff09;作为建筑智能化的核心载体&#xff0c;通过物联网、大数据和人工智能技术的深度融合&#xff0c;正在重构建筑管理的全维度框架&…

【HarmonyOS】Window11家庭中文版开启鸿蒙模拟器失败提示未开启Hyoer-V

【HarmonyOS】Window11家庭中文版开启鸿蒙模拟器失败提示未开启Hyoer-V一、问题背景 当鸿蒙模拟器启动时&#xff0c;提示如下图所示&#xff1a;因为Hyper-V 仅在 Windows 11 专业版、企业版和教育版中作为预装功能提供&#xff0c;而家庭版&#xff08;包括中文版&#xff09…

vscode远程服务器出现一直卡在正在打开远程和连接超时解决办法

项目场景&#xff1a; 使用ssh命令或者各种软件进行远程服务器之后&#xff0c;结果等到几分钟之后自动断开连接问题解决。vscode远程服务器一直卡在正在打开远程状态问题解决。问题描述 1.连接超时 2.vscode远程一直卡在正在打开远程...原因分析&#xff1a;需要修改设置超时断…

Maven下载和配置-IDEA使用

目录 一 MAVEN 二 三个仓库 1. 本地仓库&#xff08;Local Repository&#xff09; 2. 私有仓库&#xff08;Private Repository&#xff0c;公司内部仓库&#xff09; 3. 远程仓库&#xff08;Remote Repository&#xff09; 依赖查找流程&#xff08;优先级&#xff09…

Dify实战应用指南(上传需求稿生成测试用例)

一、Dify平台简介 Dify是一款开源的大语言模型&#xff08;LLM&#xff09;应用开发平台&#xff0c;融合了“Define&#xff08;定义&#xff09; Modify&#xff08;修改&#xff09;”的设计理念&#xff0c;通过低代码/无代码的可视化界面降低技术门槛。其核心价值在于帮助…

学习日志35 python

1 Python 列表切片一、切片完整语法列表切片的基本格式&#xff1a; 列表[start:end:step]start&#xff1a;起始索引&#xff08;包含该位置元素&#xff0c;可省略&#xff09;end&#xff1a;结束索引&#xff08;不包含该位置元素&#xff0c;可省略&#xff09;step&#…

Linux -- 文件【下】

目录 一、EXT2文件系统 1、宏观认识 2、块组内部构成 2.1 Data Block 2.2 i节点表(Inode Table) 2.3 块位图&#xff08;Block Bitmap&#xff09; 2.4 inode位图&#xff08;Inode Bitmap&#xff09; 2.5 GDT&#xff08;Group Descriptor Table&#xff09; 2.6 超…

谷歌手机刷机和面具ROOT保姆级别教程

#比较常用的谷歌输入root面具教程,逆向工程师必修课程# 所需工具与材料清单 真机设备 推荐使用 Google Pixel 4 或其他兼容设备&#xff0c;确保硬件支持刷机操作。 ADB 环境配置 通过安装 Android Studio 自动配置 ADB 和 Fastboot 工具。安装完成后&#xff0c;需在系统环境…

平衡二叉搜索树 - 红黑树详解

文章目录一、红黑树概念引申问题二、红黑树操作一、红黑树概念 红黑树是一棵二叉搜索树&#xff0c;它在每个节点上增加了一个存储位用来表示节点颜色(红色或者黑色)&#xff0c;红黑树通过约束颜色&#xff0c;可以保证最长路径不超过最短路径的两倍&#xff0c;因而近似平衡…

从0开始跟小甲鱼C语言视频使用linux一步步学习C语言(持续更新)8.14

第十六天 第五十二&#xff0c;五十三&#xff0c;五十四&#xff0c;五十五和五十六集 第五十二集 文件包含 一个include命令只能指定一个被包含文件 文件允许嵌套&#xff0c;就是一个被包含的文件可以包含另一个文件。 文件名可以用尖括号或者双引号括起来 但是两种的查找方…

B+树索引分析:单表最大存储记录数

在现代数据库设计中&#xff0c;随着数据量的增加&#xff0c;如何有效地管理和优化数据库成为了一个关键问题。根据阿里巴巴开发手册的标准&#xff0c;当一张表预计在三年内的数据量超过500万条或者2GB时&#xff0c;就应该考虑实施分库分表策略 Mysql B树索引介绍 及 页内储…

三、memblock 内存分配器

两个问题&#xff1a; 1、系统是怎么知道物理内存的&#xff1f;linux内存管理学习&#xff08;1&#xff09;&#xff1a;物理内存探测 2、在内存管理真正初始化之前&#xff0c;内核的代码执行需要分配内存该怎么处理&#xff1f; 在Linux内核启动初期&#xff0c;完整的内存…

Python 桌面应用形态后台管理系统的技术选型与方案报告

下面是一份面向“Python 桌面应用形态的后台管理系统”的技术选型与方案报告。我把假设前提→总体架构→客户端技术选型→服务端与数据层→基础设施与安全→交付与运维→质量保障→里程碑计划→风险与对策→最小可行栈逐层给出。 一、前置假设 & 非功能目标 业务假设 典型…

Winsows系统去除右键文件显示的快捷列表

前言&#xff1a;今天重做了电脑系统&#xff0c;安装的是纯净版的系统。然后手动指定D盘安装了下列软件。&#xff08;QQ&#xff0c;迅雷&#xff0c;百度网盘&#xff0c;搜狗输入法&#xff0c;驱动精灵&#xff09;然后我右键点击桌面的软件快捷方式&#xff0c;出现了一排…

【Go】Gin 超时中间件的坑:fatal error: concurrent map writes

Gin 社区超时中间件的坑&#xff1a;导致线上 Pod 异常重启 在最近的项目中&#xff0c;我们遇到了因为 Gin 超时中间件&#xff08;timeout&#xff09; 引发的生产事故&#xff1a;Pod 异常退出并重启。 问题现场 pod无故重启&#xff0c;抓取标准输出日志&#xff0c;问题…