Qwen3MLP

Qwen3MLP是基于门控机制的MLP模块,采用了类似门控线性单元(GLU)的结构。它通过三个线性变换层(gate_proj、up_proj和down_proj)和SiLU激活函数,先将输入从隐藏维度扩展到中间维度,经过门控计算后再投影回原始维度。该模块保持了输入输出形状的一致性,演示了如何逐步执行前向传播并验证计算正确性,展示了Transformer模型中常用的前馈神经网络结构。
具体代码与测试如下:

import torch
import torch.nn as nn
from transformers.activations import ACT2FNclass Qwen3MLP(nn.Module):def __init__(self, config):super().__init__()self.config = configself.hidden_size = config.hidden_sizeself.intermediate_size = config.intermediate_sizeself.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)self.act_fn = ACT2FN[config.hidden_act] # siludef forward(self, x):down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))return down_proj# 模拟配置类
class MockConfig:def __init__(self):self.hidden_size = 1024self.intermediate_size = 2048self.hidden_act = "silu"# 完整示例
if __name__ == "__main__":# 1. 创建配置对象config = MockConfig()# 2. 初始化Qwen3MLP模块mlp = Qwen3MLP(config)# 3. 创建测试输入数据batch_size = 2seq_length = 8hidden_size = config.hidden_size  # 1024# 输入张量形状: (batch_size, seq_length, hidden_size)input_tensor = torch.randn(batch_size, seq_length, hidden_size)print("=== Qwen3MLP 示例 ===")print(f"配置信息:")print(f"  - hidden_size: {config.hidden_size}")print(f"  - intermediate_size: {config.intermediate_size}")print(f"  - activation: {config.hidden_act}")print(f"\n输入张量形状: {input_tensor.shape}")# 4. 前向传播with torch.no_grad():output_tensor = mlp(input_tensor)print(f"输出张量形状: {output_tensor.shape}")# 5. 验证输出形状与输入形状一致assert output_tensor.shape == input_tensor.shape, \f"输出形状 {output_tensor.shape} 与输入形状 {input_tensor.shape} 不一致"print("\n=== MLP 层内部组件 ===")print(f"gate_proj 权重形状: {mlp.gate_proj.weight.shape}")print(f"up_proj 权重形状: {mlp.up_proj.weight.shape}")print(f"down_proj 权重形状: {mlp.down_proj.weight.shape}")# 6. 逐步计算过程演示print("\n=== 前向传播步骤 ===")with torch.no_grad():# 第一步: 门控投影gate_output = mlp.gate_proj(input_tensor)print(f"1. gate_proj 输出形状: {gate_output.shape}")# 第二步: 激活函数gate_activated = mlp.act_fn(gate_output)print(f"2. 激活函数后形状: {gate_activated.shape}")# 第三步: 上投影up_output = mlp.up_proj(input_tensor)print(f"3. up_proj 输出形状: {up_output.shape}")# 第四步: 门控线性单元 (GLU)glu_output = gate_activated * up_outputprint(f"4. GLU 输出形状: {glu_output.shape}")# 第五步: 下投影final_output = mlp.down_proj(glu_output)print(f"5. down_proj 输出形状: {final_output.shape}")# 验证与直接调用forward的结果一致direct_output = mlp(input_tensor)assert torch.allclose(final_output, direct_output, atol=1e-6), "逐步计算结果与直接调用不一致"print("✓ 逐步计算结果与直接调用结果一致")print("\n=== 示例完成 ===")print(f"MLP 成功处理了形状为 {input_tensor.shape} 的输入,输出形状为 {output_tensor.shape}")
=== Qwen3MLP 示例 ===
配置信息:- hidden_size: 1024- intermediate_size: 2048- activation: silu输入张量形状: torch.Size([2, 8, 1024])
输出张量形状: torch.Size([2, 8, 1024])=== MLP 层内部组件 ===
gate_proj 权重形状: torch.Size([2048, 1024])
up_proj 权重形状: torch.Size([2048, 1024])
down_proj 权重形状: torch.Size([1024, 2048])=== 前向传播步骤 ===
1. gate_proj 输出形状: torch.Size([2, 8, 2048])
2. 激活函数后形状: torch.Size([2, 8, 2048])
3. up_proj 输出形状: torch.Size([2, 8, 2048])
4. GLU 输出形状: torch.Size([2, 8, 2048])
5. down_proj 输出形状: torch.Size([2, 8, 1024])
✓ 逐步计算结果与直接调用结果一致=== 示例完成 ===
MLP 成功处理了形状为 torch.Size([2, 8, 1024]) 的输入,输出形状为 torch.Size([2, 8, 1024])

Qwen3MoeSparseMoeBlock

Qwen3 模型的稀疏混合专家(Sparse MoE)模块,核心是通过“路由机制+多专家并行计算”提升模型在大参数量下的效率与能力。

Qwen3MoeSparseMoeBlock 处理输入的流程可分为 路由计算→专家选择→并行计算→结果聚合 四步:

1. 路由计算:为每个 token 选专家
  • 输入 hidden_states(形状 [batch_size, seq_length, hidden_size])先展平为 [batch*seq, hidden_size]
  • self.gate(线性层)生成 router_logits(每个 token 对 8 个专家的“匹配分数”);
  • 通过 softmax+topk,为每个 token 选 num_experts_per_tok=2 个“最匹配专家”,并得到归一化的路由权重(决定每个专家对 token 的贡献占比)。
2. 专家选择:标记活跃专家

通过 one_hot 编码生成 expert_mask,标记“哪些专家被哪些 token 选中”;再通过 expert_hit 筛选出至少被一个 token 选中的活跃专家(示例中 8 个专家都有 token 命中)。

3. 并行计算:专家各自处理 token

对每个活跃专家,执行:

  • 筛选出“属于当前专家”的 token(通过 expert_mask 定位);
  • 调用该专家的 Qwen3MoeMLP 层(结构同普通 MLP,但参数量仅服务部分 token),完成“门控投影→激活→上投影→下投影”的计算;
  • 用路由权重对专家输出加权(确保不同专家的贡献按匹配度分配)。
4. 结果聚合:合并所有专家输出

通过 index_add_ 将每个专家处理后的 token 结果,按原始位置合并,最终还原为 [batch_size, seq_length, hidden_size] 的输出。


具体代码与测试如下:

import torch.nn as nn
from transformers.activations import ACT2FN
import torch.nn.functional as Fclass Qwen3MoeMLP(nn.Module):def __init__(self, config, intermediate_size=None):super().__init__()self.config = configself.hidden_size = config.hidden_size  # 512self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size# 256self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# 512, 256self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) # 512, 256self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) # 256, 512self.act_fn = ACT2FN[config.hidden_act]def forward(self, x):down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))return down_projclass Qwen3MoeSparseMoeBlock(nn.Module):def __init__(self, config):super().__init__()self.num_experts = config.num_experts # 8self.top_k = config.num_experts_per_tok # 2self.norm_topk_prob = config.norm_topk_prob # Trueself.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) # 512 -> 8self.experts = nn.ModuleList([Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]) #  512 -> 256 -> 512def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape  # 2, 6, 512hidden_states = hidden_states.view(-1, hidden_dim) # 2, 6, 512 -> 12, 512router_logits = self.gate(hidden_states) # 12 8routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # 12 8routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # 12 2if self.norm_topk_prob:  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)routing_weights = routing_weights.to(hidden_states.dtype)final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device) # 12 512expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)# 12 2 8    8 2 12 print("expert_mask: \n",expert_mask)expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() # 8print("expert hit: \n",expert_hit)for expert_idx in expert_hit:expert_layer = self.experts[expert_idx]  # Qwen3MoeMLPidx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) # 4 4 if expert_idx == 0:print("expert_mask[expert_idx].squeeze(0):",expert_mask[expert_idx].squeeze(0))print("idx:",idx)print("top_x:",top_x)print("hidden_states[None, top_x]:",hidden_states[None, top_x].shape)current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)   # 1, 4, 512 -> 4, 512current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]# 4, 512 * 4, 512 -> 4, 512final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) # 2, 6, 512return final_hidden_states, router_logits
class MockConfig:def __init__(self):self.hidden_size = 512self.moe_intermediate_size = 256self.hidden_act = "silu"self.num_experts = 8self.num_experts_per_tok = 2self.norm_topk_prob = Trueimport numpy as np
import random# 设置随机种子以确保可重复性
def set_random_seed(seed=42):"""设置所有随机种子以确保结果可重复"""torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False
# 完整示例
if __name__ == "__main__":set_random_seed(42)config = MockConfig()moe_block = Qwen3MoeSparseMoeBlock(config)batch_size = 2seq_length = 6hidden_size = config.hidden_size  # 512input_tensor = torch.randn(batch_size, seq_length, hidden_size)print("=== Qwen3MoeSparseMoeBlock 示例 ===")print(f"配置信息:")print(f"  - hidden_size: {config.hidden_size}")print(f"  - moe_intermediate_size: {config.moe_intermediate_size}")print(f"  - activation: {config.hidden_act}")print(f"  - num_experts: {config.num_experts}")print(f"  - num_experts_per_tok: {config.num_experts_per_tok}")print(f"  - norm_topk_prob: {config.norm_topk_prob}")print(f"\n输入张量形状: {input_tensor.shape}")with torch.no_grad():output_tensor, router_logits = moe_block(input_tensor)print(f"输出张量形状: {output_tensor.shape}")print(f"路由逻辑形状: {router_logits.shape}")
=== Qwen3MoeSparseMoeBlock 示例 ===
配置信息:- hidden_size: 512- moe_intermediate_size: 256- activation: silu- num_experts: 8- num_experts_per_tok: 2- norm_topk_prob: True输入张量形状: torch.Size([2, 6, 512])
expert_mask: tensor([[[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0],[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],[[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],[0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1],[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]],[[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],[0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1]]])
expert hit: tensor([[0],[1],[2],[3],[4],[5],[6],[7]])
expert_mask[expert_idx].squeeze(0): tensor([[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0],[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]])
idx: tensor([0, 0, 0, 1])
top_x: tensor([ 0,  2, 10,  6])
hidden_states[None, top_x]: torch.Size([1, 4, 512])
输出张量形状: torch.Size([2, 6, 512])
路由逻辑形状: torch.Size([12, 8])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/96074.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/96074.shtml
英文地址,请注明出处:http://en.pswp.cn/web/96074.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

产线相机问题分析思路

现象:复现问题 原因:问题分析、溯源,定位根本原因; 方案:提出解决方案、规避措施 验证:导入、验证方案是否可行(先小批量、再大批量);一. 现象产线反馈4pcs预览又脏污、划…

【开关电源篇】EMI输入电路-超简单解读

1. 输入电路主要包含哪些元件?滤波设计需遵循什么原则? 输入电路是电子设备(如开关电源)的“入口”,核心作用是抑制电磁干扰(EMI)、保护后级电路,其设计直接影响设备的稳定性和电磁…

胜券POS:打造智能移动终端,让零售智慧运营触手可及

零售企业运营中依然存在重重挑战:收银台前的长队消磨着顾客的耐心,仓库里的库存盘点不断侵蚀着员工的精力,导购培训的成本长期居高不下却收效甚微……面对这些痛点,零售企业或许都在等待一个破局的答案。百胜软件胜券POS&#xff…

(回溯/组合)Leetcode77组合+39组合总和+216组合总和III

为什么不能暴力,因为不知道要循环多少次,如果长度为n,难道要循环n次么,回溯的本质还是暴力,但是是可以知道多少层的暴力 之所以要pop是因为回溯相当于一个树形结构,要pop进行第二个分支 剪枝:…

07 下载配置很完善的yum软件源

文章目录前言ping 测试网络排查原因排查虚拟机的虚拟网络是否开启检查net8虚拟网络和Centos 7的ip地址是否在一个局域网点击虚拟网络编辑器点击更改设置记录net8的虚拟网络地址ip a记录Centos 7的ip地址比较net8和Centos 7的ip地址是否在一个网段解决问题问题解决办法修改net8的…

SpringBoot中添加健康检查服务

问题 今天需要给一个Spring工程添加健康检查。 pom.xml <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-actuator</artifactId> </dependency>application.yml management:endpoints:web:e…

AI工具深度测评与选型指南 - AI工具测评框架及方法论

目录引言&#xff1a;AI工具爆发期的机遇与挑战一、从AI模型到AI工具&#xff1a;核心认知与生态解析1.1 DeepSeek&#xff1a;快速出圈的国产大模型代表1.2 大模型的核心能力与类型划分1.2.1 大模型的三层能力与“双系统”类比1.2.2 生成模型与推理模型的核心差异1.3 AI工具与…

Spring Cloud Alibaba快速入门02-Nacos(中)

文章目录实现注册中心-服务发现模拟掉线远程调用1.订单和商品模块的接口商品服务订单服务2.抽取实体类3.订单服务拿到需要调用服务的ip和端口负载均衡步骤1步骤2步骤3步骤4面试题&#xff1a;注册中心宕机&#xff0c;远程调用还能成功吗&#xff1f;1、调用过;远程调用不在依赖…

【Python】数据可视化之热力图

热力图&#xff08;Heatmap&#xff09;是一种通过颜色深浅来展示数据分布、密度和强度等信息的可视化图表。它通过对色块着色来反映数据特征&#xff0c;使用户能够直观地理解数据模式&#xff0c;发现规律&#xff0c;并作出决策。 目录 基本原理 sns.heatmap 代码实现 基…

如何 正确使用 nrm 工具 管理镜像源

目录 nrm 是啥&#xff1f; nrm 的安装 查看你当前已有的镜像源 怎么切换到目标镜像源 添加镜像源 删除镜像源 测试镜像源速度 nrm 是啥&#xff1f; 镜像源&#xff1a;可以理解为&#xff0c;你访问或下载某jar包或依赖的仓库。 nrm&#xff08;Node Registry Manag…

关于对逾期提醒的定时任务~改进完善

Spring Boot 中实现到期提醒任务的定时Job详解在金融或借贷系统中&#xff0c;到期提醒是常见的功能需求。通过定时任务&#xff0c;可以定期扫描即将到期的借款记录&#xff0c;并生成或更新提醒信息。本文基于提供的三个JobHandler类&#xff08;FarExpireRemindJob、MidExpi…

springboot配置请求日志

springboot配置请求日志 一般情况下&#xff0c;接口请求都需要日志记录&#xff0c;Java springboot中的日志记录相对复杂一点 经过实践&#xff0c;以下方案可行&#xff0c;记录一下完整过程 一、创建日志数据模型 创建实体类&#xff0c;也就是日志文件中要记录的数据格式 …

Redis(50) Redis哨兵如何与客户端进行交互?

Redis 哨兵&#xff08;Sentinel&#xff09;不仅负责监控和管理 Redis 主从复制集群的高可用性&#xff0c;还需要与客户端进行有效的交互来实现故障转移后的透明连接切换。下面详细探讨 Redis 哨兵如何与客户端进行交互&#xff0c;并结合代码示例加以说明。 哨兵与客户端的交…

【.Net技术栈梳理】04-核心框架与运行时(线程处理)

文章目录1. 线程管理1.1 线程的核心概念&#xff1a;System.Threading.Thread1.2 现代线程管理&#xff1a;System.Threading.Tasks.Task 和 Task Parallel Library (TPL)1.3 状态管理和异常处理1.4 协调任务&#xff1a;async/await 模式2. 线程间通信2.1 共享内存与竞态条件2…

(JVM)四种垃圾回收算法

在 JVM 中&#xff0c;垃圾回收&#xff08;GC&#xff09;是核心机制之一。为了提升性能与内存利用率&#xff0c;JVM 采用了多种垃圾回收算法。本文总结了 四种常见的 GC 算法&#xff0c;并结合其优缺点与应用场景进行说明。1. 标记-清除&#xff08;Mark-Sweep&#xff09;…

论文阅读:VGGT Visual Geometry Grounded Transformer

论文阅读&#xff1a;VGGT: Visual Geometry Grounded Transformer 今天介绍一篇 CVPR 2025 的 best paper&#xff0c;这篇文章是牛津大学的 VGG 团队的工作&#xff0c;主要围绕着 3D 视觉中的各种任务&#xff0c;这篇文章提出了一种多任务统一的架构&#xff0c;实现一次输…

python编程:一文掌握pypiserver的详细使用

更多内容请见: python3案例和总结-专栏介绍和目录 文章目录 一、 pypiserver 概述 1.1 pypiserver是什么? 1.2 核心特性 1.3 典型应用场景 1.4 pypiserver优缺点 二、 安装与基本使用 2.1 安装 pypiserver 2.2 快速启动(最简模式) 2.3 使用私有服务器安装包 2.4 向私有服务…

Git reset 回退版本

- 第 121 篇 - Date: 2025 - 09 - 06 Author: 郑龙浩&#xff08;仟墨&#xff09; 文章目录Git reset 回退版本1 介绍三种命令区别3 验证三种的区别3 如果不小心git reset --hard将「工作区」和「暂存区」中的内容删除&#xff0c;刚才的记录找不到了&#xff0c;怎么办呢&…

ARM 基础(2)

ARM内核工作模式及其切换条件用户模式(User Mode, usr) 权限最低&#xff0c;运行普通应用程序。只能通过异常被动切换到其他模式。快速中断模式(FIQ Mode, fiq) 处理高速外设中断&#xff0c;专用寄存器减少上下文保存时间&#xff0c;响应周期约4个时钟周期。触发条件为FIQ中…

Flutter 性能优化

Flutter 性能优化是一个系统性的工程&#xff0c;涉及多个层面。 一、性能分析工具&#xff08;Profiling Tools&#xff09; 在开始优化前&#xff0c;必须使用工具定位瓶颈。切忌盲目优化。 1. DevTools 性能视图 DevTools 性能视图 (Performance View) 作用&#xff1a;…