参考:
Transformer模型详解(图解最完整版) - 知乎https://zhuanlan.zhihu.com/p/338817680GitHub - liaoyanqing666/transformer_pytorch: 完整的原版transformer程序,complete origin transformer programhttps://github.com/liaoyanqing666/transformer_pytorcharxiv.org/pdf/1706.03762https://arxiv.org/pdf/1706.03762

一. Transformer的整体架构

Transformer 由 Encoder (编码器)和 Decoder (解码器)两个部分组成,Encoder 和 Decoder 都包含 6 个 block(块)。Transformer 的工作流程大体如下:

第一步:获取输入句子的每一个单词的表示向量 X,X由单词本身的 Embedding(Embedding就是从原始数据提取出来的特征(Feature)) 和单词位置的 Embedding 相加得到。

第二步:将得到的单词表示向量矩阵 (如上图所示,每一行是一个单词的表示 x)传入 Encoder 中,经过 6 个 Encoder block (编码器块)后可以得到句子所有单词的编码信息矩阵 C。如下图,单词向量矩阵用 X_{n\times d}表示, n 是句子中单词个数,d 是表示向量的维度(论文中 d=512)。每一个 Encoder block (编码器块)输出的矩阵维度与输入完全一致。

第三步:将 Encoder (编码器)输出的编码信息矩阵 C传递到 Decoder(解码器)中,Decoder(解码器) 依次会根据当前翻译过的单词 1~ i 翻译下一个单词 i+1,如下图所示。在使用的过程中,翻译到单词 i+1 的时候需要通过 Mask (掩盖) 操作遮盖住 i+1 之后的单词。

上图 Decoder 接收了 Encoder 的编码矩阵 C,然后首先输入一个翻译开始符 "<Begin>",预测第一个单词 "I";然后输入翻译开始符 "<Begin>" 和单词 "I",预测单词 "have",以此类推。

二. Transformer 的输入

Transformer 中单词的输入表示 单词本身的 Embedding 和单词位置 Embedding (Positional Encoding)相加得到。

2.1 单词 Embedding(词嵌入层)

单词本身的 Embedding 有很多种方式可以获取,例如可以采用 Word2Vec、Glove 等算法预训练得到,也可以在 Transformer 中训练得到。

self.embedding = nn.Embedding(vocabulary, dim)

功能解释:

  1. 作用:将离散的整数索引(单词ID)转换为连续的向量表示

  2. 输入:形状为 [sequence_length] 的整数张量

  3. 输出:形状为 [sequence_length, dim] 的浮点数张量(X_{n\times d},n是序列长度,d是特征维度)

参数详解:

参数含义示例值说明
vocabulary词汇表大小10000表示模型能处理的不同单词/符号总数
dim嵌入维度512每个单词被表示成的向量长度

工作原理:

  1. 创建一个可学习的嵌入矩阵[vocabulary, dim],例如当 vocabulary=10000dim=512 时,是一个 10000×512 的矩阵;

  2. 每个整数索引对应矩阵中的一行:

# 假设单词"apple"的ID=42
apple_vector = embedding_matrix[42]  # 形状 [512]

在Transformer中的具体作用:

# 输入:src = torch.randint(0, 10000, (2, 10))
# 形状:[batch_size=2, seq_len=10]src_embedded = self.embedding(src)# 输出形状变为:[2, 10, 512]
# 每个整数单词ID被替换为512维的向量

可视化表现:

原始输入 (单词ID):
[ [ 25,  198, 3000, ... ],   # 句子1[ 1,   42,  999,  ... ] ]  # 句子2经过嵌入层后 (向量表示):
[ [ [0.2, -0.5, ..., 1.3],   # ID=25的向量[0.8, 0.1, ..., -0.9],   # ID=198的向量... ],[ [0.9, -0.2, ..., 0.4],   # ID=1的向量[0.3, 0.7, ..., -1.2],   # ID=42的向量... ] ]

为什么需要词嵌入:

  • 语义表示:相似的单词会有相似的向量表示

  • 降维:将离散的ID映射到连续空间(one-hot编码需要10000维 → 嵌入只需512维)

  • 可学习:在训练过程中,这些向量会不断调整以更好地表示语义关系

2.2  位置 Embedding(位置编码)

Transformer 的位置编码(Positional Encoding,PE)是模型的关键创新之一,它解决了传统序列模型(如 RNN)固有的顺序处理问题。Transformer 的自注意力机制本身不具备感知序列位置的能力,位置编码通过向输入嵌入添加位置信息,使模型能够理解序列中元素的顺序关系。位置编码计算之后的输出维度和词嵌入层相同,均为(X_{n\times d})。

位置编码的核心作用:

  1. 注入位置信息:让模型区分不同位置的相同单词(如 "bank" 在句首 vs 句尾)

  2. 保持距离关系:编码相对位置和绝对位置信息

  3. 支持并行计算:避免像 RNN 那样依赖顺序处理

为什么需要位置编码?

  1. 自注意力的位置不变性
    Attention(Q,K,V)=softmax\left ( \frac{QK^{T}}{\sqrt{d_k}} \right )V,计算过程不包含位置信息

  2. 序列顺序的重要性

  • 自然语言:"猫追狗" ≠ "狗追猫"
  • 时序数据:股价序列的顺序决定趋势替代方案对比
方法优点缺点
正弦/余弦泛化性好,理论保证固定模式不灵活
可学习适应任务特定模式长度受限,需训练
相对位置直接建模相对距离实现复杂

位置编码的实际效果

  1. 早期层作用:帮助模型建立位置感知

  2. 后期层作用:位置信息被融合到语义表示中

  3. 可视化示例

Input:    [The,   cat,   sat,   on,   mat]
Embed:    [E_The, E_cat, E_sat, E_on, E_mat]
Position: [P0,    P1,    P2,    P3,   P4]Final: [E_The+P0, E_cat+P1, ... E_mat+P4]
(1)正余弦位置编码(论文采用)

正余弦位置编码的计算公式:

其中:

  •  `pos` 是token在序列中的位置(从0开始)
  •  `d_model` 是模型的嵌入维度(即每个token的向量维度)
  •  `i` 是维度的索引(从0到d_model/2-1)

特点:

  • 波长几何级数:覆盖不同频率
  • 相对位置可学习:位置偏移的线性变换 PE_{pos+k} 可表示为 PE_{pos} 的线性函数
  • 泛化性强:可处理比训练时更长的序列
  • 对称性:sin/cos 组合允许模型学习相对位置

代码实现:

class PositionalEncoding(nn.Module):# Sine-cosine positional codingdef __init__(self, emb_dim, max_len, freq=10000.0):super(PositionalEncoding, self).__init__()assert emb_dim > 0 and max_len > 0, 'emb_dim and max_len must be positive'self.emb_dim = emb_dimself.max_len = max_lenself.pe = torch.zeros(max_len, emb_dim)pos = torch.arange(0, max_len).unsqueeze(1)# pos: [max_len, 1]div = torch.pow(freq, torch.arange(0, emb_dim, 2) / emb_dim)# div: [ceil(emb_dim / 2)]self.pe[:, 0::2] = torch.sin(pos / div)# torch.sin(pos / div): [max_len, ceil(emb_dim / 2)]self.pe[:, 1::2] = torch.cos(pos / (div if emb_dim % 2 == 0 else div[:-1]))# torch.cos(pos / div): [max_len, floor(emb_dim / 2)]def forward(self, x, len=None):if len is None:len = x.size(-2)return x + self.pe[:len, :]

例如,指定emb_dim=512和max_len=100,句子长度为10,则位置embedding的数值计算如下(三角函数取弧度制):

\begin{bmatrix} sin\left ( \frac{0}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{0}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{0}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{0}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{0}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{0}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{1}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{1}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{1}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{1}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{1}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{1}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{2}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{2}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{2}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{2}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{2}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{2}{10000^{\frac{510}{512}}} \right )\\ ... & ... & ... & ... & ... & ... & ...\\ sin\left ( \frac{7}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{7}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{7}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{7}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{7}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{7}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{8}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{8}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{8}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{8}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{8}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{8}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{9}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{9}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{9}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{9}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{9}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{9}{10000^{\frac{510}{512}}} \right )\\ \end{bmatrix}_{10\times 512}=\begin{bmatrix} 0 & 1 & 0 & ... & 1 & 0 & 1\\ 0.8415 & 0.5403 & 0.8219 & ... & 1.0000 & 1.0366\times 10^{-4} & 1.0000\\ 0.9093 & -0.4161 & 0.9364 & ... & 1.0000 & 2.0733\times 10^{-4} & 1.0000\\ ... & ... & ... & ... & ... & ... & ...\\ 0.6570& 0.7539 & 0.4524 & ... & 1.0000 & 7.2564\times 10^{-4} & 1.0000\\ 0.9894 & -0.1455 & 0.9907 & ... & 1.0000 & 8.2931\times 10^{-4} & 1.0000\\ 0.4121 & -0.9111 & 0.6764 & ... & 1.0000 & 9.3297\times 10^{-4} & 1.0000 \end{bmatrix}_{10\times 512}

(2)可学习位置编码
class LearnablePositionalEncoding(nn.Module):# Learnable positional encodingdef __init__(self, emb_dim, len):super(LearnablePositionalEncoding, self).__init__()assert emb_dim > 0 and len > 0, 'emb_dim and len must be positive'self.emb_dim = emb_dimself.len = lenself.pe = nn.Parameter(torch.zeros(len, emb_dim))def forward(self, x):return x + self.pe[:x.size(-2), :]

特性

  • 直接学习位置嵌入:作为模型参数训练
  • 灵活性高:可适应特定任务的位置模式
  • 长度受限:只能处理预定义的最大长度
  • 计算效率高:直接查表无需计算

三. Self-Attention(自注意力机制)和Multi-Head Attention(多头自注意力)

Transformer 的内部结构图,左侧为 Encoder block(编码器),右侧为 Decoder block(解码器)。可以看到:

(1)Encoder block 包含一个 Multi-Head Attention;

(2)Decoder block 包含两个 Multi-Head Attention (其中有一个用到 Masked)。Multi-Head Attention 上方还包括一个 Add & Norm 层,Add 表示残差连接(Residual Connection),用于防止网络退化,Norm 表示Layer Normalization,用于对每一层的激活值进行归一化。

Multi-Head Attention 是 Transformer 的重点,它由 Self-Attention 演变而来,我们先从 Self-Attention 讲起。

3.1  Self-Attention(自注意力机制)

Self-Attention(自注意力)是 Transformer 架构的核心创新,它彻底改变了序列建模的方式。与传统的循环神经网络(RNN)和卷积神经网络(CNN)不同,self-attention 能够直接捕捉序列中任意两个元素之间的关系,无论它们之间的距离有多远:

Self-Attention 的输入用矩阵X_{n\times d}(n是序列长度,d是特征维度)进行表示,计算如下:

(1)通过可学习的权重矩阵生成Q(查询),K(键值),V(值):

\left\{\begin{matrix} Q = XW^Q \\ K = XW^K \\ V = XW^V \end{matrix}\right.

其中W^Q,W^K,W^V是可学习参数。

(2)计算 Self-Attention 的输出:Attention(Q,K,V)=softmax\left ( \frac{QK^{T}}{\sqrt{d_k}} \right )V

步骤分解:

  1. 相似度计算QK^T计算所有查询-键对之间的点积相似度,QK^T得到的矩阵行列数都为 n,n为句子单词数,这个矩阵可以表示单词之间的 attention 强度。

  2. 缩放:除以\sqrt{d_k}防止点积过大导致梯度消失

  3. 归一化:softmax 将相似度转换为概率分布

  4. 加权求和:用注意力权重对值向量加权求和,得到最终的输出

输入序列: [x1, x2, x3, x4]步骤1: 为每个输入生成Q,K,V向量
x1 → q1, k1, v1
x2 → q2, k2, v2
x3 → q3, k3, v3
x4 → q4, k4, v4步骤2: 计算注意力权重 (以x1为例)
权重1 = softmax(q1·k1 / √d_k)
权重2 = softmax(q1·k2 / √d_k)
权重3 = softmax(q1·k3 / √d_k)
权重4 = softmax(q1·k4 / √d_k)步骤3: 加权求和
输出1 = 权重1*v1 + 权重2*v2 + 权重3*v3 + 权重4*v4

3.2  Multi-Head Attention(多头注意力)

Transformer 使用多头机制增强模型表达能力:

MultiHead(Q,K,V)=Concat(head_1,head_2...head_h)W^O

其中每个注意力头:

head_i=Attention(QW_{i}^{Q},KW_{i}^{K},VW_{i}^{V})

  • h:注意力头的数量

  • W_i^Q, W_i^K, W_i^V:每个头的独立参数

  • W^O:输出投影矩阵

代码实现:

(1)多头分割处理:使用view将特征维度分割为多个头,确保每个头的维度:dim_head = dim_qk // num_heads

q = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)
k = ... # 类似处理
v = ... # 类似处理

(2)高效的矩阵运算:使用矩阵乘法并行计算所有位置的注意力分数

attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)

(3)多头合并:使用view合并多头:num_heads * d_v = dim_v

output = output.transpose(1, 2)
output = output.contiguous().view(-1, len_q, self.dim_v)

完整Multi-Head Attention(多头注意力)的代码实现,这里已经考虑了掩码处理的实现,关于掩码将在后面介绍。

class MultiHeadAttention(nn.Module):def __init__(self, dim, dim_qk=None, dim_v=None, num_heads=1, dropout=0.):super(MultiHeadAttention, self).__init__()dim_qk = dim if dim_qk is None else dim_qkdim_v = dim if dim_v is None else dim_vassert dim % num_heads == 0 and dim_v % num_heads == 0 and dim_qk % num_heads == 0, 'dim must be divisible by num_heads'self.dim = dimself.dim_qk = dim_qkself.dim_v = dim_vself.num_heads = num_headsself.dropout = nn.Dropout(dropout)self.w_q = nn.Linear(dim, dim_qk)self.w_k = nn.Linear(dim, dim_qk)self.w_v = nn.Linear(dim, dim_v)def forward(self, q, k, v, mask=None):# q: [B, len_q, D]# k: [B, len_kv, D]# v: [B, len_kv, D]assert q.ndim == k.ndim == v.ndim == 3, 'input must be 3-dimensional'len_q, len_k, len_v = q.size(1), k.size(1), v.size(1)assert q.size(-1) == k.size(-1) == v.size(-1) == self.dim, 'dimension mismatch'assert len_k == len_v, 'len_k and len_v must be equal'len_kv = len_vq = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)k = self.w_k(k).view(-1, len_kv, self.num_heads, self.dim_qk // self.num_heads)v = self.w_v(v).view(-1, len_kv, self.num_heads, self.dim_v // self.num_heads)# q: [B, len_q, num_heads, dim_qk//num_heads]# k: [B, len_kv, num_heads, dim_qk//num_heads]# v: [B, len_kv, num_heads, dim_v//num_heads]# The following 'dim_(qk)//num_heads' is writen as d_(qk)q = q.transpose(1, 2)k = k.transpose(1, 2)v = v.transpose(1, 2)# q: [B, num_heads, len_q, d_qk]# k: [B, num_heads, len_kv, d_qk]# v: [B, num_heads, len_kv, d_v]attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)# attn: [B, num_heads, len_q, len_kv]if mask is not None:attn = attn.transpose(0, 1).masked_fill(mask, float('-1e20')).transpose(0, 1)attn = torch.softmax(attn, dim=-1)attn = self.dropout(attn)output = torch.matmul(attn, v)# output: [B, num_heads, len_q, d_v]output = output.transpose(1, 2)# output: [B, len_q, num_heads, d_v]output = output.contiguous().view(-1, len_q, self.dim_v)# output: [B, len_q, num_heads * d_v] = [B, len_q, dim_v]return output

六.完整代码实现

import torch
import torch.nn as nnclass LearnablePositionalEncoding(nn.Module):# Learnable positional encodingdef __init__(self, emb_dim, len):super(LearnablePositionalEncoding, self).__init__()assert emb_dim > 0 and len > 0, 'emb_dim and len must be positive'self.emb_dim = emb_dimself.len = lenself.pe = nn.Parameter(torch.zeros(len, emb_dim))def forward(self, x):return x + self.pe[:x.size(-2), :]class PositionalEncoding(nn.Module):# Sine-cosine positional codingdef __init__(self, emb_dim, max_len, freq=10000.0):super(PositionalEncoding, self).__init__()assert emb_dim > 0 and max_len > 0, 'emb_dim and max_len must be positive'self.emb_dim = emb_dimself.max_len = max_lenself.pe = torch.zeros(max_len, emb_dim)pos = torch.arange(0, max_len).unsqueeze(1)# pos: [max_len, 1]div = torch.pow(freq, torch.arange(0, emb_dim, 2) / emb_dim)# div: [ceil(emb_dim / 2)]self.pe[:, 0::2] = torch.sin(pos / div)# torch.sin(pos / div): [max_len, ceil(emb_dim / 2)]self.pe[:, 1::2] = torch.cos(pos / (div if emb_dim % 2 == 0 else div[:-1]))# torch.cos(pos / div): [max_len, floor(emb_dim / 2)]def forward(self, x, len=None):if len is None:len = x.size(-2)print(self.pe[:len, :])return x + self.pe[:len, :]class MultiHeadAttention(nn.Module):def __init__(self, dim, dim_qk=None, dim_v=None, num_heads=1, dropout=0.):super(MultiHeadAttention, self).__init__()dim_qk = dim if dim_qk is None else dim_qkdim_v = dim if dim_v is None else dim_vassert dim % num_heads == 0 and dim_v % num_heads == 0 and dim_qk % num_heads == 0, 'dim must be divisible by num_heads'self.dim = dimself.dim_qk = dim_qkself.dim_v = dim_vself.num_heads = num_headsself.dropout = nn.Dropout(dropout)self.w_q = nn.Linear(dim, dim_qk)self.w_k = nn.Linear(dim, dim_qk)self.w_v = nn.Linear(dim, dim_v)def forward(self, q, k, v, mask=None):# q: [B, len_q, D]# k: [B, len_kv, D]# v: [B, len_kv, D]assert q.ndim == k.ndim == v.ndim == 3, 'input must be 3-dimensional'len_q, len_k, len_v = q.size(1), k.size(1), v.size(1)assert q.size(-1) == k.size(-1) == v.size(-1) == self.dim, 'dimension mismatch'assert len_k == len_v, 'len_k and len_v must be equal'len_kv = len_vq = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)k = self.w_k(k).view(-1, len_kv, self.num_heads, self.dim_qk // self.num_heads)v = self.w_v(v).view(-1, len_kv, self.num_heads, self.dim_v // self.num_heads)# q: [B, len_q, num_heads, dim_qk//num_heads]# k: [B, len_kv, num_heads, dim_qk//num_heads]# v: [B, len_kv, num_heads, dim_v//num_heads]# The following 'dim_(qk)//num_heads' is writen as d_(qk)q = q.transpose(1, 2)k = k.transpose(1, 2)v = v.transpose(1, 2)# q: [B, num_heads, len_q, d_qk]# k: [B, num_heads, len_kv, d_qk]# v: [B, num_heads, len_kv, d_v]attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)# attn: [B, num_heads, len_q, len_kv]if mask is not None:attn = attn.transpose(0, 1).masked_fill(mask, float('-1e20')).transpose(0, 1)attn = torch.softmax(attn, dim=-1)attn = self.dropout(attn)output = torch.matmul(attn, v)# output: [B, num_heads, len_q, d_v]output = output.transpose(1, 2)# output: [B, len_q, num_heads, d_v]output = output.contiguous().view(-1, len_q, self.dim_v)# output: [B, len_q, num_heads * d_v] = [B, len_q, dim_v]return outputclass Feedforward(nn.Module):def __init__(self, dim, hidden_dim=2048, dropout=0., activate=nn.ReLU()):super(Feedforward, self).__init__()self.dim = dimself.hidden_dim = hidden_dimself.dropout = nn.Dropout(dropout)self.fc1 = nn.Linear(dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, dim)self.act = activatedef forward(self, x):x = self.act(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return xdef attn_mask(len):""":param len: length of sequence:return: mask tensor, False for not replaced, True for replaced as -infe.g. attn_mask(3) =tensor([[[False,  True,  True],[False, False,  True],[False, False, False]]])"""mask = torch.triu(torch.ones(len, len, dtype=torch.bool), 1)return maskdef padding_mask(pad_q, pad_k):""":param pad_q: pad label of query (0 is padding, 1 is not padding), [B, len_q]:param pad_k: pad label of key (0 is padding, 1 is not padding), [B, len_k]:return: mask tensor, False for not replaced, True for replaced as -infe.g. pad_q = tensor([[1, 1, 0]], [1, 0, 1])padding_mask(pad_q, pad_q) =tensor([[[False, False,  True],[False, False,  True],[ True,  True,  True]],[[False,  True, False],[ True,  True,  True],[False,  True, False]]])"""assert pad_q.ndim == pad_k.ndim == 2, 'pad_q and pad_k must be 2-dimensional'assert pad_q.size(0) == pad_k.size(0), 'batch size mismatch'mask = pad_q.bool().unsqueeze(2) * pad_k.bool().unsqueeze(1)mask = ~mask# mask: [B, len_q, len_k]return maskclass EncoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(EncoderLayer, self).__init__()self.attn = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)def forward(self, x, mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn(res1, res1, res1, mask)res2 = self.norm2(x)x = x + self.ffn(res2)else:x = self.attn(x, x, x, mask) + xx = self.norm1(x)x = self.ffn(x) + xx = self.norm2(x)return xclass Encoder(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, num_layers=1, dropout=0., pre_norm=False):super(Encoder, self).__init__()self.layers = nn.ModuleList([EncoderLayer(dim, dim_qk, num_heads, dropout, pre_norm) for _ in range(num_layers)])def forward(self, x, mask=None):for layer in self.layers:x = layer(x, mask)return xclass DecoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(DecoderLayer, self).__init__()self.attn1 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.attn2 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)self.norm3 = nn.LayerNorm(dim)def forward(self, x, enc, self_mask=None, pad_mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn1(res1, res1, res1, self_mask)res2 = self.norm2(x)x = x + self.attn2(res2, enc, enc, pad_mask)res3 = self.norm3(x)x = x + self.ffn(res3)else:x = self.attn1(x, x, x, self_mask) + xx = self.norm1(x)x = self.attn2(x, enc, enc, pad_mask) + xx = self.norm2(x)x = self.ffn(x) + xx = self.norm3(x)return xclass Decoder(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, num_layers=1, dropout=0., pre_norm=False):super(Decoder, self).__init__()self.layers = nn.ModuleList([DecoderLayer(dim, dim_qk, num_heads, dropout, pre_norm) for _ in range(num_layers)])def forward(self, x, enc, self_mask=None, pad_mask=None):for layer in self.layers:x = layer(x, enc, self_mask, pad_mask)return xclass Transformer(nn.Module):def __init__(self, dim, vocabulary, num_heads=1, num_layers=1, dropout=0., learnable_pos=False, pre_norm=False):super(Transformer, self).__init__()self.dim = dimself.vocabulary = vocabularyself.num_heads = num_headsself.num_layers = num_layersself.dropout = dropoutself.learnable_pos = learnable_posself.pre_norm = pre_normself.embedding = nn.Embedding(vocabulary, dim)self.pos_enc = LearnablePositionalEncoding(dim, 100) if learnable_pos else PositionalEncoding(dim, 100)self.encoder = Encoder(dim, dim // num_heads, num_heads, num_layers, dropout, pre_norm)self.decoder = Decoder(dim, dim // num_heads, num_heads, num_layers, dropout, pre_norm)self.linear = nn.Linear(dim, vocabulary)def forward(self, src, tgt, src_mask=None, tgt_mask=None, pad_mask=None):# src.shape: torch.Size([2, 10])src = self.embedding(src)# src.shape: torch.Size([2, 10, 512])src = self.pos_enc(src)# src.shape: torch.Size([2, 10, 512])src = self.encoder(src, src_mask)# src.shape: torch.Size([2, 10, 512])# tgt.shape: torch.Size([2, 8])tgt = self.embedding(tgt)# tgt.shape: torch.Size([2, 8, 512])tgt = self.pos_enc(tgt)# tgt.shape: torch.Size([2, 8, 512])tgt = self.decoder(tgt, src, tgt_mask, pad_mask)# tgt.shape: torch.Size([2, 8, 512])output = self.linear(tgt)# output.shape: torch.Size([2, 8, 10000])return outputdef get_mask(self, tgt, src_pad=None):# Under normal circumstances, tgt_pad will perform mask processing when calculating loss, and it isn't necessarily in decoderif src_pad is not None:src_mask = padding_mask(src_pad, src_pad)else:src_mask = Nonetgt_mask = attn_mask(tgt.size(1))if src_pad is not None:pad_mask = padding_mask(torch.zeros_like(tgt), src_pad)else:pad_mask = None# src_mask: [B, len_src, len_src]# tgt_mask: [len_tgt, len_tgt]# pad_mask: [B, len_tgt, len_src]return src_mask, tgt_mask, pad_maskif __name__ == '__main__':model = Transformer(dim=512, vocabulary=10000, num_heads=8, num_layers=6, dropout=0.1, learnable_pos=False, pre_norm=True)src = torch.randint(0, 10000, (2, 10))  # torch.Size([2, 10])tgt = torch.randint(0, 10000, (2, 8))   # torch.Size([2, 8])src_pad = torch.randint(0, 2, (2, 10))  # torch.Size([2, 10])src_mask, tgt_mask, pad_mask = model.get_mask(tgt, src_pad)model(src, tgt, src_mask, tgt_mask, pad_mask)# output.shape: torch.Size([2, 8, 10000])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/912296.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/912296.shtml
英文地址,请注明出处:http://en.pswp.cn/news/912296.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Adobe InDesign 2025

Adobe InDesign 2025(ID2025)桌面出版软件和在线发布工具,报刊杂志印刷排版设计软件。Adobe InDesign中文版主要用于传单设计,海报设计,明信片设计,电子书设计,排版,手册设计,数字杂志,iPad应用程序和在线交互文档。它是首款支持Unicode文本处理的主流DTP应用程序,率先使用新型…

Linux下获取指定时间内某个进程的平均CPU使用率

一、引言 通过pidstat工具可以测量某个进程在两个时间点之间的平均CPU利用率。 二、pidstat工具的安装 pidstat属于sysstat套件的一部分。以Ubuntu系统为例&#xff0c;执行下面命令下载安装sysstat套件&#xff1a; apt-get install sysstat 执行完后&#xff0c;终端执行p…

1.4 蜂鸟E203处理器NICE接口详解

一、NICE接口的概念 NICE&#xff08;Nuclei Instruction Co-unit Extension&#xff09;接口是蜂鸟E203处理器中用于扩展自定义指令的协处理器接口&#xff0c;基于RISC-V标准协处理器扩展机制设计。它允许用户在不修改处理器核流水线的情况下&#xff0c;通过外部硬件加速特…

Oracle 递归 + Decode + 分组函数实现复杂树形统计进阶(第二课)

在上篇文章基础上&#xff0c;我们进一步解决层级数据递归汇总问题 —— 让上级部门的统计结果自动包含所有下级部门数据&#xff08;含多级子部门&#xff09;&#xff0c;并新增请假天数大于 3 天的统计维度。通过递归 CTE、DECODE函数与分组函数的深度结合&#xff0c;实现真…

MySQL 数据类型全面指南:详细说明与关键注意事项

MySQL 数据类型全面指南&#xff1a;详细说明与关键注意事项 MySQL 提供了丰富的数据类型&#xff0c;合理选择对数据库性能、存储效率和数据准确性至关重要。以下是所有数据类型的详细说明及使用注意事项&#xff1a; 一、数值类型 整数类型 类型字节有符号范围无符号范围说…

leetcode437-路径总和III

leetcode 437 思路 利用前缀和hash map解答 前缀和在这里的含义是&#xff1a;从根节点到当前节点的路径上所有节点值的总和 我们使用一个 Map 数据结构来记录这些前缀和及其出现的次数 具体思路如下&#xff1a; 初始化&#xff1a;创建一个 Map &#xff0c;并将前缀和 …

UI前端与数字孪生融合探索新领域:智慧家居的可视化设计与实现

hello宝子们...我们是艾斯视觉擅长ui设计、前端开发、数字孪生、大数据、三维建模、三维动画10年经验!希望我的分享能帮助到您!如需帮助可以评论关注私信我们一起探讨!致敬感谢感恩! 一、引言&#xff1a;智慧家居的数字化转型浪潮 在物联网与人工智能技术的推动下&#xff0c…

数据结构知识点总结--绪论

1.1 数据结构的基本概念 1.1.1 基本概念和术语 主要涉及概念有&#xff1a; 数据、数据元素、数据对象、数据类型、数据结构 #mermaid-svg-uyyvX6J6ofC9rFSB {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-uyyvX6…

pip install mathutils 安装 Blender 的 mathutils 模块时,编译失败了

你遇到的问题是因为你试图通过 pip install mathutils 安装 Blender 的 mathutils 模块时&#xff0c;编译失败了&#xff0c;主要原因是&#xff1a; 2018年 的老版本也不行 pip install mathutils2.79 ❌ 报错核心总结&#xff1a; 缺失头文件 BLI_path_util.h&#xff1a;…

编译安装交叉工具链 riscv-gnu-toolchain

参考链接&#xff1a; https://zhuanlan.zhihu.com/p/258394849 1&#xff0c;下载源码 git clone https://gitee.com/mirrors/riscv-gnu-toolchain 2&#xff0c;进入目录 cd riscv-gnu-toolchain 3&#xff0c;去掉qemu git rm qemu 4&#xff0c;初始化 git submodule…

复制 生成二维码

一、安装插件 1、复制 npm install -g copy-to-clipboard import copy from copy-to-clipboard; 2、生成二维码 & 下载 npm install -g qrcode import QRCode from qrcode.react; 二、功能&#xff1a;生成二维码 & 下载 效果图 1、常规使用&#xff08;下载图片模糊…

自由职业的经营视角

“领导力的核心是帮助他人看到自己看不到的东西。” — 彼得圣吉 最近与一些自由职业者的交流中&#xff0c;发现很多专业人士都会从专业视角来做交流&#xff0c;这也让我更加理解我们海外战略顾问庄老师在每月辅导时的提醒——经营者视角和专业人士视角的不同。这不仅让大家获…

MR30分布式 IO在物流堆垛机的应用

在现代物流行业蓬勃发展的浪潮中&#xff0c;物流堆垛机作为自动化仓储系统的核心设备&#xff0c;承担着货物的高效存取与搬运任务。它凭借自动化操作、高精度定位等优势&#xff0c;极大地提升了仓储空间利用率和货物周转效率。然而&#xff0c;随着物流行业的高速发展&#…

告别固定密钥!在单一账户下用 Cognito 实现 AWS CLI 的 MFA 单点登录

大家好&#xff0c;很多朋友&#xff0c;特别是通过合作伙伴或服务商使用 AWS 的同学&#xff0c;可能会发现自己的 IAM Identity Center 功能受限&#xff0c;无法像在组织管理账户里那样轻松配置 CLI 的 SSO (aws configure sso)。那么&#xff0c;我们就要放弃治疗&#xff…

未来机器视觉软件将更注重成本控制,边缘性能,鲁棒性、多平台支持、模块优化与性能提升,最新版本opencv-4.11.0更新了什么

OpenCV 4.11.0 作为 4.10.0 的后续版本,虽然没有在提供的搜索结果中直接列出详细更新内容,但结合 OpenCV 4.10.0 的重大改进方向(发布于 2024 年 6 月),可以合理推断 4.11.0 版本可能延续了对多平台支持、模块优化和性能提升的强化。以下是基于 OpenCV 近期更新模式的推测…

小程序入门:数据请求全解析

在微信小程序开发中&#xff0c;数据请求是实现丰富功能的关键环节。本文将带你深入了解小程序数据请求的相关知识&#xff0c;包括请求限制、配置方法以及不同请求方式的实现&#xff0c;还会介绍如何在页面加载时自动请求数据&#xff0c;同时附上详细代码示例&#xff0c;让…

开源版gpt4o 多模态MiniGPT-4 实现原理详解

MiniGPT-4是开源的GPT-4的平民版。本文用带你快速掌握多模态大模型MiniGPT-4的模型架构、训练秘诀、实战亮点与改进方向。 1 模型架构全景&#xff1a;三层协同 &#x1f4ca; 模型底部实际输入图像&#xff0c;经 ViT Q-Former 编码。蓝色方块 (视觉编码器)&#xff1a;左侧…

Flutter基础(控制器)

第1步&#xff1a;找个遥控器&#xff08;创建控制器&#xff09;​ // 就像买新遥控器要装电池 TextEditingController myController TextEditingController(); ​​第2步&#xff1a;连上你的玩具&#xff08;绑定到组件&#xff09;​​ TextField(controller: myContro…

Spring Boot使用Redis常用场景

Spring Boot使用Redis常用场景 一、概述&#xff1a;Redis 是什么&#xff1f;为什么要用它&#xff1f; Redis&#xff08;Remote Dictionary Server&#xff09;是一个内存中的数据存储系统&#xff08;类似一个“超级大字典”&#xff09;&#xff0c;它能存各种类型的数据…

CAD文件处理控件Aspose.CAD教程:在 C# 中将 DXF 文件转换为 SVG - AutoCAD C# 示例

概述 使用 C# 轻松将DXF文件转换为SVG。此转换可更好地兼容 Web 应用程序&#xff0c;并增强 CAD 图纸的视觉呈现效果。使用Aspose.CAD for .NET &#xff0c;开发人员可以轻松实现此转换过程。该 SDK 提供强大的功能&#xff0c;使其成为 C# 开发人员的可靠选择。Aspose.CAD …