GWM: Towards Scalable Gaussian World Models for Robotic Manipulation
- 文章概括
- 摘要
- 1. 引言
- 2. 相关工作
- 3. 高斯世界模型(Gaussian World Model)
- 3.1. 世界状态编码(World State Encoding)
- 3.2. 基于扩散的动态建模(Diffusion-based Dynamics Modeling)
- 4. 实验(Experiments)
- 4.1. 动作条件场景预测(Action-conditioned Scene Prediction)
- 4.2. 基于GWM的模仿学习(GWM-based Imitation Learning)
- 4.3. 基于GWM的强化学习(GWM-based Reinforcement Learning)
- 4.4. 真实世界部署(Real-world Deployment)
- 4.5. 消融分析(Ablation Analysis)
- 5. 结论(Conclusion)
- GWM: Towards Scalable Gaussian World Models for Robotic Manipulation (补充材料)
- A. 数据集与基准(Datasets and Benchmarks)
- B. 实现细节(Implementation Details)
- B.1. EDM 预处理(EDM Preconditioning)
- B.2. 架构设计(Architectural Design)
- B.3. 超参数(Hyper-parameters)
文章概括
引用:
@article{lu2025gwm,title={GWM: Towards Scalable Gaussian World Models for Robotic Manipulation},author={Lu, Guanxing and Jia, Baoxiong and Li, Puhao and Chen, Yixin and Wang, Ziwei and Tang, Yansong and Huang, Siyuan},journal={arXiv preprint arXiv:2508.17600},year={2025}
}
Lu, G., Jia, B., Li, P., Chen, Y., Wang, Z., Tang, Y. and Huang, S., 2025. GWM: Towards Scalable Gaussian World Models for Robotic Manipulation. arXiv preprint arXiv:2508.17600.
主页: https://gaussian-world-model.github.io
原文: https://arxiv.org/abs/2508.17600
代码、数据和视频:
系列文章:
请在 《《《文章》》》 专栏中查找
宇宙声明!
引用解析部分属于自我理解补充,如有错误可以评论讨论然后改正!
摘要
在学习得到的世界模型中训练机器人策略正成为一种趋势,这是由于真实世界交互的低效性。现有的基于图像的世界模型和策略虽然已经展示了早期的成功,但缺乏健壮的几何信息,而这种信息需要对三维世界保持一致的空间和物理理解,即便在互联网规模的视频源上进行过预训练,也依然不足。为此,我们提出了一种新的世界模型分支,称为高斯世界模型(Gaussian World Model, GWM),用于机器人操作,它通过推断在机器人动作作用下高斯基元(Gaussian primitives)的传播来重建未来状态。其核心是一个与三维变分自编码器(3D variational autoencoder)结合的潜在扩散Transformer(latent Diffusion Transformer, DiT),能够利用高斯点绘(Gaussian Splatting)实现精细的场景级未来状态重建。GWM不仅可以通过自监督的未来预测训练来增强模仿学习智能体的视觉表征,还可以作为一种神经模拟器支持基于模型的强化学习。无论在仿真还是现实实验中,结果都表明GWM能够在多样化的机器人动作条件下精确预测未来场景,并且可以进一步用于训练出显著优于当前最先进方法的策略,展现了三维世界模型在初始数据扩展方面的潜力。
图1. 高斯世界模型(Gaussian World Model, GWM)是一种新颖的世界模型分支,它基于三维高斯点绘(3D Gaussian Splatting)表示来预测动态的未来状态,并支持机器人操作。它促进了基于动作条件的三维视频预测,提升了模仿学习中的视觉表示学习能力,并作为一种稳健的神经模拟器服务于基于模型的强化学习。
1. 引言
人类能够从有限的感官输入中构建预测性世界模型,使其能够预见未来的结果并适应新的情境 [12, 18]。受到这一能力的启发,世界模型学习推动了智能体的重大进展,使其在自动驾驶 [14, 26, 27, 80, 98] 和游戏 [1, 18–22, 66, 89] 等领域表现出色。随着智能体日益与物理世界互动,推进面向机器人操作的世界模型学习成为一项重要的研究方向,因为它理想情况下能够使机器人具备关于交互进行推理、预测物理动力学、并适应多样化未知环境的能力。
这自然引出了以下问题:如何有效地表示、构建并利用世界模型来增强机器人操作?这样的需求对现有的表示方法和模型提出了重大挑战。
-
三维表示的必要性
高容量的架构 [25, 77] 和互联网规模的预训练,使基于视频的生成模型成为捕捉世界动态信息的强大工具,这极大地提升了策略学习 [82, 87]。然而,它们依赖图像输入,使其容易受到未见过的视觉变化(例如光照、相机姿态、纹理等)[40] 的影响,因为它们缺乏三维几何和空间理解。尽管RGB-D和多视角 [16, 17] 方案试图缓解这一差距,但在一致的三维空间中隐式对齐图像补丁特征仍然具有挑战性 [62, 100],这使得稳健性问题依然没有解决。这凸显了需要一种能够将精细视觉细节与三维空间信息相结合的表示方式,以提升面向机器人操作的世界建模。 -
效率与可扩展性
为了从二维图像中识别出一种既能保留三维几何结构又能保留精细视觉细节的三维表示,多视角三维重建方法(例如神经辐射场 NeRF [57] 和三维高斯点绘 3D Gaussian Splatting, 3D-GS [35])提供了自然的解决方案。其中,3D-GS 尤其具有吸引力,因为它对三维场景进行了显式的逐高斯建模,将点云等高效的三维表示与高保真渲染相结合。 然而,由于这些方法主要依赖于离线的逐场景重建,它们的计算需求在应用到机器人操作时带来了重大挑战 [49, 91],尤其是在基于模型的强化学习(Model-based Reinforcement Learning, MBRL)中,从而限制了它们的可扩展性。
为此,我们提出了高斯世界模型(Gaussian World Model, GWM),这是一种新颖的三维世界模型,它将3D-GS与高容量生成模型结合,用于机器人操作。具体来说,我们的方法结合了前馈式3D-GS重建的最新进展与扩散Transformer(Diffusion Transformers, DiTs),使得在当前观测和机器人动作条件下,通过高斯渲染实现精细的未来场景重建。 为了实现实时训练和推理,我们设计了一种三维高斯变分自编码器(3D Gaussian Variational Autoencoder, VAE),用于从三维高斯中提取潜在表示,使基于扩散的世界模型能够在紧凑的潜在空间中高效运行。通过这种新颖的设计,我们证明了GWM能够增强视觉表示学习,提升其作为模仿学习视觉编码器的作用,同时还可以作为一种稳健的神经模拟器服务于基于模型的强化学习(RL)。
为了全面评估GWM,我们在动作条件视频预测、模仿学习和基于模型的强化学习设置下进行了广泛的实验,涵盖了跨越三个领域的31个多样化机器人任务。针对现实场景的评估,我们引入了一个包含20种变体的Franka PnP任务套件,涵盖了域内和域外的设置。在消融实验中,我们同时评估了感知指标和成功率,以验证每个组成模块的有效性。GWM持续优于之前的基线方法,包括最先进的基于图像的世界模型,展现了显著的优势,并突出了其数据扩展潜力。
总而言之,我们的主要贡献有三点:
- 我们提出了GWM,这是一种新颖的三维世界模型,由高斯扩散Transformer和高斯VAE实现高效的动态建模。GWM能够以可扩展的端到端方式学习预测准确的未来状态和动力学,而无需人工干预。
- GWM可以轻松集成到离线模仿学习和在线强化学习中,并具备卓越的效率,展现出在基于学习的机器人操作中令人印象深刻的扩展潜力。
- 我们通过在两个具有挑战性的仿真环境中的大量实验验证了GWM的有效性,其性能较之前的最先进基线提高了16.25%的巨大幅度。此外,我们在现实场景中验证了其实用性,在20次试验中,GWM将典型的扩散策略提升了30%。
2. 相关工作
世界模型(World Models)
世界模型捕捉场景动态,并通过基于当前观测和动作预测未来状态,从而实现高效学习。它们已在自动驾驶 [14, 26, 27, 80, 98, 102]、游戏智能体 [1, 18–22, 66, 89] 和机器人操作 [23, 67, 83] 中得到了广泛研究。早期的工作 [18–23, 56, 65–67, 89, 96] 学习了一种用于未来预测的潜在空间,并在仿真和真实环境中都取得了强有力的结果 [83]。然而,虽然潜在表示简化了建模,但其难以捕捉世界的精细细节。近期在扩散模型 [24, 71, 72] 和Transformer [64, 77] 的进展推动了世界建模向直接像素空间建模 [1, 50, 51, 87] 转变,从而能够捕捉精细细节并从互联网视频中实现大规模学习。然而,基于图像的模型往往缺乏物理常识 [4],因此限制了它们在机器人操作中的适用性。
高斯点绘(Gaussian Splatting)
3D-GS [35] 使用三维高斯来表示场景,并通过可微分的投射高效地映射到二维平面。与隐式表示(如NeRF [57])相比,它具有更高的效率,因此受益于一些应用,例如侵入性手术 [46]、SLAM [34] 和自动驾驶 [99]。这种优势扩展到四维动态建模 [28, 52, 84],因为三维高斯与点云类似,具有空间意义。然而,这些方法所需的离线逐场景重建为实时应用(如机器人操作)带来了计算挑战。近期的研究 [6, 13, 74, 85, 86, 94, 97, 101] 通过使用大规模数据集学习从像素到高斯的生成映射来解决这一问题,但仍然依赖已知的相机位姿,从而限制了可扩展性。另一条并行研究路径 [8, 37, 70, 79] 探索了从无位姿图像进行前馈的新视角合成,利用预测的点图(point map)作为显式多视角对齐的代理。在这些进展的基础上,本工作开发了一种从无位姿图像构建的可扩展高斯世界模型,从而保证空间感知与可扩展性,以支持策略训练。
视觉操作(Visual Manipulation)
构建具有人类般能力的视觉驱动机器人一直是一项长期挑战。视觉模仿学习方法 [5, 36, 39, 45, 75] 通过使用各种视觉表示模仿专家演示,例如点云 [7, 15]、体素 [44, 69]、NeRFs [11, 32, 41, 43, 68, 91] 和 3D-GS [49]。尽管这些模型在已学习的任务中有效,但它们在未见过的真实场景中表现不佳 [53, 54]。强化学习(RL)通过试错来优化策略,从而弥补了这一缺陷,但它需要昂贵的现实世界执行过程。因此,许多方法采用“仿真到现实迁移”(sim-to-real transfer),即在世界的数字孪生体中学习RL策略并将其部署到任务执行中。然而,由于这些方法依赖于预定义资产 [3, 58, 78] 或将现实世界物体转化为仿真的劳动密集型过程 [10, 38, 42, 47, 48, 63],可扩展性仍然是一个挑战。为了解决这些局限性,GWM专注于同时为模仿学习提供更强的视觉表示,并为视觉强化学习提供一种高效的神经模拟器,从而实现更有效且更具可扩展性的机器人操作。
3. 高斯世界模型(Gaussian World Model)
我们的方法的整体流程如图2所示,其中我们构建了一个高斯世界模型,用来推断由三维高斯基元(3D Gaussian primitives)表示的未来场景重建。具体来说,我们首先将真实世界的视觉输入编码为潜在的三维高斯表示(第3.1节),然后利用基于扩散的条件生成模型,在给定机器人状态和动作的情况下,学习表示的动态变化(第3.2节)。我们展示了GWM可以灵活地集成到离线模仿学习和在线基于模型的强化学习中,以应对多样化的机器人操作任务(第3.3节)。
图2. GWM的整体流程,主要由一个三维变分编码器和一个潜在扩散Transformer组成。三维变分编码器将由基础重建模型估计得到的高斯点(Gaussian Splats)嵌入到一个紧凑的潜在空间中,而扩散Transformer则在这些潜在补丁(latent patches)上进行操作,在给定机器人动作和去噪时间步的条件下,交互式地“想象”未来的高斯点。
1. 输入图像 → 高斯点绘 (Splatt3R)
输入可以是单张或多张未配准图像(Unposed Images)。
通过 Splatt3R [70] 将输入图像转化为 3D高斯点云 (Gaussian Splats GtG_tGt),这一步得到的是场景在当前时刻的三维结构和外观表示。
2. 3D VAE 编码器 → 潜在表示
高斯点云 GtG_tGt 被送入 3D VAE,压缩成一个紧凑的潜在表示(Compact Representation)。
在潜在空间中引入了随机噪声(Random Noise),用于扩散建模。
3. 位置嵌入 & 条件信息
对潜在特征加入 位置嵌入 (Positional Embedding, RoPE),让模型感知空间关系。
条件信息包括:
时间步 τ\tauτ(扩散噪声步数),通过 AdaLN(自适应层归一化) 融合到特征中;
机器人动作 ata_tat,作为 Cross Attention 的键值 (KV),让预测与动作相关联。
4. 潜在扩散Transformer
在潜在空间上运行扩散Transformer,包含:
Cross Attention:将机器人动作与潜在表示对齐;
Feed-Forward 层:进一步建模时序和空间特征;
AdaLN:对每一层的输入做自适应调制,提高稳定性。
所有注意力机制采用 RMSNorm 归一化,保证训练稳定。
5. 解码器 → 未来高斯点云 Gt+1G_{t+1}Gt+1
预测得到的潜在表示通过 3D VAE 解码器 转换回三维高斯点云。
得到的是 未来时刻 t+1t+1t+1 的场景重建,即模型根据当前状态和机器人动作想象出的未来画面。
3.1. 世界状态编码(World State Encoding)
前馈式三维高斯点绘(Feed-forward 3D Gaussian Splatting)
给定一个世界状态的单视角或双视角图像输入 I={I}i={1,2}\mathcal{I}=\{I\}_{i=\{1,2\}}I={I}i={1,2},我们的目标是首先将场景编码为三维高斯表示,以便进行动力学学习和预测。三维高斯点绘(3D-GS)使用多个非结构化的三维高斯核来表示一个三维场景:
G={xp,σp,Σp,Cp}p∈P,G=\{x_p, \sigma_p, \Sigma_p, \mathcal{C}_p\}_{p\in \mathcal{P}}, G={xp,σp,Σp,Cp}p∈P,
其中,xp,σp,Σp,Cpx_p, \sigma_p, \Sigma_p, \mathcal{C}_pxp,σp,Σp,Cp 分别表示高斯核的中心、透明度、不相关矩阵以及球谐函数系数。
- 它指的是,这些三维高斯核(3D Gaussian kernels)在空间中的排列、数量和形状是不规则的、没有预设网格或拓扑结构的。
- 非结构化的表示:这更像是一团漂浮在空中的彩色点云。每个点都是独立的个体,它们没有预设的邻居关系,也没有固定的排列顺序。三维高斯核就是这种“非结构化点”的升级版。每个“点”不仅仅是一个位置,它还是一个高斯球,拥有自己的位置、尺寸、形状和颜色。
想象一下你正在用成千上万个“彩色光球”来重建一个三维世界。这四个参数就是用来描述每一个光球的属性。
xpx_pxp : 高斯核的中心(Position)
直观理解:这就像光球的位置坐标,比如 (1.5,2.3,0.8)。它告诉我们这个光球在三维空间中的具体位置。
作用:决定了高斯核在场景中的落脚点。在训练过程中,模型会不断地调整这些位置,让它们最有效地“填满”整个场景,特别是在物体的表面。
σp\sigma_pσp : 透明度(Opacity)
直观理解:这就像光球的透明度或不透明度。它的值通常在 0 到 1 之间。
作用:决定了该高斯核对最终图像的贡献程度。
如果透明度接近 1,它就是一个几乎不透明的“实心球”,会强烈地影响它所在位置的颜色。
如果透明度接近 0,它就是一个“虚影”,几乎不会对图像产生影响。
这对于表示半透明物体(如玻璃、烟雾)或者在遮挡关系中至关重要。模型会学习让被遮挡的高斯核具有较低的透明度,从而让它们“消失”在背景中。
Σp\Sigma_pΣp : 协方差矩阵(Covariance Matrix)
直观理解:这是最复杂的一个,但我们可以用一个简单的比喻:它决定了光球的尺寸、形状和旋转方向。
一个简单的球形高斯核,其协方差矩阵决定了它的半径。
一个椭球形高斯核,其协方差矩阵决定了它的长轴、短轴以及它在空间中的旋转角度。
作用:让高斯核拥有了弹性。
在平坦、大面积的表面(如墙壁、地面)上,模型可以生成一个又大又扁平的椭球形高斯核,用很少的核就覆盖很大一片区域,这大大提高了效率。
在物体边缘、角落或细节丰富的区域,模型则会生成更小、更接近球形的高斯核,以精确地捕捉这些细节。
为什么叫“不相关矩阵”?
- 这实际上是对协方差矩阵的简化描述。在 3D-GS 的原始论文中,协方差矩阵可以被分解为两个部分:一个缩放矩阵(scaling matrix)和一个旋转矩阵(rotation matrix)。“不相关矩阵”这个词可能指的是一种简化的表示方式,但其核心作用始终是控制高斯核的形状和方向。
Cp\mathcal{C}_pCp : 球谐函数系数(Spherical Harmonics Coefficients)
直观理解:这就像光球的 “彩色涂料”。但它不仅仅是单一的颜色,它是一种更复杂的“涂料”,可以根据观察视角的不同而改变颜色。
作用:决定了高斯核的颜色和光照效果。
简单来说,球谐函数(Spherical Harmonics)是一种数学工具,可以非常有效地表示三维空间中的复杂函数,比如光照。
通过这些系数,模型能够学习到当从不同角度观察一个物体时,它的颜色会有什么变化(比如高光、阴影)。这使得 3D-GS
渲染出来的图像具有非常真实的光照效果,而不仅仅是简单的贴图。举例:想象一个闪亮的红色苹果。
从正面看,它可能大部分是红色。
但从某个角度看,你会看到一个白色的高光点。
这并不是因为苹果表面有不同的颜色,而是因为它对光线的反射方式不同。
球谐函数系数正是用来捕捉这种视角依赖的颜色变化,让渲染结果看起来更加逼真。
总结
这四个参数共同工作,就像一个艺术家的工具箱:
xpx_pxp :决定了你的笔触落在哪里。
σp\sigma_pσp :决定了你的笔触的透明度。
Σp\Sigma_pΣp :决定了你的笔触的大小和形状。
Cp\mathcal{C}_pCp :决定了你的笔触的颜色和光影效果。
通过对成千上万个这些“笔触”(高斯核)进行精心的优化和组合,3D-GS 就能够从几张二维图片中,神奇地重建出一个高质量的三维场景。
为了从给定视角获得每个像素的颜色,3D-GS将三维高斯投影到图像平面,并计算像素颜色如下:
C(G)=∑p∈PαpSH(dp;Cp)∏j=1p−1(1−αj),(1)C(G) = \sum_{p\in \mathcal{P}} \alpha_p \, \text{SH}(d_p; \mathcal{C}_p) \prod_{j=1}^{p-1}(1-\alpha_j), \tag{1} C(G)=p∈P∑αpSH(dp;Cp)j=1∏p−1(1−αj),(1)
其中:
-
αp\alpha_pαp 表示按照深度顺序排列的有效透明度,即由 Σp\Sigma_pΣp 推导出的二维高斯权重与其整体透明度 σp\sigma_pσp 的乘积;
-
dpd_pdp 表示从相机到 xpx_pxp 的视角方向;
-
SH(⋅)\text{SH}(\cdot)SH(⋅) 是球谐函数。
由于原始的3D-GS依赖于耗时的逐场景离线优化,我们采用可泛化的3D-GS来学习从图像到三维高斯的前馈映射,以加速这一过程。具体来说,我们使用 Splatt3R [70] 获取三维高斯世界状态 GGG:该方法首先利用立体重建模型 Mast3R [37] 从输入图像生成三维点图(3D point maps),然后使用一个额外的预测头,在这些点图的基础上预测每个三维高斯的参数。
这部分的核心是理解公式 (1),它描述了如何计算图像中一个特定像素的颜色。这个过程被称为 “Splating”,就像把颜料“泼洒”到画布上。 想象你面前有一个场景,由成千上万个三维高斯(那些彩色光球)组成。现在,你想用一个相机去拍摄它,得到一张二维照片。这个公式就是相机捕捉颜色的数学描述。
C(G)=∑p∈PαpSH(dp;Cp)∏j=1p−1(1−αj),(1)C(G) = \sum_{p\in \mathcal{P}} \alpha_p \, \text{SH}(d_p; \mathcal{C}_p) \prod_{j=1}^{p-1}(1-\alpha_j), \tag{1} C(G)=p∈P∑αpSH(dp;Cp)j=1∏p−1(1−αj),(1)
- C(G)C(G)C(G): 这是最终得到的像素颜色。
- ∑p∈P\sum_{p\in \mathcal{P}}∑p∈P : 这表示对所有影响这个像素的三维高斯进行累加求和。
- αp\alpha_pαp: 这是高斯核的有效透明度。它不是单纯的 σp\sigma_pσp (整体透明度),而是由两部分相乘得到的:
- 高斯二维投影权重: 当一个三维高斯被投影到二维图像平面时,它的能量分布是呈高斯曲线的。离高斯中心越近的像素,得到的“能量”或权重就越大。
- σp\sigma_pσp 这个高斯核自身的整体透明度。
- 例子:想象一个半透明的蓝色气球。它的 σp\sigma_pσp 可能只有 0.5。当它投影到图像上时,只有它中心位置的像素会获得最大的权重,而边缘的像素权重较小。最终的 αp\alpha_pαp 结合了这两者,告诉我们这个高斯对这个特定像素的贡献有多大。
- SH(dp;Cp)\text{SH}(d_p; \mathcal{C}_p)SH(dp;Cp): 这部分是高斯核在特定视角下的颜色。
- Cp\mathcal{C}_pCp: 前面我们提到的球谐函数系数,它包含了这个高斯的颜色信息和光照信息。
- dpd_pdp: 观察方向,即从相机到这个高斯中心 xpx_pxp 的方向。
- 作用:球谐函数利用 dpd_pdp 和 Cp\mathcal{C}_pCp 计算出,在当前这个视角下,这个高斯应该呈现出什么颜色。这正是为什么 3D-GS 能够渲染出带有高光、阴影等真实光影效果的原因。
- ∏j=1p−1(1−αj)\prod_{j=1}^{p-1}(1-\alpha_j)∏j=1p−1(1−αj): 这部分是累积的透明度,它处理了遮挡关系。
- 直观理解:当你从一个方向看物体时,离你近的物体会遮挡住后面的物体。
- 公式解释:3D-GS 会将所有高斯按深度(从远到近或从近到远,这里是从远到近)进行排序。这个乘积项计算的是在当前高斯 ppp 前面的所有高斯 jjj 的累积透明度。
- 1−αj1−α_j1−αj 表示高斯 jjj 的“透明度”部分(即没有被它遮挡的光线)。
- 将前面的所有 (1−αj)(1−α_j)(1−αj) 相乘,就得到了光线在到达当前高斯 ppp 之前,还剩下多少能量。如果前面的高斯非常不透明(αjα_jαj 接近 1),那么这个乘积就会接近 0,意味着光线基本都被前面的高斯挡住了,后面的高斯 ppp 对最终颜色的贡献就会很小。
从耗时离线优化到前馈式学习
原始的 3D-GS 有一个很大的缺点:它需要对每一个新场景都从头开始进行数小时的离线优化。这就像每次想渲染一个新物体,你都必须让艺术家从零开始雕刻它。 为了解决这个问题,研究者提出了 “可泛化的 3D-GS” 。它的目标是:学习一个通用的模型,能够直接从输入图像快速预测出三维高斯表示,而不需要逐场景的优化。
方法的核心:
- 立体重建 (Mast3R):这个模型首先从给定的单视角或双视角图像中,生成一个三维点图。你可以把这看作一个“粗略”的三维点云,它已经捕捉了场景的基本几何形状。
- 额外的预测头 (Prediction Head):这个神经网络是真正的“魔术”所在。它接收前面生成的三维点图作为输入,然后预测出每个点对应的完整三维高斯参数(xp,σp,Σp,Cpx_p, \sigma_p, \Sigma_p, \mathcal{C}_pxp,σp,Σp,Cp)。
优势:
一旦这个“可泛化”的模型训练完成,它就可以像一个“速写大师”一样,瞬间从新的输入图像中生成三维高斯,省去了漫长的离线优化过程。
这使得 3D-GS 的应用场景大大扩展,例如实时三维重建、快速虚拟现实内容生成等。
总而言之,前馈式的 3D-GS 将原本耗时费力的 “优化问题” 转换成了一个更高效的 “预测问题”。它通过学习一个通用的映射关系,实现了从二维图像到三维高斯表示的快速、直接的转换。
三维高斯 VAE
想象一下,你用 3D-GS 重建了两个不同的场景:一个简单的立方体和一个复杂的雕像。
重建立方体可能只需要 1000 个高斯。
重建雕像可能需要 100 万个高斯。
这就带来了问题:如果你想用一个神经网络来处理这些三维数据,比如让机器人根据这些数据做决策,该怎么办?这个网络的输入层必须是固定大小的,但你的高斯数量却是可变的。
三维高斯VAE(变分自编码器) 就是为了解决这个“可变大小”的问题而设计的。它的作用是:
将一个可变数量的三维高斯 GGG 压缩成一个固定长度的潜在编码 x\text{x}x,然后再将这个编码解压回一个三维高斯表示 G^\hat{G}G^。
通过这个过程,我们就可以将三维数据作为固定大小的向量输入到其他网络中,比如一个策略网络(policy network),让机器人来学习如何操作。
由于每个世界状态中学习到的三维高斯数量在不同场景和任务中可能存在显著差异,我们采用一个三维高斯 VAE(Eθ,Dθ)\text{VAE}(E_\theta, D_\theta)VAE(Eθ,Dθ) 来将重建的三维高斯 GGG 编码为一个固定长度的 NNN 个潜在嵌入 x∈RN×D\text{x} \in \mathbb{R}^{N\times D}x∈RN×D。具体来说,我们首先使用最远点采样(Farthest Point Sampling, FPS)将重建的三维高斯 GGG 下采样为固定数量 NNN 个高斯 GNG_NGN:
GN=FPS(G).G_N = \text{FPS}(G). GN=FPS(G).
最远点采样 (FPS): GN=FPS(G).G_N = \text{FPS}(G). GN=FPS(G).
直观理解:这是一种智能的“抽样”方法。想象你有一大堆三维点,你想从中选出 NNN 个点来代表整体。FPS 的做法是:
随机选择第一个点。
从剩下的点中,选择离所有已选点最远的点作为下一个。
重复这个过程,直到选出 N 个点。
作用:这种方法能确保选出的 NNN 个点均匀地分布在原始的高斯集合 GGG 中,从而保证这 NNN 个点具有代表性,不会集中在某个角落。这解决了数量不固定的问题,将我们感兴趣的高斯从 GGG 变成了固定数量的 GNG_NGN 。
接下来,我们使用这些采样到的高斯 GNG_NGN 作为查询项(queries),通过一个基于交叉注意力的 LLL 层编码器 EθE_\thetaEθ 从所有高斯 GGG 中聚合信息,得到潜在嵌入 x\text{x}x [93]:
X=Eθ(GN,G)=Eθ(L)∘⋯∘Eθ(1)(GN,G),(2)X = E_\theta(G_N, G) = E_\theta^{(L)} \circ \cdots \circ E_\theta^{(1)}(G_N, G), \tag{2} X=Eθ(GN,G)=Eθ(L)∘⋯∘Eθ(1)(GN,G),(2)
其中
Eθ(l)(Q,G)=LayerNorm(CrossAttn(Q,PosEmbed(G))).E_\theta^{(l)}(Q, G) = \text{LayerNorm}(\text{CrossAttn}(Q, \text{PosEmbed}(G))). Eθ(l)(Q,G)=LayerNorm(CrossAttn(Q,PosEmbed(G))).
交叉注意力编码器 (Cross-Attention Encoder): X=Eθ(GN,G)X = E_\theta(G_N, G)X=Eθ(GN,G)
直观理解:现在我们有了 NNN 个具有代表性的“查询”高斯 GNG_NGN 。编码器要做的是让这 NNN 个查询去“问”所有原始高斯 GGG:“你们各自是什么样的?”,然后把得到的信息聚合起来,形成 NNN 个固定长度的向量。
工作方式:
查询(Queries): 采样得到的 NNN 个高斯 GNG_NGN 作为查询项。
键值(Keys & Values): 所有原始高斯 GGG 作为键值。
交叉注意力: 编码器中的交叉注意力 (Cross-Attention) 机制会计算每个查询高斯 GNG_NGN 与所有原始高斯 GGG 之间的关系。这个机制能够让每个查询高斯有效地 “看到”并聚合 整个场景中所有高斯的信息,而不仅仅是它自己的信息。
结果:经过多层(LLL 层)这样的交叉注意力处理后,每个查询高斯都获得了整个场景的上下文信息。最终,我们得到一个固定大小的潜在嵌入 x∈RN×D\text{x}∈\mathbb{R}^{N×D}x∈RN×D 。
在获得潜在编码 x\text{x}x 之后,我们采用一个对称的基于Transformer的解码器 DθD_\thetaDθ,在潜在编码集合内部传播并聚合信息,从而得到重建的高斯 G^\hat{G}G^:
G^=Dθ(x)=LayerNorm(CrossAttn(x,x)).(3)\hat{G} = D_\theta(\text{x}) = \text{LayerNorm}(\text{CrossAttn}(\text{x}, \text{x})). \tag{3} G^=Dθ(x)=LayerNorm(CrossAttn(x,x)).(3)
为了训练三维高斯 VAE(Eθ,Dθ)\text{VAE}(E_\theta, D_\theta)VAE(Eθ,Dθ) ,我们使用重建高斯 G^\hat{G}G^ 的中心与原始高斯 GGG 的中心之间的 Chamfer 损失作为监督。同时,我们还添加了一个基于渲染的损失,以确保我们的重建高斯 G^\hat{G}G^ 能够实现高保真的图像渲染,从而服务于基于图像的策略学习:
LVAE=Chamfer(G^,G)+∥C(G^)−C(G)∥1.(4)\mathcal{L}_{\text{VAE}} = \text{Chamfer}(\hat{G}, G) + \| C(\hat{G}) - C(G)\|_1. \tag{4} LVAE=Chamfer(G^,G)+∥C(G^)−C(G)∥1.(4)
Chamfer\text{Chamfer}Chamfer:
- 原始点云中的每个点,到重建点云中最近点的距离。
- 重建点云中的每个点,到原始点云中最近点的距离。
三维高斯VAE 结构总结表
模块 输入 处理方式 输出 公式 编码器 EθE_\thetaEθ - 原始三维高斯集合 GGG
- 通过FPS采样得到的 GNG_NGN基于交叉注意力(Cross-Attn)的 LLL 层Transformer编码器,将 GNG_NGN 作为查询,从 GGG 中聚合信息 潜在嵌入 X∈RN×DX \in \mathbb{R}^{N\times D}X∈RN×D X=Eθ(GN,G)X = E_\theta(G_N, G)X=Eθ(GN,G)
Eθ(l)(Q,G)=LayerNorm(CrossAttn(Q,PosEmbed(G)))E_\theta^{(l)}(Q,G) = LayerNorm(CrossAttn(Q, PosEmbed(G)))Eθ(l)(Q,G)=LayerNorm(CrossAttn(Q,PosEmbed(G))) (公式2)解码器 DθD_\thetaDθ 潜在嵌入 XXX 基于自注意力(Self-Attn)的Transformer解码器,在潜在集合内部传播并聚合信息 重建高斯 G^\hat{G}G^ G^=Dθ(X)=LayerNorm(SelfAttn(X,X))\hat{G} = D_\theta(X) = LayerNorm(SelfAttn(X,X))G^=Dθ(X)=LayerNorm(SelfAttn(X,X)) (公式3) 损失函数 LVAE\mathcal{L}_{VAE}LVAE - 原始高斯 GGG
- 重建高斯 G^\hat{G}G^Chamfer距离(中心点级别) + 渲染损失(像素级别差异) 优化目标,用于训练 Eθ,DθE_\theta, D_\thetaEθ,Dθ LVAE=Chamfer(G^,G)+∣C(G^)−C(G)∣1\mathcal{L}_{VAE} = Chamfer(\hat{G}, G) + |C(\hat{G}) - C(G)|_1LVAE=Chamfer(G^,G)+∣C(G^)−C(G)∣1 (公式4)
3.2. 基于扩散的动态建模(Diffusion-based Dynamics Modeling)
在时刻 ttt 的编码世界状态嵌入 xt\text{x}_txt 以及其未来状态 xt+1\text{x}_{t+1}xt+1 已知的情况下,我们的目标是学习世界动力学 p(xt+1∣x≤t,a≤t)p(\text{x}_{t+1}\mid \text{x}_{\leq t}, a_{\leq t})p(xt+1∣x≤t,a≤t),其中 x≤t\text{x}_{\leq t}x≤t 和 a≤ta_{\leq t}a≤t 分别表示历史状态和历史动作。具体来说,我们采用基于扩散的动力学模型,将动力学学习转化为一个条件生成问题,即从噪声中生成未来状态 xt+1\text{x}_{t+1}xt+1,条件为历史状态和动作 yt=(x≤t,a≤t)\text{y}_t = (\text{x}_{\leq t}, a_{\leq t})yt=(x≤t,a≤t)。
扩散公式(Diffusion Formulation)
为了生成未来状态,我们从扩散过程的表述开始。具体来说,我们首先向真实的未来状态 xt+10=xt+1\text{x}^0_{t+1} = \text{x}_{t+1}xt+10=xt+1 添加噪声,以通过高斯扰动核获得带噪的未来状态样本 xt+1τ\text{x}^\tau_{t+1}xt+1τ:
p0→τ(xt+1τ∣xt+10)=N(xt+1τ;xt+10,σ2(τ)I),(5)p^{0 \to \tau}(\text{x}^\tau_{t+1}\mid \text{x}^0_{t+1}) = \mathcal{N}(\text{x}^\tau_{t+1}; \text{x}^0_{t+1}, \sigma^2(\tau) I), \tag{5} p0→τ(xt+1τ∣xt+10)=N(xt+1τ;xt+10,σ2(τ)I),(5)
其中,τ\tauτ 是噪声步索引,σ(τ)\sigma(\tau)σ(τ) 是噪声调度函数。该扩散过程可以通过以下随机微分方程(SDE)的解来描述 [72]:
dx=f(x,τ)dτ+g(τ)dw,(6)d\mathbf{x} = \mathbf{f}(\mathbf{x}, \tau)\, d\tau + g(\tau)\, d\mathbf{w}, \tag{6} dx=f(x,τ)dτ+g(τ)dw,(6)
其中,w\mathbf{w}w 表示标准维纳过程(Wiener process),f\mathbf{f}f 是漂移系数(drift coefficient),ggg 是扩散系数(diffusion coefficient)。在这种表述下,高斯扰动核的作用等价于设置 f(x,τ)=0\mathbf{f}(\mathbf{x},\tau)=0f(x,τ)=0,并令 g(τ)=2σ˙(τ)σ(τ)g(\tau)=\sqrt{2 \dot{\sigma}(\tau)\sigma(\tau)}g(τ)=2σ˙(τ)σ(τ)。
直观理解:这是一个描述 “随机过程随时间演变” 的方程。它将公式 (5) 的离散加噪过程,用一个连续的数学工具来描述。
具体含义:
dxd\mathbf{x}dx:表示状态 x 的微小变化。
f(x,τ)dτ\mathbf{f}(\mathbf{x}, \tau)d\tauf(x,τ)dτ:漂移项。它描述了状态 x\mathbf{x}x 的确定性变化,就像一个物体在水流中漂移。
g(τ)dwg(\tau)\, d\mathbf{w}g(τ)dw:扩散项。它描述了状态 x\mathbf{x}x 的随机性变化,就像布朗运动一样。w\mathbf{w}w 是维纳过程,代表着随机噪声。
简化:在你的例子中,f(x,τ)=0\mathbf{f}(\mathbf{x}, \tau)=0f(x,τ)=0,这意味着没有“漂移”,整个变化过程完全由随机噪声($g(τ)$0)驱动,这与公式 (5) 的纯高斯加噪是等价的。
为了从噪声中生成样本,我们可以使用逆时间SDE [2] 对式(6)进行反演,从而得到采样公式:
dx=[f(x,τ)−g(τ)2∇xlogpτ(x)]dτ+g(τ)dwˉ,(7)d\mathbf{x} = \big[\mathbf{f}(\mathbf{x}, \tau) - g(\tau)^2 \nabla_\mathbf{x} \log p^\tau(\mathbf{x})\big] d\tau + g(\tau)\, d\bar{\mathbf{w}}, \tag{7} dx=[f(x,τ)−g(τ)2∇xlogpτ(x)]dτ+g(τ)dwˉ,(7)
其中,wˉ\bar{\mathbf{w}}wˉ 表示逆时间维纳过程,∇xlogpτ(x)\nabla_\mathbf{x} \log p^\tau(\mathbf{x})∇xlogpτ(x) 是得分函数(score function),即关于 x\mathbf{x}x 的对数边际概率的梯度 [30]。
直观理解:这是对公式 (6) 的反演。它告诉我们,如果想让一个随机过程“倒着走”,需要额外加上一个“导向力”来抵消随机性。
关键项:∇xlogpτ(x)\nabla_\mathbf{x} \log p^\tau(\mathbf{x})∇xlogpτ(x)
这被称为得分函数 (score function)。
logpτ(x)\log p^\tau(\mathbf{x})logpτ(x):表示带噪数据 x\mathbf{x}x 的概率密度函数的对数。
∇x\nabla_\mathbf{x}∇x :表示对 x\mathbf{x}x 求梯度。
作用:这个梯度项指向概率密度增加最快的方向。想象一下你在一座山(概率分布)上,得分函数就像一个指南针,永远指向山顶的方向。它告诉我们,如果想从一个随机点(带噪数据)回到真实数据(概率山顶),应该往哪个方向走。
与神经网络的联系:由于我们无法直接计算这个得分函数,所以我们用一个神经网络来近似估计它。这就是扩散模型训练的核心。
由于得分函数可以通过神经网络来估计,我们通过最小化采样到的未来状态 x^t+10=Dθ(xt+1τ,yt)\hat{\mathbf{x}}^0_{t+1} = \mathcal{D}_\theta(\mathbf{x}^\tau_{t+1}, \mathbf{y}_t)x^t+10=Dθ(xt+1τ,yt) 与真实未来状态 xt+10\mathbf{x}^0_{t+1}xt+10 之间的差异来学习条件去噪模型 Dθ\mathcal{D}_\thetaDθ:
L(θ)=E[∥Dθ(xt+1τ,ytτ)−xt+10∥22].(8)\mathcal{L}(\theta) = \mathbb{E}\left[\left\| \mathcal{D}_\theta(x^\tau_{t+1}, y^\tau_t) - x^0_{t+1}\right\|_2^2\right]. \tag{8} L(θ)=E[Dθ(xt+1τ,ytτ)−xt+1022].(8)
yt=(x≤t,a≤t)\text{y}_t = (\text{x}_{\leq t}, a_{\leq t})yt=(x≤t,a≤t)
基于EDM的学习(Learning with EDM)
正如 [33] 中指出的,直接学习去噪器 Dθ(xt+1τ,yt)\mathcal{D}_\theta(\text{x}^\tau_{t+1}, \text{y}_t)Dθ(xt+1τ,yt) 可能会受到噪声幅度变化等问题的影响。因此,我们遵循 [1],并采用EDM [33] 中的做法,改为学习一个带有预条件的网络 Fθ\mathcal{F}_\thetaFθ。具体来说,我们将去噪器 Dθ(xt+1τ,yt+1τ)\mathcal{D}_\theta(\text{x}^\tau_{t+1}, \text{y}^\tau_{t+1})Dθ(xt+1τ,yt+1τ) 参数化为:
Dθ(xt+1τ,ytτ)=cskipτxt+1τ+coutτFθ(cinτxt+1τ,ytτ;cnoiseτ),(9)\mathcal{D}_\theta(\text{x}^\tau_{t+1}, \text{y}^\tau_t) = c^\tau_{\text{skip}} \text{x}^\tau_{t+1} + c^\tau_{\text{out}} \, \mathcal{F}_\theta\big(c^\tau_{\text{in}} \text{x}^\tau_{t+1}, \text{y}^\tau_t; c^\tau_{\text{noise}}\big), \tag{9} Dθ(xt+1τ,ytτ)=cskipτxt+1τ+coutτFθ(cinτxt+1τ,ytτ;cnoiseτ),(9)
其中:
-
预条件器 cinτc^\tau_{\text{in}}cinτ 和 coutτc^\tau_{\text{out}}coutτ 用于缩放输入与输出的幅度,
-
cskipτc^\tau_{\text{skip}}cskipτ 调节跳跃连接(skip connection),
-
cnoiseτc^\tau_{\text{noise}}cnoiseτ 将噪声水平映射为额外的条件输入,送入 Fθ\mathcal{F}_\thetaFθ。
这些预条件器的细节在附录 B.1 中给出。
基于 EDM 的学习:为什么要改变训练方式?
首先,让我们回到最初的问题:为什么要引入 EDM (Elucidating Diffusion Models) 这个框架? 原始的扩散模型训练有一个潜在问题:如果直接学习去噪器 Dθ\mathcal{D}_\thetaDθ,它需要处理各种不同噪声水平的输入。当噪声非常多或非常少时,网络的行为可能变得不稳定,难以收敛到最佳解。
高噪声:输入几乎是纯噪声,网络很难识别出其中的微弱信号,训练效率低。
低噪声:输入和真实数据几乎一模一样,网络很容易学会 “什么也不做”,直接输出输入(即平凡解),这导致它没有真正学到去噪的能力。
EDM 的核心思想是:我们不直接训练去噪器 Dθ\mathcal{D}_\thetaDθ ,而是训练一个 “预处理”过的网络 Fθ\mathcal{F}_\thetaFθ 。这个网络经过精心设计,它的输入和输出都经过了缩放(scaling),从而使得无论噪声水平如何,它面临的训练任务都更加稳定和一致。 这就像一个厨师,他不会直接处理各种形状、大小不一的食材,而是先将所有食材切成标准化的块状,然后再进行烹饪。这样,他的烹饪过程就变得更加稳定和可控。
网络 Fθ\mathcal{F}_\thetaFθ :这是真正需要学习的神经网络。它接收预处理后的输入,并执行核心的去噪操作。
cskipτc^\tau_{\text{skip}}cskipτ (跳跃连接):这个参数控制原始带噪数据 xt+1τ\text{x}^\tau_{t+1}xt+1τ 在最终去噪结果中的占比。
当噪声很小时,σ(τ)σ(τ)σ(τ) 接近 0,数据和信号的差异很小,我们希望去噪器主要依赖原始数据,这时 cskipτc^\tau_{\text{skip}}cskipτ 接近 1。
当噪声很大时,σ(τ)σ(τ)σ(τ) 很大,原始数据几乎被噪声淹没,我们希望去噪器主要依赖网络 Fθ\mathcal{F}_\thetaFθ 的输出,这时 cskipτc^\tau_{\text{skip}}cskipτ 接近 0。
cinτc^\tau_{\text{in}}cinτ (输入预条件):这个参数用于缩放网络 Fθ\mathcal{F}_\thetaFθ 的输入。它的设计目标是让网络的输入在任何噪声水平下都具有相似的幅度。这使得网络不必处理幅度差异巨大的输入,从而稳定了训练。
coutτc^\tau_{\text{out}}coutτ (输出预条件):这个参数用于缩放网络 Fθ\mathcal{F}_\thetaFθ 的输出,以确保它与跳跃连接的部分正确组合。
cnoiseτc^\tau_{\text{noise}}cnoiseτ (噪声条件):这是一个将噪声水平 τ 转换为一个可供网络 Fθ\mathcal{F}_\thetaFθ 理解的额外输入。它通常是一个简单的映射函数,比如将 log(σ(τ))\text{log}(σ(τ))log(σ(τ)) 转换为一个嵌入向量,这个向量会通过 AdaLN 或其他方式注入到网络中。
通过这种转换,我们可以将公式 (8) 的目标改写为:
L(θ)=E[∥Fθ(cinτxt+1τ,ytτ)−1coutτ(xt+10−cskipτxt+1τ)∥22].(10)\mathcal{L}(\theta) = \mathbb{E}\left[\Big\|\mathbf{F}_\theta(c^\tau_{\text{in}} \mathbf{x}^\tau_{t+1}, y^\tau_t) - \tfrac{1}{c^\tau_{\text{out}}}\big(\mathbf{x}^0_{t+1} - c^\tau_{\text{skip}} \mathbf{x}^\tau_{t+1}\big)\Big\|_2^2\right]. \tag{10} L(θ)=E[Fθ(cinτxt+1τ,ytτ)−coutτ1(xt+10−cskipτxt+1τ)22].(10)
这一转换的一个关键见解在于:为更好地训练网络 Fθ\mathcal{F}_\thetaFθ 创建了一个新的训练目标,它能够根据噪声调度 σ(τ)\sigma(\tau)σ(τ) 自适应地混合信号与噪声。直观来说:在高噪声水平下(σ(τ)≫σdata\sigma(\tau) \gg \sigma_{\text{data}}σ(τ)≫σdata),cskipτ→0c^\tau_{\text{skip}} \to 0cskipτ→0,网络主要学习预测干净信号。相反,在低噪声水平下(σ(τ)→0\sigma(\tau) \to 0σ(τ)→0),cskipτ→1c^\tau_{\text{skip}} \to 1cskipτ→1,训练目标变为噪声部分,从而避免目标退化为平凡解。
这个公式是将公式 (8) 中的 Dθ\mathcal{D}_\thetaDθ 替换为公式 (9) 之后,对整个损失函数进行的数学变换。虽然看起来复杂,但它背后的思想非常直观:
训练目标不再是直接让网络预测 xt+10\mathbf{x}^0_{t+1}xt+10 ,而是让它预测一个经过精心设计的 “目标值” 。
这个“目标值”是 1coutτ(xt+10−cskipτxt+1τ)\tfrac{1}{c^\tau_{\text{out}}}\big(\mathbf{x}^0_{t+1} - c^\tau_{\text{skip}} \mathbf{x}^\tau_{t+1}\big)coutτ1(xt+10−cskipτxt+1τ),它根据噪声水平 (τττ) 动态变化。
高噪声时 (cskipτ→0c^\tau_{\text{skip}}→0cskipτ→0):训练目标接近 1coutτxt+10\tfrac{1}{c^\tau_{\text{out}}} \mathbf{x}^0_{t+1}coutτ1xt+10 ,网络主要学习预测干净的信号。
低噪声时 (cskipτ→1c^\tau_{\text{skip}}→1cskipτ→1):训练目标接近 1coutτ(xt+10−xt+1τ)\tfrac{1}{c^\tau_{\text{out}}}\big(\mathbf{x}^0_{t+1} - \mathbf{x}^\tau_{t+1}\big)coutτ1(xt+10−xt+1τ),这正是噪声本身!这时,网络主要学习预测噪声。
这种自适应的目标让网络在任何噪声水平下都能够学习到有用的信息,从而极大地提高了训练的稳定性和效率。
实现(Implementation)
在技术实现上,我们使用 DiT [60] 来实现网络 Fθ\mathcal{F}_\thetaFθ。给定一系列实际世界状态的潜在嵌入 {xt0=xt}t=1T\{\mathbf{x}^0_t = \mathbf{x}_t\}_{t=1}^T{xt0=xt}t=1T,我们首先根据公式 (5) 中描述的高斯扰动生成带噪潜在嵌入 {xtτ}t=1T\{\mathbf{x}^\tau_t\}_{t=1}^T{xtτ}t=1T。接下来,我们将这些带噪潜在嵌入与旋转位置编码(Rotary Position Embedding, RoPE [73])拼接,并作为输入传递给 DiT。关于条件 yt=(x≤t0,a≤t,cnoiseτ)\mathbf{y}_t = (\mathbf{x}^0_{\leq t}, a_{\leq t}, c^\tau_{\text{noise}})yt=(x≤t0,a≤t,cnoiseτ),时间嵌入通过自适应层归一化(Adaptive Layer Normalization, AdaLN [61])进行调制,而当前的机器人动作则作为键(keys)和值(values)输入到 DiT 内部的交叉注意力层中,用于条件生成。为了在所有注意力机制中保持稳定性和效率,我们采用具有可学习缩放因子的均方根归一化(Root Mean Square Normalization, RMSNorm [92]),以在处理空间表示的同时,结合时间动作序列作为条件,从而稳定训练。
1. 初始化
- 策略 π(at∣st)π(a_t ∣s_t )π(at∣st): 这是机器人的大脑,它决定了在当前状态 sts_tst 下应该采取什么动作 ata_tat 。
- 高斯世界模型 pθ(st+1,rt∣st,at)p_θ (s_{t+1} ,r_t ∣s_t ,a_t )pθ(st+1,rt∣st,at): 这是机器人的“虚拟世界”模型。它接收当前状态 sts_tst 和动作 ata_tat 作为输入,并预测下一步的状态 st+1s_{t+1}st+1 和奖励 rtr_trt 。这里的 pθp_θpθ 就是我们之前讨论的高斯世界模型(GWM)。
- 回放缓冲区 B\mathcal{B}B: 这是一个存储机器人过去经验的数据库。
2. 循环 N 个周期 (for N epochs do)
- 算法会重复执行以下步骤,直到学习完成。
3. 收集数据(Collect data with π in real environment)
机器人使用当前版本的策略 π 在真实世界中与环境进行交互。
它观察自己的状态 sts_tst,执行一个动作 ata_tat,然后观察环境如何变化到新状态 st+1s_{t+1}st+1 ,并获得奖励rtr_trt 。
这个经验元组 (st,at,st+1,rt)(s_t ,a_t ,s_{t+1} ,r_t )(st,at,st+1,rt) 被收集起来,并添加到回放缓冲区 B\mathcal{B}B中。
这一步的目的是确保模型和策略能够接触到真实世界的数据,从而避免只在虚拟世界中“闭门造车”。
4. 训练世界模型(Train Gaussian world model pθp_θpθ on dataset B\mathcal{B}B via maximum likelihood)
现在,我们使用回放缓冲区 B 中收集到的真实数据,来训练高斯世界模型 p θ 。
最大似然(maximum likelihood) 是训练目标。这表示我们希望高斯世界模型能够尽可能准确地预测出与真实数据相符的结果。
公式中的 arg maxθEB[logpθ(st+1,rt∣st,at)]\argmax_θ \mathbb{E}_{\mathcal{B}} [\log p_θ (s_{t+1} ,r_t ∣s_t ,a_t )]argmaxθEB[logpθ(st+1,rt∣st,at)] 意味着,我们找到一组最优的参数 θ,使得在给定的真实数据上,模型预测出这个真实结果的概率最大。
这一步让机器人的“虚拟世界”模型变得越来越逼真。
5. 优化策略(Optimize policy π inside predictive model)
这是最关键的一步。现在我们有了一个相当准确的“虚拟世界”模型 pθp_θpθ 。我们可以利用这个模型,让机器人进行大量的虚拟练习,而不需要再消耗宝贵的真实世界经验。
公式中的 arg maxπEπ[∑t≥0γtrt]\argmax_π \mathbb{E}_π [∑_{t≥0} γ^t r_t ]argmaxπEπ[∑t≥0γtrt] 是强化学习的标准目标,它的含义是:找到一个最优的策略 πππ,能够最大化机器人未来获得的累积奖励。
机器人会不断地在它的 “虚拟世界” 中模拟行动,尝试各种策略,找到能够获得最高奖励的行动序列。这使得它能在短时间内进行数百万次模拟,从而快速提升自己的决策能力。
6. 循环 (end)
- 当这一轮的策略优化完成后,机器人会回到步骤 3,带着更新后的、更强大的策略 πππ,再次进入真实世界收集数据,然后继续训练世界模型并优化策略。
4. 实验(Experiments)
在实验中,我们主要聚焦于以下问题:
- 在不同领域下,基于动作条件的视频预测结果质量如何?
- 高斯世界模型(GWM)是否能对下游的模仿学习和强化学习带来益处?它是否比基于图像的世界模型表现出更强的鲁棒性?
- 在真实世界的机器人操作任务中,高斯世界模型如何帮助典型的策略(例如扩散策略 [9])?
在以下小节中,我们将详细描述模型在这些关键问题上的性能表现。具体来说,我们在实验中利用了以下三个测试环境和四个任务:
环境(Environments)
为了对GWM的能力进行全面分析,我们在两个合成环境和一个真实环境中评估了我们的方法:
- META-WORLD [90]:一个合成环境,用于学习机器人操作的强化学习策略;
- ROBOCASA [59]:一个大规模、多尺度的合成模仿学习基准,涵盖厨房环境中的多样化机器人操作任务;
- FRANKA-PNP:一个真实世界的抓取与放置环境,使用 Franka Emika FR3 机械臂。
任务(Tasks)
我们精心设计了四个任务,以系统性地在不同测试环境中评估GWM:
- 动作条件场景预测(Action-conditioned scene prediction):评估GWM在世界建模和未来预测中的有效性;
- 基于GWM的模仿学习(GWM-based imitation learning):考察其表示质量及其对基于模仿学习的机器人操作的益处;
- 基于GWM的强化学习(GWM-based RL):探索其在基于模型的强化学习中的潜力;
- 真实任务部署(Real-world task deployment):评估GWM在真实世界机器人操作中的鲁棒性。
4.1. 动作条件场景预测(Action-conditioned Scene Prediction)
实验设置(Experiment Setup) 一个世界模型生成高保真且与动作对齐的 rollout 的能力,对有效的策略优化至关重要。为了评估这一能力,我们在所有考虑的真实与合成环境中,使用人类演示来训练GWM,并在评估时将模型条件化在验证集中采样得到的、从未见过的动作轨迹上,以进行未来预测质量测试。在定量评估方面,我们采用常见的生成质量指标,包括 FVD [76] 来衡量时间一致性,基于图像的指标 PSNR [29] 用于像素级精度,同时还使用 SSIM [81] 和 LPIPS [95] 来评估感知质量。
结果与分析(Results and Analyses) 我们在表1中提供了本方法与 iVideoGPT 的定量比较。如表1所示,我们的方法在合成和真实环境中始终优于当前最先进的基于图像的世界建模方法 iVideoGPT,这表明我们的基于扩散的高斯世界模型学习流程的有效性。值得注意的是,如图3所示,像 iVideoGPT 这样的基于图像的模型容易在捕捉动态细节时出现失败(例如机械手夹爪的动作)。尽管这些细节可能不会在视觉指标上造成大的差异,但它们会显著影响策略学习,这一点我们将在第4.3节进一步讨论。我们在图4中提供了GWM在 ROBOCASA 和 FRANKA-PNP 上的预测结果的更多定性可视化。
4.2. 基于GWM的模仿学习(GWM-based Imitation Learning)
实验设置(Experiment Setup) 如第3.3节所讨论,GWM可以从图像观测中提取信息量丰富的表示,这有望为模仿学习带来益处。我们通过在 ROBOCASA 上测试GWM在模仿学习中的有效性来验证这一性质。ROBOCASA中的任务集包含24个厨房环境的原子任务,并配有相关的语言指令,包括诸如抓取与放置(pick-and-place)、打开(open)和关闭(close)等动作。每个任务都提供了一组有限的 50个真人演示,以及一组 3000个来自 MimicGen [55] 生成的演示。我们在这些演示上训练GWM,并将其作为状态编码传递给最先进的 BC-transformer [59],以在成功率指标上进行定量比较。
结果与分析(Results and Analyses) 我们在 ROBOCASA 基准上的实验结果展示于表2,结果证明了我们的方法在多任务模仿学习场景中的有效性。在24个厨房操作任务中,我们的方法始终优于 BC-Transformer 基线。
- 在有限真人演示(H-50)的情况下,我们的方法在成功率上平均提升了 10.5%;
- 在使用生成演示(G-3000)训练时,我们的方法依然保持了可扩展的性能,平均增益为 7.6%。
值得注意的是,我们的方法在复杂操作任务(如抓取与放置)以及交互性任务(如打开/关闭电器)中表现出特别的优势,这些场景中的性能提升最为显著。这些结果确认了我们的方法能够从视觉观测中提取信息量丰富的表示,从而在实际的机器人操作场景中,有效增强模仿学习的能力。
4.3. 基于GWM的强化学习(GWM-based Reinforcement Learning)
实验设置(Experiment Setup) 我们在 Meta-World [90] 中的六个机器人操作任务上评估了GWM对强化学习策略的支持能力,这些任务具有递增的复杂性。我们实现了一种受 MBPO [31] 启发的基于模型的强化学习方法,使用GWM生成的合成rollout来增强 DrQ-v2 [88] actor-critic 算法的回放缓冲区。 我们将最先进的基于图像的世界模型 iVideoGPT [82] 作为强基线方法。为了公平比较,我们没有对两种方法使用预训练初始化。同时,为了保证公平性,所有比较方法使用相同的上下文长度、预测范围,并且最大训练步数为 1 × 10^5。
结果与分析(Results and Analyses) 如图5所示,GWM在所有六个Meta-World任务上都始终优于iVideoGPT。平均而言,GWM的收敛速度大约比iVideoGPT快 2倍,并且在复杂操作任务上达到了更高的渐近性能。 其优越性能的来源在于:GWM的 三维高斯表示 能够相比纯粹基于图像的方法,更准确地预测操作中的接触动力学与物体运动。这些结果证实了:显式的三维表示在需要精确空间推理的机器人控制任务中,提供了显著的优势。
4.4. 真实世界部署(Real-world Deployment)
实验设置(Experiment Setup) 我们在真实机器人实验中部署了一台 Franka Emika FR3 机械臂 和一个 Panda 夹爪。实验聚焦于现实世界中的一个任务:抓取一个有颜色的杯子,并将其放置到桌子上的盘子上。我们使用 Mujoco AR 远程操作接口 收集了30个演示数据。此外,我们还设置了一台第三视角的 Realsense D435i 相机,用于提供未配准的仅RGB图像作为观测输入。我们在图6中给出了该真实世界任务设置的概览。类似于第4.2节的实验设置,我们将最先进的基于RGB的策略 Diffusion Policy [9] 与“是否使用GWM表示”进行比较,以任务成功率为指标进行定量分析。
结果与分析(Results and Analysis) 如表3所示,在20次试验中,GWM在任务成功率上优于Diffusion Policy(65% 对 35%),这些试验包含了不同的初始起始位置和物体位置(即干扰物)。当出现新的干扰物时,两者的性能差距进一步扩大,这表明GWM具有更强的泛化能力。我们的方法在不同任务变体中始终保持一致的性能,这是因为其高效的世界模型能够捕捉任务相关的动态特征,同时对视觉差异具有鲁棒性。在补充文件中展示了真实世界的rollout,GWM的优势主要源于 更精确的物体定位和更准确的放置操作。这些结果表明,GWM在真实世界机器人操作任务中具备稳健的时空理解能力。
4.5. 消融分析(Ablation Analysis)
我们在 ROBOCASA 上进行了额外实验,以进一步验证我们的设计选择。
高斯点绘的选择(Choice of Gaussian Splatting) 如表4所示,与直接使用扩散Transformer构建基于图像的世界模型(类似于 [1])相比,引入高斯点绘(Gaussian Splatting)显著提升了成功率(SR),从 4% 提高到 18%。虽然PSNR略有下降,但SSIM和LPIPS指标都有所提升,这表明高斯点绘在不同时间步之间提供了更好的三维一致性。这验证了我们的假设:相比于纯二维方法,显式的三维表示能够增强机器人学习的空间理解能力。
三维高斯VAE的选择(Choice of 3D Gaussian VAE) 进一步引入三维VAE组件,使得所有指标(包括PSNR)都持续提升。成功率从 18% 提高到 24%。结果表明,我们的三维高斯VAE能够高效捕捉场景的潜在结构,实现更紧凑的场景表示,同时保持空间理解。
5. 结论(Conclusion)
在本文中,我们提出了一种新颖的 高斯世界模型(Gaussian World Model, GWM),用于机器人操作。该模型通过引入稳健的几何信息,解决了基于图像的世界模型的局限性。我们的方法通过建模机器人动作下高斯基元的传播来重建未来状态。该方法将 扩散Transformer (DiT) 与 三维感知的变分自编码器 相结合,并通过 高斯点绘(Gaussian Splatting) 实现了精确的场景级未来状态重建。我们开发了一个可扩展的数据处理流程,以便在基于模型的强化学习框架中支持测试时更新,从未配准图像中提取对齐的高斯点。在仿真和真实环境中的实验均表明,GWM在未来场景预测和训练更优策略方面具有有效性。
GWM: Towards Scalable Gaussian World Models for Robotic Manipulation (补充材料)
A. 数据集与基准(Datasets and Benchmarks)
Robocasa. 该数据集由机器人操作数据组成,这些数据来自 MuJoCo 仿真环境,使用 Franka Emika Panda 机械臂 收集,主要聚焦于厨房场景。
在我们的实验中,我们使用了 Human-50 (H-50) 和 Generated-3000 (G-3000) 两个数据集,它们由 RoboCasa 提供,并基于人类演示使用 MimicGen [55] 自动生成。该基准包含 24个原子任务,详细信息见表2。
Metaworld. MetaWorld 是一个常用的基准,用于 元强化学习 和 多任务学习。它包含 50个不同的机器人操作任务,这些任务均在仿真环境中使用 Sawyer 机械臂 完成。观测输入是大小为 64 × 64 的RGB图像,动作为一个 4维连续向量。
B. 实现细节(Implementation Details)
B.1. EDM 预处理(EDM Preconditioning)
如第3.2节所述,我们在此列出了为改进网络训练而设计的预处理器 [33]:
cinτ=1σ(τ)2+σdata2(A1)c^\tau_{in} = \frac{1}{\sqrt{\sigma(\tau)^2 + \sigma^2_{data}}} \tag{A1} cinτ=σ(τ)2+σdata21(A1)
coutτ=σ(τ)σdataσ(τ)2+σdata2(A2)c^\tau_{out} = \frac{\sigma(\tau)\sigma_{data}}{\sqrt{\sigma(\tau)^2 + \sigma^2_{data}}} \tag{A2} coutτ=σ(τ)2+σdata2σ(τ)σdata(A2)
cnoiseτ=14log(σ(τ))(A3)c^\tau_{noise} = \frac{1}{4} \log(\sigma(\tau)) \tag{A3} cnoiseτ=41log(σ(τ))(A3)
cskipτ=σdata2σdata2+σ(τ)2(A4)c^\tau_{skip} = \frac{\sigma^2_{data}}{\sigma^2_{data} + \sigma(\tau)^2} \tag{A4} cskipτ=σdata2+σ(τ)2σdata2(A4)
其中,σdata=0.5\sigma_{data} = 0.5σdata=0.5。
噪声参数 σ(τ)\sigma(\tau)σ(τ) 的采样方式如下,以最大化训练的有效性:
log(σ(τ))∼N(Pmean,Pstd2),(A5)\log(\sigma(\tau)) \sim \mathcal{N}(P_{mean}, P^2_{std}), \tag{A5} log(σ(τ))∼N(Pmean,Pstd2),(A5)
其中,Pmean=−0.4P_{mean} = -0.4Pmean=−0.4,Pstd=1.2P_{std} = 1.2Pstd=1.2。
B.2. 架构设计(Architectural Design)
变分自编码器(VAE)采用 基于Transformer的架构,并使用 点嵌入(point embedding) 来编码点云输入。它通过 最远点采样(farthest point sampling) 将原始点云从 N=2048N = 2048N=2048 下采样到可管理的潜在点数量 M=512M = 512M=512,随后通过一系列 自注意力(self-attention) 和 交叉注意力(cross-attention) 模块进行处理。
对于概率变体,编码器输出 均值(mean) 和 对数方差(logvar) 参数,通过 重参数化技巧(reparameterization trick) 来采样潜在向量,同时可以选择性地引入 KL散度正则项。
扩散模型 Dθ\mathcal{D}_\thetaDθ 采用 视觉Transformer (Vision Transformer, DiT) 结构,通过多个Transformer模块处理点图(pointmap)补丁,并使用 自适应层归一化(adaLN) 来对时间步和动作进行条件化。其输入由以下部分组成:
- 当前观测(current observation)、
- 加噪的下一个观测(noisy next observation)、
- 时间嵌入(time embedding)、
- 当前动作嵌入(current action embedding)。
该模型根据 EDM(Elucidated Diffusion Models) 的公式来预测去噪后的下一个状态。
奖励模型 RψR_\psiRψ 结合了 卷积编码 和 序列建模,由带有可选注意力层的 残差块(ResBlocks) 以及后续的 LSTM 组成。编码器处理 一对观测(当前状态与下一个状态),同时以嵌入的动作作为条件,而LSTM则捕获时间依赖关系,最终通过一个 MLP头 来预测奖励。
在推理之前,LSTM的隐藏状态通过一个 热身过程(burn-in procedure) 使用条件帧进行初始化。
B.3. 超参数(Hyper-parameters)
RoboCasa 和 MetaWorld 实验的超参数分别列于表 A2 和表 A1 中。