核心思想与定义
扩散模型的核心思想是:学习一个去噪过程,以逆转一个固定的加噪过程。
-
前向过程(固定): 定义一个马尔可夫链,逐步向数据 x0∼q(x0)\mathbf{x}_0 \sim q(\mathbf{x}_0)x0∼q(x0) 添加高斯噪声,产生一系列噪声逐渐增大的隐变量 x1,...,xT\mathbf{x}_1, ..., \mathbf{x}_Tx1,...,xT。最终 xT\mathbf{x}_TxT 近似为一个标准高斯分布。
q(x1:T∣x0)=∏t=1Tq(xt∣xt−1),其中q(xt∣xt−1)=N(xt;1−βtxt−1,βtI) q(\mathbf{x}_{1:T} | \mathbf{x}_0) = \prod_{t=1}^T q(\mathbf{x}_t | \mathbf{x}_{t-1}), \quad \text{其中} \quad q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}) q(x1:T∣x0)=t=1∏Tq(xt∣xt−1),其中q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)
这里 {βt}t=1T\{\beta_t\}_{t=1}^T{βt}t=1T 是预先定义好的方差调度表。 -
反向过程(可学习): 我们想要学习一个参数化的反向马尔可夫链 pθp_\thetapθ,从噪声 xT∼N(0,I)\mathbf{x}_T \sim \mathcal{N}(\mathbf{0}, \mathbf{I})xT∼N(0,I) 开始,逐步去噪以生成数据。
pθ(x0:T)=p(xT)∏t=1Tpθ(xt−1∣xt),其中pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t)) p_\theta(\mathbf{x}_{0:T}) = p(\mathbf{x}_T) \prod_{t=1}^T p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t), \quad \text{其中} \quad p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \mathbf{\mu}_\theta(\mathbf{x}_t, t), \mathbf{\Sigma}_\theta(\mathbf{x}_t, t)) pθ(x0:T)=p(xT)t=1∏Tpθ(xt−1∣xt),其中pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
我们的目标是让 pθ(x0)p_\theta(\mathbf{x}_0)pθ(x0) 尽可能接近真实数据分布 q(x0)q(\mathbf{x}_0)q(x0)。 -
前向过程的闭式解: 得益于高斯分布的可加性,我们可以直接从 x0\mathbf{x}_0x0 采样任意时刻 ttt 的 xt\mathbf{x}_txt:
q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I) q(\mathbf{x}_t | \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t)\mathbf{I}) q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I)
其中 αt=1−βt\alpha_t = 1 - \beta_tαt=1−βt, αˉt=∏i=1tαi\bar{\alpha}_t = \prod_{i=1}^t \alpha_iαˉt=∏i=1tαi。使用重参数化技巧,可以写为:
xt=αˉtx0+1−αˉtϵ,其中ϵ∼N(0,I) \mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \mathbf{\epsilon}, \quad \text{其中} \quad \mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) xt=αˉtx0+1−αˉtϵ,其中ϵ∼N(0,I)
这个公式至关重要,它允许我们随机采样时间步 ttt 并高效地计算训练损失。
优化目标:变分下界 (VLB/ELBO)
我们的目标是最大化模型生成真实数据的对数似然 logpθ(x0)\log p_\theta(\mathbf{x}_0)logpθ(x0)。由于其难以直接计算,我们转而最大化其变分下界(VLB),也称为证据下界(ELBO)。
logpθ(x0)≥Eq(x1:T∣x0)[logpθ(x0:T)q(x1:T∣x0)]=Eq[logp(xT)∏t=1Tpθ(xt−1∣xt)∏t=1Tq(xt∣xt−1)]≜−LVLB
\begin{aligned}
\log p_\theta(\mathbf{x}_0)
&\geq \mathbb{E}_{q(\mathbf{x}_{1:T} | \mathbf{x}_0)} \left[ \log \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} | \mathbf{x}_0)} \right] \\
&= \mathbb{E}_{q} \left[ \log \frac{ p(\mathbf{x}_T) \prod_{t=1}^T p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t) }{ \prod_{t=1}^T q(\mathbf{x}_t | \mathbf{x}_{t-1}) } \right] \\
&\triangleq -L_{\text{VLB}}
\end{aligned}
logpθ(x0)≥Eq(x1:T∣x0)[logq(x1:T∣x0)pθ(x0:T)]=Eq[log∏t=1Tq(xt∣xt−1)p(xT)∏t=1Tpθ(xt−1∣xt)]≜−LVLB
因此,我们最小化 LVLBL_{\text{VLB}}LVLB。
通过对 LVLBL_{\text{VLB}}LVLB 进行推导(利用马尔可夫性和贝叶斯定理),可以将其分解为以下几项:
LVLB=Eq[DKL(q(xT∣x0)∥p(xT))⏟LT−logpθ(x0∣x1)⏟L0+∑t=2TDKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))⏟Lt−1] L_{\text{VLB}} = \mathbb{E}_q [\underbrace{D_{\text{KL}}(q(\mathbf{x}_T | \mathbf{x}_0) \parallel p(\mathbf{x}_T))}_{L_T} - \underbrace{\log p_\theta(\mathbf{x}_0 | \mathbf{x}_1)}_{L_0} + \sum_{t=2}^T \underbrace{D_{\text{KL}}(q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t))}_{L_{t-1}} ] LVLB=Eq[LTDKL(q(xT∣x0)∥p(xT))−L0logpθ(x0∣x1)+t=2∑TLt−1DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))]
- LTL_TLT: 衡量最终噪声分布与先验分布 N(0,I)\mathcal{N}(\mathbf{0}, \mathbf{I})N(0,I) 的差异。此项没有可学习参数,接近于0,可以忽略。
- L0L_0L0: 重建项,衡量最后一步生成图像与真实图像的差异。此项在原始DDPM中通过一个离散化decoder处理,实践中发现其影响较小。
- Lt−1L_{t-1}Lt−1 (1≤t≤T1 \le t \le T1≤t≤T): 这是最关键的一项。它衡量的是对于每一个去噪步,真实的去噪分布 q(xt−1∣xt,x0)q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0)q(xt−1∣xt,x0) 和 学习的去噪分布 pθ(xt−1∣xt)p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t)pθ(xt−1∣xt) 之间的KL散度。
核心推导:真实的后验分布 q(xt−1∣xt,x0)q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0)q(xt−1∣xt,x0)
根据贝叶斯定理和马尔可夫性,我们可以推导出这个真实的后验分布。它也是一个高斯分布,这意味着我们可以用另一个高斯分布 pθp_\thetapθ 去匹配它。
q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt−1∣x0)q(xt∣x0)∝N(xt;αtxt−1,(1−αt)I)⋅N(xt−1;αˉt−1x0,(1−αˉt−1)I) \begin{aligned} q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0) &= \frac{q(\mathbf{x}_t | \mathbf{x}_{t-1}, \mathbf{x}_0) q(\mathbf{x}_{t-1} | \mathbf{x}_0)}{q(\mathbf{x}_t | \mathbf{x}_0)} \\ &\propto \mathcal{N}(\mathbf{x}_t; \sqrt{\alpha_t} \mathbf{x}_{t-1}, (1 - \alpha_t)\mathbf{I}) \cdot \mathcal{N}(\mathbf{x}_{t-1}; \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0, (1 - \bar{\alpha}_{t-1})\mathbf{I}) \end{aligned} q(xt−1∣xt,x0)=q(xt∣x0)q(xt∣xt−1,x0)q(xt−1∣x0)∝N(xt;αtxt−1,(1−αt)I)⋅N(xt−1;αˉt−1x0,(1−αˉt−1)I)
经过一系列高斯分布密度函数的乘积和配方,可以得出其均值和方差为:
q(xt−1∣xt,x0)=N(xt−1;μ~t(xt,x0),β~tI) q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \mathbf{\tilde{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0), \tilde{\beta}_t \mathbf{I}) q(xt−1∣xt,x0)=N(xt−1;μ~t(xt,x0),β~tI)
其中μ~t(xt,x0)=1αt(xt−βt1−αˉtϵ),β~t=1−αˉt−11−αˉtβt \text{其中} \quad \mathbf{\tilde{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0) = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{\epsilon} \right), \quad \tilde{\beta}_t = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \beta_t 其中μ~t(xt,x0)=αt1(xt−1−αˉtβtϵ),β~t=1−αˉt1−αˉt−1βt
注意:这里 ϵ\mathbf{\epsilon}ϵ 是前向过程中添加到 x0\mathbf{x}_0x0 上生成 xt\mathbf{x}_txt 的噪声。这个 μ~t\mathbf{\tilde{\mu}}_tμ~t 的表达式非常关键!
简化损失函数:从均值预测到噪声预测
现在我们来看要最小化的 Lt−1L_{t-1}Lt−1,它是两个高斯分布的KL散度。高斯分布的KL散度主要由其均值的差异主导(假设方差固定)。
Lt−1=Eq[DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))]=Eq[12σt2∥μ~t(xt,x0)−μθ(xt,t)∥2]+C \begin{aligned} L_{t-1} &= \mathbb{E}_q \left[ D_{\text{KL}}(q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t)) \right] \\ &= \mathbb{E}_q \left[ \frac{1}{2\sigma_t^2} \| \mathbf{\tilde{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0) - \mathbf{\mu}_\theta(\mathbf{x}_t, t) \|^2 \right] + C \end{aligned} Lt−1=Eq[DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))]=Eq[2σt21∥μ~t(xt,x0)−μθ(xt,t)∥2]+C
现在我们有两个选择:
- 让网络 μθ\mathbf{\mu}_\thetaμθ 直接预测均值 μ~t\mathbf{\tilde{\mu}}_tμ~t。
- 根据 μ~t\mathbf{\tilde{\mu}}_tμ~t 的表达式,重新参数化模型。
DDPM选择了第二种方式,因为它效果更好。我们将 μ~t\mathbf{\tilde{\mu}}_tμ~t 的表达式代入:
μθ(xt,t)=1αt(xt−βt1−αˉtϵθ(xt,t)) \mathbf{\mu}_\theta(\mathbf{x}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) \right) μθ(xt,t)=αt1(xt−1−αˉtβtϵθ(xt,t))
这里,我们不再让网络预测均值,而是让它预测噪声 ϵ\mathbf{\epsilon}ϵ,即 ϵθ(xt,t)\mathbf{\epsilon}_\theta(\mathbf{x}_t, t)ϵθ(xt,t)。将这个表达式代入上面的损失函数,经过简化(忽略权重系数),我们得到最终极其简洁的损失函数:
Lsimple=Ex0,t,ϵ∼N(0,I)[∥ϵ−ϵθ(αˉtx0+1−αˉtϵ,t)∥2] L_{\text{simple}} = \mathbb{E}_{\mathbf{x}_0, t, \mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})} \left[ \| \mathbf{\epsilon} - \mathbf{\epsilon}_\theta( \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \mathbf{\epsilon}, t ) \|^2 \right] Lsimple=Ex0,t,ϵ∼N(0,I)[∥ϵ−ϵθ(αˉtx0+1−αˉtϵ,t)∥2]
这个损失函数的直观解释是:对于一张真实图像 x0\mathbf{x}_0x0,随机选择一个时间步 ttt,随机采样一个噪声 ϵ\mathbf{\epsilon}ϵ,构造出噪声图像 xt\mathbf{x}_txt。然后,我们训练一个网络 ϵθ\mathbf{\epsilon}_\thetaϵθ,让它根据 xt\mathbf{x}_txt 和 ttt 来预测出我们添加的噪声 ϵ\mathbf{\epsilon}ϵ。损失就是预测噪声和真实噪声之间的均方误差。
总结:优化流程
- 输入:从训练集中采样一张真实图像 x0\mathbf{x}_0x0。
- 加噪:
- 均匀采样一个时间步 t∼Uniform(1,...,T)t \sim \text{Uniform}(1, ..., T)t∼Uniform(1,...,T)。
- 从标准高斯分布采样噪声 ϵ∼N(0,I)\mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})ϵ∼N(0,I)。
- 计算 xt=αˉtx0+1−αˉtϵ\mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \mathbf{\epsilon}xt=αˉtx0+1−αˉtϵ。
- 预测:将 xt\mathbf{x}_txt 和 ttt 输入神经网络 ϵθ\mathbf{\epsilon}_\thetaϵθ,得到其对噪声的预测 ϵθ(xt,t)\mathbf{\epsilon}_\theta(\mathbf{x}_t, t)ϵθ(xt,t)。
- 优化:计算损失 L=∥ϵ−ϵθ∥2L = \| \mathbf{\epsilon} - \mathbf{\epsilon}_\theta \|^2L=∥ϵ−ϵθ∥2,并通过梯度下降更新网络参数 θ\thetaθ。
- 重复:重复步骤1-4直至收敛。
这个框架的巧妙之处在于,它将一个复杂的生成问题,分解为了 TTT 个相对简单的去噪问题。网络 ϵθ\mathbf{\epsilon}_\thetaϵθ 不需要一步生成完美图像,只需要在每一步完成一个更简单的任务:预测噪声。这使得训练非常稳定,也是扩散模型成功的核心原因。