参考:
Transformer模型详解(图解最完整版) - 知乎https://zhuanlan.zhihu.com/p/338817680GitHub - liaoyanqing666/transformer_pytorch: 完整的原版transformer程序,complete origin transformer program
https://github.com/liaoyanqing666/transformer_pytorcharxiv.org/pdf/1706.03762
https://arxiv.org/pdf/1706.03762
一. Transformer的整体架构
Transformer 由 Encoder (编码器)和 Decoder (解码器)两个部分组成,Encoder 和 Decoder 都包含 6 个 block(块)。Transformer 的工作流程大体如下:
第一步:获取输入句子的每一个单词的表示向量 X,X由单词本身的 Embedding(Embedding就是从原始数据提取出来的特征(Feature)) 和单词位置的 Embedding 相加得到。
第二步:将得到的单词表示向量矩阵 (如上图所示,每一行是一个单词的表示 x)传入 Encoder 中,经过 6 个 Encoder block (编码器块)后可以得到句子所有单词的编码信息矩阵 C。如下图,单词向量矩阵用 表示, n 是句子中单词个数,d 是表示向量的维度(论文中 d=512)。每一个 Encoder block (编码器块)输出的矩阵维度与输入完全一致。
第三步:将 Encoder (编码器)输出的编码信息矩阵 C传递到 Decoder(解码器)中,Decoder(解码器) 依次会根据当前翻译过的单词 1~ i 翻译下一个单词 i+1,如下图所示。在使用的过程中,翻译到单词 i+1 的时候需要通过 Mask (掩盖) 操作遮盖住 i+1 之后的单词。
上图 Decoder 接收了 Encoder 的编码矩阵 C,然后首先输入一个翻译开始符 "<Begin>",预测第一个单词 "I";然后输入翻译开始符 "<Begin>" 和单词 "I",预测单词 "have",以此类推。
二. Transformer 的输入
Transformer 中单词的输入表示 x 由单词本身的 Embedding 和单词位置 Embedding (Positional Encoding)相加得到。
2.1 单词 Embedding(词嵌入层)
单词本身的 Embedding 有很多种方式可以获取,例如可以采用 Word2Vec、Glove 等算法预训练得到,也可以在 Transformer 中训练得到。
self.embedding = nn.Embedding(vocabulary, dim)
功能解释:
-
作用:将离散的整数索引(单词ID)转换为连续的向量表示
-
输入:形状为
[sequence_length]
的整数张量 -
输出:形状为
[sequence_length, dim]
的浮点数张量(,n是序列长度,d是特征维度)
参数详解:
参数 | 含义 | 示例值 | 说明 |
---|---|---|---|
vocabulary | 词汇表大小 | 10000 | 表示模型能处理的不同单词/符号总数 |
dim | 嵌入维度 | 512 | 每个单词被表示成的向量长度 |
工作原理:
-
创建一个可学习的嵌入矩阵[vocabulary, dim],例如当
vocabulary=10000
,dim=512
时,是一个10000×512
的矩阵; -
每个整数索引对应矩阵中的一行:
# 假设单词"apple"的ID=42
apple_vector = embedding_matrix[42] # 形状 [512]
在Transformer中的具体作用:
# 输入:src = torch.randint(0, 10000, (2, 10))
# 形状:[batch_size=2, seq_len=10]src_embedded = self.embedding(src)# 输出形状变为:[2, 10, 512]
# 每个整数单词ID被替换为512维的向量
可视化表现:
原始输入 (单词ID):
[ [ 25, 198, 3000, ... ], # 句子1[ 1, 42, 999, ... ] ] # 句子2经过嵌入层后 (向量表示):
[ [ [0.2, -0.5, ..., 1.3], # ID=25的向量[0.8, 0.1, ..., -0.9], # ID=198的向量... ],[ [0.9, -0.2, ..., 0.4], # ID=1的向量[0.3, 0.7, ..., -1.2], # ID=42的向量... ] ]
为什么需要词嵌入:
-
语义表示:相似的单词会有相似的向量表示
-
降维:将离散的ID映射到连续空间(one-hot编码需要10000维 → 嵌入只需512维)
-
可学习:在训练过程中,这些向量会不断调整以更好地表示语义关系
2.2 位置 Embedding(位置编码)
Transformer 的位置编码(Positional Encoding,PE)是模型的关键创新之一,它解决了传统序列模型(如 RNN)固有的顺序处理问题。Transformer 的自注意力机制本身不具备感知序列位置的能力,位置编码通过向输入嵌入添加位置信息,使模型能够理解序列中元素的顺序关系。位置编码计算之后的输出维度和词嵌入层相同,均为()。
位置编码的核心作用:
-
注入位置信息:让模型区分不同位置的相同单词(如 "bank" 在句首 vs 句尾)
-
保持距离关系:编码相对位置和绝对位置信息
-
支持并行计算:避免像 RNN 那样依赖顺序处理
为什么需要位置编码?
-
自注意力的位置不变性:
,计算过程不包含位置信息
-
序列顺序的重要性:
- 自然语言:"猫追狗" ≠ "狗追猫"
- 时序数据:股价序列的顺序决定趋势替代方案对比
方法 | 优点 | 缺点 |
---|---|---|
正弦/余弦 | 泛化性好,理论保证 | 固定模式不灵活 |
可学习 | 适应任务特定模式 | 长度受限,需训练 |
相对位置 | 直接建模相对距离 | 实现复杂 |
位置编码的实际效果
-
早期层作用:帮助模型建立位置感知
-
后期层作用:位置信息被融合到语义表示中
-
可视化示例:
Input: [The, cat, sat, on, mat]
Embed: [E_The, E_cat, E_sat, E_on, E_mat]
Position: [P0, P1, P2, P3, P4]Final: [E_The+P0, E_cat+P1, ... E_mat+P4]
(1)正余弦位置编码(论文采用)
正余弦位置编码的计算公式:
其中:
- `pos` 是token在序列中的位置(从0开始)
- `d_model` 是模型的嵌入维度(即每个token的向量维度)
- `i` 是维度的索引(从0到d_model/2-1)
特点:
- 波长几何级数:覆盖不同频率
- 相对位置可学习:位置偏移的线性变换 PE_{pos+k} 可表示为 PE_{pos} 的线性函数
- 泛化性强:可处理比训练时更长的序列
- 对称性:sin/cos 组合允许模型学习相对位置
代码实现:
class PositionalEncoding(nn.Module):# Sine-cosine positional codingdef __init__(self, emb_dim, max_len, freq=10000.0):super(PositionalEncoding, self).__init__()assert emb_dim > 0 and max_len > 0, 'emb_dim and max_len must be positive'self.emb_dim = emb_dimself.max_len = max_lenself.pe = torch.zeros(max_len, emb_dim)pos = torch.arange(0, max_len).unsqueeze(1)# pos: [max_len, 1]div = torch.pow(freq, torch.arange(0, emb_dim, 2) / emb_dim)# div: [ceil(emb_dim / 2)]self.pe[:, 0::2] = torch.sin(pos / div)# torch.sin(pos / div): [max_len, ceil(emb_dim / 2)]self.pe[:, 1::2] = torch.cos(pos / (div if emb_dim % 2 == 0 else div[:-1]))# torch.cos(pos / div): [max_len, floor(emb_dim / 2)]def forward(self, x, len=None):if len is None:len = x.size(-2)return x + self.pe[:len, :]
例如,指定emb_dim=512和max_len=100,句子长度为10,则位置embedding的数值计算如下(三角函数取弧度制):
(2)可学习位置编码
class LearnablePositionalEncoding(nn.Module):# Learnable positional encodingdef __init__(self, emb_dim, len):super(LearnablePositionalEncoding, self).__init__()assert emb_dim > 0 and len > 0, 'emb_dim and len must be positive'self.emb_dim = emb_dimself.len = lenself.pe = nn.Parameter(torch.zeros(len, emb_dim))def forward(self, x):return x + self.pe[:x.size(-2), :]
特性
- 直接学习位置嵌入:作为模型参数训练
- 灵活性高:可适应特定任务的位置模式
- 长度受限:只能处理预定义的最大长度
- 计算效率高:直接查表无需计算
三. Self-Attention(自注意力机制)和Multi-Head Attention(多头自注意力)
Transformer 的内部结构图,左侧为 Encoder block(编码器),右侧为 Decoder block(解码器)。可以看到:
(1)Encoder block 包含一个 Multi-Head Attention;
(2)Decoder block 包含两个 Multi-Head Attention (其中有一个用到 Masked)。Multi-Head Attention 上方还包括一个 Add & Norm 层,Add 表示残差连接(Residual Connection),用于防止网络退化,Norm 表示Layer Normalization,用于对每一层的激活值进行归一化。
Multi-Head Attention 是 Transformer 的重点,它由 Self-Attention 演变而来,我们先从 Self-Attention 讲起。
3.1 Self-Attention(自注意力机制)
Self-Attention(自注意力)是 Transformer 架构的核心创新,它彻底改变了序列建模的方式。与传统的循环神经网络(RNN)和卷积神经网络(CNN)不同,self-attention 能够直接捕捉序列中任意两个元素之间的关系,无论它们之间的距离有多远:
Self-Attention 的输入用矩阵(n是序列长度,d是特征维度)进行表示,计算如下:
(1)通过可学习的权重矩阵生成Q(查询),K(键值),V(值):
其中是可学习参数。
(2)计算 Self-Attention 的输出:
步骤分解:
-
相似度计算:
计算所有查询-键对之间的点积相似度,
得到的矩阵行列数都为 n,n为句子单词数,这个矩阵可以表示单词之间的 attention 强度。
-
缩放:除以
防止点积过大导致梯度消失
-
归一化:softmax 将相似度转换为概率分布
-
加权求和:用注意力权重对值向量加权求和,得到最终的输出
输入序列: [x1, x2, x3, x4]步骤1: 为每个输入生成Q,K,V向量
x1 → q1, k1, v1
x2 → q2, k2, v2
x3 → q3, k3, v3
x4 → q4, k4, v4步骤2: 计算注意力权重 (以x1为例)
权重1 = softmax(q1·k1 / √d_k)
权重2 = softmax(q1·k2 / √d_k)
权重3 = softmax(q1·k3 / √d_k)
权重4 = softmax(q1·k4 / √d_k)步骤3: 加权求和
输出1 = 权重1*v1 + 权重2*v2 + 权重3*v3 + 权重4*v4
3.2 Multi-Head Attention(多头注意力)
Transformer 使用多头机制增强模型表达能力:
其中每个注意力头:
-
h:注意力头的数量
-
:每个头的独立参数
-
:输出投影矩阵
代码实现:
(1)多头分割处理:使用view
将特征维度分割为多个头,确保每个头的维度:dim_head = dim_qk // num_heads
q = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)
k = ... # 类似处理
v = ... # 类似处理
(2)高效的矩阵运算:使用矩阵乘法并行计算所有位置的注意力分数
attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)
(3)多头合并:使用view
合并多头:num_heads * d_v = dim_v
output = output.transpose(1, 2)
output = output.contiguous().view(-1, len_q, self.dim_v)
完整Multi-Head Attention(多头注意力)的代码实现,这里已经考虑了掩码处理的实现,关于掩码将在后面介绍。
class MultiHeadAttention(nn.Module):def __init__(self, dim, dim_qk=None, dim_v=None, num_heads=1, dropout=0.):super(MultiHeadAttention, self).__init__()dim_qk = dim if dim_qk is None else dim_qkdim_v = dim if dim_v is None else dim_vassert dim % num_heads == 0 and dim_v % num_heads == 0 and dim_qk % num_heads == 0, 'dim must be divisible by num_heads'self.dim = dimself.dim_qk = dim_qkself.dim_v = dim_vself.num_heads = num_headsself.dropout = nn.Dropout(dropout)self.w_q = nn.Linear(dim, dim_qk)self.w_k = nn.Linear(dim, dim_qk)self.w_v = nn.Linear(dim, dim_v)def forward(self, q, k, v, mask=None):# q: [B, len_q, D]# k: [B, len_kv, D]# v: [B, len_kv, D]assert q.ndim == k.ndim == v.ndim == 3, 'input must be 3-dimensional'len_q, len_k, len_v = q.size(1), k.size(1), v.size(1)assert q.size(-1) == k.size(-1) == v.size(-1) == self.dim, 'dimension mismatch'assert len_k == len_v, 'len_k and len_v must be equal'len_kv = len_vq = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)k = self.w_k(k).view(-1, len_kv, self.num_heads, self.dim_qk // self.num_heads)v = self.w_v(v).view(-1, len_kv, self.num_heads, self.dim_v // self.num_heads)# q: [B, len_q, num_heads, dim_qk//num_heads]# k: [B, len_kv, num_heads, dim_qk//num_heads]# v: [B, len_kv, num_heads, dim_v//num_heads]# The following 'dim_(qk)//num_heads' is writen as d_(qk)q = q.transpose(1, 2)k = k.transpose(1, 2)v = v.transpose(1, 2)# q: [B, num_heads, len_q, d_qk]# k: [B, num_heads, len_kv, d_qk]# v: [B, num_heads, len_kv, d_v]attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)# attn: [B, num_heads, len_q, len_kv]if mask is not None:attn = attn.transpose(0, 1).masked_fill(mask, float('-1e20')).transpose(0, 1)attn = torch.softmax(attn, dim=-1)attn = self.dropout(attn)output = torch.matmul(attn, v)# output: [B, num_heads, len_q, d_v]output = output.transpose(1, 2)# output: [B, len_q, num_heads, d_v]output = output.contiguous().view(-1, len_q, self.dim_v)# output: [B, len_q, num_heads * d_v] = [B, len_q, dim_v]return output
六.完整代码实现
import torch
import torch.nn as nnclass LearnablePositionalEncoding(nn.Module):# Learnable positional encodingdef __init__(self, emb_dim, len):super(LearnablePositionalEncoding, self).__init__()assert emb_dim > 0 and len > 0, 'emb_dim and len must be positive'self.emb_dim = emb_dimself.len = lenself.pe = nn.Parameter(torch.zeros(len, emb_dim))def forward(self, x):return x + self.pe[:x.size(-2), :]class PositionalEncoding(nn.Module):# Sine-cosine positional codingdef __init__(self, emb_dim, max_len, freq=10000.0):super(PositionalEncoding, self).__init__()assert emb_dim > 0 and max_len > 0, 'emb_dim and max_len must be positive'self.emb_dim = emb_dimself.max_len = max_lenself.pe = torch.zeros(max_len, emb_dim)pos = torch.arange(0, max_len).unsqueeze(1)# pos: [max_len, 1]div = torch.pow(freq, torch.arange(0, emb_dim, 2) / emb_dim)# div: [ceil(emb_dim / 2)]self.pe[:, 0::2] = torch.sin(pos / div)# torch.sin(pos / div): [max_len, ceil(emb_dim / 2)]self.pe[:, 1::2] = torch.cos(pos / (div if emb_dim % 2 == 0 else div[:-1]))# torch.cos(pos / div): [max_len, floor(emb_dim / 2)]def forward(self, x, len=None):if len is None:len = x.size(-2)print(self.pe[:len, :])return x + self.pe[:len, :]class MultiHeadAttention(nn.Module):def __init__(self, dim, dim_qk=None, dim_v=None, num_heads=1, dropout=0.):super(MultiHeadAttention, self).__init__()dim_qk = dim if dim_qk is None else dim_qkdim_v = dim if dim_v is None else dim_vassert dim % num_heads == 0 and dim_v % num_heads == 0 and dim_qk % num_heads == 0, 'dim must be divisible by num_heads'self.dim = dimself.dim_qk = dim_qkself.dim_v = dim_vself.num_heads = num_headsself.dropout = nn.Dropout(dropout)self.w_q = nn.Linear(dim, dim_qk)self.w_k = nn.Linear(dim, dim_qk)self.w_v = nn.Linear(dim, dim_v)def forward(self, q, k, v, mask=None):# q: [B, len_q, D]# k: [B, len_kv, D]# v: [B, len_kv, D]assert q.ndim == k.ndim == v.ndim == 3, 'input must be 3-dimensional'len_q, len_k, len_v = q.size(1), k.size(1), v.size(1)assert q.size(-1) == k.size(-1) == v.size(-1) == self.dim, 'dimension mismatch'assert len_k == len_v, 'len_k and len_v must be equal'len_kv = len_vq = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)k = self.w_k(k).view(-1, len_kv, self.num_heads, self.dim_qk // self.num_heads)v = self.w_v(v).view(-1, len_kv, self.num_heads, self.dim_v // self.num_heads)# q: [B, len_q, num_heads, dim_qk//num_heads]# k: [B, len_kv, num_heads, dim_qk//num_heads]# v: [B, len_kv, num_heads, dim_v//num_heads]# The following 'dim_(qk)//num_heads' is writen as d_(qk)q = q.transpose(1, 2)k = k.transpose(1, 2)v = v.transpose(1, 2)# q: [B, num_heads, len_q, d_qk]# k: [B, num_heads, len_kv, d_qk]# v: [B, num_heads, len_kv, d_v]attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)# attn: [B, num_heads, len_q, len_kv]if mask is not None:attn = attn.transpose(0, 1).masked_fill(mask, float('-1e20')).transpose(0, 1)attn = torch.softmax(attn, dim=-1)attn = self.dropout(attn)output = torch.matmul(attn, v)# output: [B, num_heads, len_q, d_v]output = output.transpose(1, 2)# output: [B, len_q, num_heads, d_v]output = output.contiguous().view(-1, len_q, self.dim_v)# output: [B, len_q, num_heads * d_v] = [B, len_q, dim_v]return outputclass Feedforward(nn.Module):def __init__(self, dim, hidden_dim=2048, dropout=0., activate=nn.ReLU()):super(Feedforward, self).__init__()self.dim = dimself.hidden_dim = hidden_dimself.dropout = nn.Dropout(dropout)self.fc1 = nn.Linear(dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, dim)self.act = activatedef forward(self, x):x = self.act(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return xdef attn_mask(len):""":param len: length of sequence:return: mask tensor, False for not replaced, True for replaced as -infe.g. attn_mask(3) =tensor([[[False, True, True],[False, False, True],[False, False, False]]])"""mask = torch.triu(torch.ones(len, len, dtype=torch.bool), 1)return maskdef padding_mask(pad_q, pad_k):""":param pad_q: pad label of query (0 is padding, 1 is not padding), [B, len_q]:param pad_k: pad label of key (0 is padding, 1 is not padding), [B, len_k]:return: mask tensor, False for not replaced, True for replaced as -infe.g. pad_q = tensor([[1, 1, 0]], [1, 0, 1])padding_mask(pad_q, pad_q) =tensor([[[False, False, True],[False, False, True],[ True, True, True]],[[False, True, False],[ True, True, True],[False, True, False]]])"""assert pad_q.ndim == pad_k.ndim == 2, 'pad_q and pad_k must be 2-dimensional'assert pad_q.size(0) == pad_k.size(0), 'batch size mismatch'mask = pad_q.bool().unsqueeze(2) * pad_k.bool().unsqueeze(1)mask = ~mask# mask: [B, len_q, len_k]return maskclass EncoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(EncoderLayer, self).__init__()self.attn = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)def forward(self, x, mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn(res1, res1, res1, mask)res2 = self.norm2(x)x = x + self.ffn(res2)else:x = self.attn(x, x, x, mask) + xx = self.norm1(x)x = self.ffn(x) + xx = self.norm2(x)return xclass Encoder(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, num_layers=1, dropout=0., pre_norm=False):super(Encoder, self).__init__()self.layers = nn.ModuleList([EncoderLayer(dim, dim_qk, num_heads, dropout, pre_norm) for _ in range(num_layers)])def forward(self, x, mask=None):for layer in self.layers:x = layer(x, mask)return xclass DecoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(DecoderLayer, self).__init__()self.attn1 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.attn2 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)self.norm3 = nn.LayerNorm(dim)def forward(self, x, enc, self_mask=None, pad_mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn1(res1, res1, res1, self_mask)res2 = self.norm2(x)x = x + self.attn2(res2, enc, enc, pad_mask)res3 = self.norm3(x)x = x + self.ffn(res3)else:x = self.attn1(x, x, x, self_mask) + xx = self.norm1(x)x = self.attn2(x, enc, enc, pad_mask) + xx = self.norm2(x)x = self.ffn(x) + xx = self.norm3(x)return xclass Decoder(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, num_layers=1, dropout=0., pre_norm=False):super(Decoder, self).__init__()self.layers = nn.ModuleList([DecoderLayer(dim, dim_qk, num_heads, dropout, pre_norm) for _ in range(num_layers)])def forward(self, x, enc, self_mask=None, pad_mask=None):for layer in self.layers:x = layer(x, enc, self_mask, pad_mask)return xclass Transformer(nn.Module):def __init__(self, dim, vocabulary, num_heads=1, num_layers=1, dropout=0., learnable_pos=False, pre_norm=False):super(Transformer, self).__init__()self.dim = dimself.vocabulary = vocabularyself.num_heads = num_headsself.num_layers = num_layersself.dropout = dropoutself.learnable_pos = learnable_posself.pre_norm = pre_normself.embedding = nn.Embedding(vocabulary, dim)self.pos_enc = LearnablePositionalEncoding(dim, 100) if learnable_pos else PositionalEncoding(dim, 100)self.encoder = Encoder(dim, dim // num_heads, num_heads, num_layers, dropout, pre_norm)self.decoder = Decoder(dim, dim // num_heads, num_heads, num_layers, dropout, pre_norm)self.linear = nn.Linear(dim, vocabulary)def forward(self, src, tgt, src_mask=None, tgt_mask=None, pad_mask=None):# src.shape: torch.Size([2, 10])src = self.embedding(src)# src.shape: torch.Size([2, 10, 512])src = self.pos_enc(src)# src.shape: torch.Size([2, 10, 512])src = self.encoder(src, src_mask)# src.shape: torch.Size([2, 10, 512])# tgt.shape: torch.Size([2, 8])tgt = self.embedding(tgt)# tgt.shape: torch.Size([2, 8, 512])tgt = self.pos_enc(tgt)# tgt.shape: torch.Size([2, 8, 512])tgt = self.decoder(tgt, src, tgt_mask, pad_mask)# tgt.shape: torch.Size([2, 8, 512])output = self.linear(tgt)# output.shape: torch.Size([2, 8, 10000])return outputdef get_mask(self, tgt, src_pad=None):# Under normal circumstances, tgt_pad will perform mask processing when calculating loss, and it isn't necessarily in decoderif src_pad is not None:src_mask = padding_mask(src_pad, src_pad)else:src_mask = Nonetgt_mask = attn_mask(tgt.size(1))if src_pad is not None:pad_mask = padding_mask(torch.zeros_like(tgt), src_pad)else:pad_mask = None# src_mask: [B, len_src, len_src]# tgt_mask: [len_tgt, len_tgt]# pad_mask: [B, len_tgt, len_src]return src_mask, tgt_mask, pad_maskif __name__ == '__main__':model = Transformer(dim=512, vocabulary=10000, num_heads=8, num_layers=6, dropout=0.1, learnable_pos=False, pre_norm=True)src = torch.randint(0, 10000, (2, 10)) # torch.Size([2, 10])tgt = torch.randint(0, 10000, (2, 8)) # torch.Size([2, 8])src_pad = torch.randint(0, 2, (2, 10)) # torch.Size([2, 10])src_mask, tgt_mask, pad_mask = model.get_mask(tgt, src_pad)model(src, tgt, src_mask, tgt_mask, pad_mask)# output.shape: torch.Size([2, 8, 10000])