前情回顾
在之前的章节我们已经构建好了视觉编码器,预处理模块,以及gemma模型的顶层。gemma模型的顶层,主要是构建图中圈出的输入,它把视觉编码器里每个图像patch的编码维度对齐到自然语言token的嵌入维度,并组装成了一个大的输入向量。同时在模型的顶层,我们准备好了位置id 以及attention mask,用来在后面的模型层计算旋转位置编码和注意力得分矩阵。接下来,我们要开始构建gemma模型的架构了。
顶层模型 GemmaForCausalLM
还记得吗,在之前的paligemma模型的顶层,我们有一个GemmaForCausalLM,然后我们通过下面的代码把输入传入了语言模型:
self.language_model = GemmaForCausalLM(config.text_config)
outputs = self.language_model(
inputs_embeds = input_embeds,
position_ids = position_ids,
attention_mask = attention_mask,
kv_cache = kv_cache,
**kwargs
)
现在我们首先要实现这个GemmaForCausalLM。
一般模型的上层是对整个模型逻辑的简单封装,故这里GemmaForCausalLM的作用很简单,它仅仅把上下文编码后的注意力嵌入通过一个MLP转换为不同token的输出概率,也就是logits,然后返回给上层,从而让上层根据概率分布来采样下一个要输出的token是什么。
先给出代码:
class GemmaForCausalLM(nn.Module): ## 匹配
def __init__(self,config:GemmaConfig): ##CasualLM实际上是Transformer模型加一个投影层,即将嵌入转换为对数概率
super().__init__()
self.config = config
self.model = GemmaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size,config.vocab_size,bias=False) def get_input_embeddings(self): ##这里返回的是模型对象本身
return self.model.embed_tokens def tie_weights(self):
self.lm_head.weight = self.model.embed_tokens.weight def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
kv_cache: Optional[KVCache] = None,
position_ids: Optional[torch.Tensor] = None
):
'''
input: [Batch_size, Seq_len, Hidden_size]
output: [Batch_size, Seq_len, Vocab_size]
'''
## [Batch_size, Seq_len, Hidden_size]
outputs = self.model(
attention_mask = attention_mask,
inputs_embeds = inputs_embeds,
kv_cache = kv_cache,
position_ids = position_ids
) hidden_states = outputs
logits = self.lm_head(hidden_states) #lm_head负责将hidden_states映射到vocab_size维度的向量,即logits
logits = logits.float() return_data = {
"logits": logits
}
if kv_cache is not None:
return_data["kv_cache"] = kv_cache ##这里kv cache是要传递下去的,因为自回归的逻辑下,后面生成的token的注意力计算要能够通过kv cache来看到之前的token的kv return return_data
以上便是顶层模型的前向传递过程:
- 就是通过 GemmaModel 生成的注意力嵌入来计算logits
- 注意:由于我们在推理过程中,后续的token计算要用到之前的kv,所以kv cache必须在推理的过程中依次传递下去,同时也要返回给上层,从而在下一次推理运算的时候有kv cache可以传入。
- 我们之前用到了参数捆绑的策略,即token嵌入的模型参数等于嵌入反解码成logits的模型参数,所以我们提供这两个函数供上层调用:
def get_input_embeddings(self): ##这里返回的是模型对象本身return self.model.embed_tokensdef tie_weights(self):self.lm_head.weight = self.model.embed_tokens.weight
GemmaModel
GemmaModel里面实际上就是一个注意力块的序列,就像一个注意力块数组一样,而该层需要做的仅仅是将输入在不同的注意力块里依次传递,并把最后一个注意力块的输出返回给上层即可。
class GemmaModel(nn.Module): ## 匹配def __init__(self,config:GemmaConfig):super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.embed_tokens = nn.Embedding(config.vocab_size,config.hidden_size,padding_idx=config.pad_token_id)
self.layers = nn.ModuleList([GemmaLayer(config, _) for _ in range(config.num_hidden_layers)])
self.norm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps) ##Root Mean Square Normalization均方根标准化,该论文表明并不一定要标准化到标准正态分布,而是只要方差为1就可以def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
kv_cache: Optional[KVCache] = None,
position_ids: Optional[torch.Tensor] = None):#[Batch_size, Seq_len, Hidden_size]
hidden_states = inputs_embeds
normalizer = torch.tensor(self.hidden_size ** 0.5,dtype= inputs_embeds.dtype)
hidden_states = hidden_states * normalizerfor layer in self.layers:
hidden_states = layer(
hidden_states = hidden_states,
attention_mask = attention_mask,
kv_cache = kv_cache,
position_ids = position_ids)## 均方根归一化,不改变shape
hidden_states = self.norm(hidden_states)return hidden_states
这里我们用一个nn.ModuleList来存储所有的GemmaLayer,一个GemmaLayer实际上就是一个attention 块。值得注意的是,在每个attention块内部我们将会做两次归一化,但是每个attention layer的输出不会做归一化,为了使得上层的计算能拿到归一化后的结果,我们在整个list前向传递完了之后再补一个normalization的过程:
hidden_states = self.norm(hidden_states)
- 注意:我们此处用的是RMSNorm,即均方根归一化,关于这个归一化与之前的其他归一化的不同我们会在文末补充一些资料。
有人可能想问,为什么嵌入模型会放到这里:self.embed_tokens
这是因为paligemma的作者是这么实现的,而我们将从huggingface来导入整个模型的参数,所以我们的架构也必须和作者一样才能正确导入参数,所以我们不得不放在这里。
GemmaLayer
在一个attention块里面我们有一个多头注意力层和一个前向传播网络,以及两个归一化,但我们实际的实现中会把归一化提前,即add&norm -> attention -> add&norm -> ff。
这也就是为什么上面提到在layer的输出没有做归一化。
代码如下:
class GemmaLayer(nn.Module): ##匹配def __init__(self,config:GemmaConfig,layer_idx:int): ##layer_idx是当前layer的索引,辅助attention存储kv_cachesuper().__init__()self.config = configself.layer_idx = layer_idxself.hidden_size = config.hidden_sizeself.intermediate_size = config.intermediate_sizeself.input_layernorm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps)self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps)self.mlp = GemmaMLP(config)self.self_attn = GemmaAttention(config,layer_idx)def forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,kv_cache: Optional[KVCache] = None,position_ids: Optional[torch.Tensor] = None)-> Tuple[torch.Tensor,Optional[Tuple[torch.FloatTensor,torch.FloatTensor]]]:
'''input: [Batch_size, Seq_len, Hidden_size]output: [Batch_size, Seq_len, Hidden_size] '''residual = hidden_stateshidden_states = self.input_layernorm(hidden_states)hidden_states,_ = self.self_attn(hidden_states = hidden_states,attention_mask = attention_mask,kv_cache = kv_cache,position_ids = position_ids)hidden_states = residual + hidden_statesresidual = hidden_stateshidden_states = self.post_attention_layernorm(hidden_states)hidden_states = self.mlp(hidden_states)hidden_states = residual + hidden_statesreturn hidden_states
- 在这里的两个归一化我们也用RMSNorm来进行归一化,注意除了归一化,我们还要处理好残差。
- 残差的作用是防止梯度为0导致训练缓慢。
RMSNorm
在前面的第四章节:手搓多模态-04 归一化介绍 里面我们介绍了BatchNormalization和LayerNormalization,我们了解到以下信息:
- 归一化是为了防止不同模型层的输入输出不稳定,分布不均匀导致的训练速度过慢
- BN 依赖于batch 的规模,而batch的规模过大会导致训练速度变相过慢
- LN 通过对单个样本的所有特征进行标准化规避了BN的问题,主要做法是对单个样本的所有特征计算均值和方差,从而将其分布转换为0-1分布。
RMSNormalization,又称均方差归一化,是由论文《Root Mean Square Layer Normalization》提出的,该文章发现,其实分布不稳定的问题和均值没有关系,主要是方差的问题,所以只需要特征的方差稳定即可,不需要计算均值,这样可以减少计算的时间,从而加速训练。
论文提出用均方根来对每个值进行缩放,从而使得方差更小,如图所示。
其中,a_i 表示缩放前的特征值,RMS(a)表示所有特征值计算出来的均方根,g是一个可学习的参数向量,b是偏置。
在paligemma的实现中,RMSNorm的代码如下:
class GemmaRMSNorm(nn.Module): ##匹配
def __init__(self,dim,eps=1e-6): ##dim是hidden_size
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) def _norm(self,x): return x * torch.rsqrt(x.pow(2).mean(dim = -1,keepdim=True) + self.eps) ##rsqrt表示平方的倒数,self.eps是防止分母为0 def forward(self,x):
x = self._norm(x)
output = x * (1.0 + self.weight.float()) ##论文中的可学习参数g
return output.type_as(x)
其中特征的维度为嵌入的维度大小。