变分下界是变分推断 (Variational Inference, VI) 中的一个重要概念,它为我们提供了一种优化概率模型参数的可行方法。VAE 是 Kingma 和 Welling 在 2013 年发表的文章 Auto-Encoding Variational Bayes 中提出的一种生成模型,它结合了变分推断和自动编码器的思想,通过最大化变分下界来学习数据的潜在表示,从而同时实现数据的生成和重构。
变分下界
这部分内容主要来自 Machine Learning: A Probabilistic Perspective 的第 11.4.7.1 节。
为了更好的理解 VAE1,我们首先理解变分下界的概念和原理2。假设我们有一组观测到的数据点 \({\{ {\bf{x}}_i \}}_{i=1}^{N}\sim p({\bf{x}})\),以及每个数据点 \(\bf{x}_i\) 对应的潜变量 \({\bf{z}_i}\sim q({\bf{z}})\)。那么观测数据的对数似然可以表示为:
\[\begin{align} \ell(\bm{\theta}) &= \sum_{i=1}^{N} \log p_{\bm{\theta}}({\bf{x}}_i) \\ &= \sum_{i=1}^{N} \log \left[ \sum_{ {\bf{z}}_i } p_{\bm{\theta}}({\bf{x}}_i, {\bf{z}}_i) \right] \\ &= \sum_{i=1}^{N} \log \left[ \sum_{ {\bf{z}}_i } q({\bf{z}}_i) \frac{ p_{\bm{\theta}}({\bf{x}}_i, {\bf{z}}_i) }{ q({\bf{z}}_i) } \right] \end{align}\]由于对数函数是一个凹函数,我们可以应用 Jensen 不等式来得到对数似然的下界:
\[\label{eq:variational_lower_bound} \ell(\bm{\theta}) \geq \sum_{i=1}^{N} \sum_{ {\bf{z}}_i } q({\bf{z}}_i) \log \left[ \frac{ p_{\bm{\theta}}({\bf{x}}_i, {\bf{z}}_i) }{ q({\bf{z}}_i) } \right] = \mathcal{Q}(\bm{\theta}, q)\]通过拆分分子和分母,我们可以把变分下界 \(\mathcal{Q}(\bm{\theta}, q)\) 用期望表示为:
\[\label{eq:variational_lower_bound_expectation} \mathcal{Q}(\bm{\theta}, q) = \sum_{i=1}^{N} \mathbb{E}_{ {\bf{z}}_i \sim q({\bf{z}}_i) } \left[ \log p_{\bm{\theta}}({\bf{x}}_i, {\bf{z}}_i) \right] + \mathbb{H}(q({\bf{z}}_i))\]其中 \(\mathbb{H}(q({\bf{z}}_i))\) 是分布 \(q({\bf{z}}_i)\) 的熵。基于式 \(\eqref{eq:variational_lower_bound}\) 我们对下界 \(\mathcal{Q}(\bm{\theta}, q)\) 的第 \(i\) 项进行进一步的拆分:
\[\label{eq:variational_lower_bound_decomposition} \begin{aligned} \mathcal{L}(\bm{\theta}, q) &= \sum_{ {\bf{z}}_i } q({\bf{z}}_i) \log \left[ \frac{ p_{\bm{\theta}}({\bf{x}}_i, {\bf{z}}_i) }{ q({\bf{z}}_i) } \right] \\ &= \sum_{ {\bf{z}}_i } q({\bf{z}}_i) \log \left[ \frac{ p_{\bm{\theta}}({\bf{z}}_i \lvert {\bf{x}}_i) p_{\bm{\theta}}({\bf{x}}_i) }{ q({\bf{z}}_i) } \right] \\ &= \sum_{ {\bf{z}}_i } q({\bf{z}}_i) \log p_{\bm{\theta}}({\bf{x}}_i) + \sum_{ {\bf{z}}_i } q({\bf{z}}_i) \log \left[ \frac{ p_{\bm{\theta}}({\bf{z}}_i \lvert {\bf{x}}_i) }{ q({\bf{z}}_i) } \right] \\ &= \log p_{\bm{\theta}}({\bf{x}}_i) - D_{KL} \left( q({\bf{z}}_i) \lVert p_{\bm{\theta}}({\bf{z}}_i \lvert {\bf{x}}_i) \right) \end{aligned}\]注意到 \(p({\bf{x}}_i)\) 与求和变量 \(\bf{z}_i\) 无关,因此可以提取到求和符号外面。上式表明,变分下界 \(\mathcal{L}(\bm{\theta}, q)\) 等于观测数据点 \(\bf{x}_i\) 的对数似然减去分布 \(q({\bf{z}}_i)\) 与后验分布 \(p_{\bm{\theta}}({\bf{z}}_i \lvert {\bf{x}}_i)\) 之间的 KL 散度。因为 \(p({\bf{x}}_i)\) 与分布 \(q\) 无关,所以最大化变分下界 \(\mathcal{L}(\bm{\theta}, q)\) 等价于最小化 KL 散度 \(D_{KL} \left( q({\bf{z}}_i) \lVert p_{\bm{\theta}}({\bf{z}}_i \lvert {\bf{x}}_i) \right)\)。换句话说,通过最大化变分下界,我们可以使得近似分布 \(q({\bf{z}}_i)\) 尽可能接近真实的后验分布 \(p_{\bm{\theta}}({\bf{z}}_i \lvert {\bf{x}}_i)\)。
与 EM 算法的联系
基于上面推导的变分下界,假如我们已知潜变量 \(\bf{z}_i\) 的真实后验分布 \(p_{\bm{\theta}}({\bf{z}}_i \lvert {\bf{x}}_i)\),那么我们可以选择 \(q^t({\bf{z}}_i) = p_{\bm{\theta}^t}({\bf{z}}_i \lvert {\bf{x}}_i)\) (在 EM 算法中,由于真实的 \(\bm{\theta}\) 是未知的,因此取的是上一步中 M step 的估计值 \(\bm{\theta}^t\)),此时 KL 散度为零,变分下界达到最大值 \(\log p_{\bm{\theta}}({\bf{x}}_i)\)。这与 EM 算法中的 E 步骤是一致的,在 E 步骤中我们计算潜变量的后验分布。然后把 \(q^t({\bf{z}}_i)\) 代入变分下界式 \(\eqref{eq:variational_lower_bound_expectation}\),得到:
\[\mathcal{Q}(\bm{\theta}, q^t) = \sum_{i=1}^{N} \mathbb{E}_{ {\bf{z}}_i \sim q^t({\bf{z}}_i) } \left[ \log p_{\bm{\theta}}({\bf{x}}_i, {\bf{z}}_i) \right] + \mathbb{H}(q^t({\bf{z}}_i))\]由于 \(\mathbb{H}(q^t({\bf{z}}_i))\) 与参数 \(\bm{\theta}\) 无关,因此可以忽略,因此 M step 可以简化为最大化期望项:
\[\bm{\theta}^{t+1} = \arg\max_{\bm{\theta}} \sum_{i=1}^{N} \mathbb{E}_{ {\bf{z}}_i \sim q^t({\bf{z}}_i) } \left[ \log p_{\bm{\theta}}({\bf{x}}_i, {\bf{z}}_i) \right]\]即在 M step 中我们最大化观测数据的对数似然。
概率图模型
VAE 可以看作是一个概率图模型 (Probabilistic Graphical Model, PGM),它假设数据点 \(\bf{x}\) 是由潜变量 \(\bf{z}\) 生成的,如下图所示:
图 1. VAE 考虑的概率图模型。
其中 \(\bm{\theta}\) 同时包含了随机变量 \(\bf{z}\) 分布的参数和随机变量 \(\bf{x}\) 分布的参数。实线表示生成模型 \(p_{\bm{\theta}}{(\bf{z})}p_{\bm{\theta}}({\bf{x}\lvert\bf{z}})\),虚线表示后验分布 \(p_{\bm{\theta}}({\bf{z}\lvert\bf{x}})\) 的变分近似 \(q_{\bm{\phi}}({\bf{z}\lvert\bf{x}})\)。
问题提出
基于上面的概率图模型,我们除了知道观测的数据点 \(\bf{x}\) 之外,并不知道真实参数 \(\bm{\theta}^*\),也不知道潜变量 \(\bf{z}\) 的真实分布。此外,这个模型还有下面两个问题:
-
不易处理性(Intractability): 积分得到的边际分布 \(p_{\bm{\theta}}({\bf{x}}) = \int p_{\bm{\theta}}({\bf{x}\lvert\bf{z}}) p_{\bm{\theta}}({\bf{z}}) \text{d}{\bf{z}}\) 通常是不可解的;此外,后验分布 \(p_{\bm{\theta}}({\bf{z}\lvert\bf{x}}) = \frac{ p_{\bm{\theta}}({\bf{x}\lvert\bf{z}}) p_{\bm{\theta}}(\bf{z}) }{ p_{\bm{\theta}}({\bf{x}}) }\) 也是不可解的,所以也无法应用 EM 算法。
-
大数据集: 特别是在大数据集上,使用蒙特卡洛 EM 方法需要对每个数据点进行大量采样,这会导致计算开销过大。
再看变分下界
VAE 引入一个模型 \(q_{\bm{\phi}}({\bf{z}\lvert\bf{x}})\) 来近似真实的后验分布 \(p_{\bm{\theta}}({\bf{z}\lvert\bf{x}})\),其中 \(\bm{\phi}\) 是该近似分布的参数,那么式 \(\eqref{eq:variational_lower_bound_decomposition}\) 中的分布 \(q({\bf{z}}_i)\) 就被替换为 \(q_{\bm{\phi}}({\bf{z}}_i \lvert {\bf{x}}_i)\) (下面简写为 \(q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i)\)),从而得到:
\[\log p_{\bm{\theta}}({\bf{x}}_i) = D_{KL} \left( q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) \lVert p_{\bm{\theta}}({\bf{z}} \lvert {\bf{x}}_i) \right) + \mathcal{L}(\bm{\theta}, \bm{\phi}; {\bf{x}}_i)\]回顾式 \(\eqref{eq:variational_lower_bound_expectation}\),VAE 把变分下界做进一步的变形:
\[\begin{aligned} \mathcal{L}(\bm{\theta}, \bm{\phi}; {\bf{x}}_i) &= \mathbb{E}_{ {\bf{z}} \sim q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) } \left[ \log p_{\bm{\theta}}({\bf{x}}_i, {\bf{z}}) \right] + \mathbb{H}(q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i)) \\ &= \sum_{ {\bf{z}} } q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) \log p_{\bm{\theta}}({\bf{x}}_i, {\bf{z}}) - \sum_{ {\bf{z}} } q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) \log q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) \\ &= \sum_{ {\bf{z}} } q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) \log \left[ \frac{ p_{\bm{\theta}}({\bf{z}}) p_{\bm{\theta}}({\bf{x}}_i \lvert {\bf{z}}) }{ q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) } \right] \\ &= \sum_{ {\bf{z}} } q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) \log p_{\bm{\theta}}({\bf{x}}_i \lvert {\bf{z}}) - \sum_{ {\bf{z}} } q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) \log \left[ \frac{ q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) }{ p_{\bm{\theta}}({\bf{z}}) } \right] \\ &= \mathbb{E}_{ {\bf{z}} \sim q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) } \left[ \log p_{\bm{\theta}}({\bf{x}}_i \lvert {\bf{z}}) \right] - D_{KL} \left( q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) \lVert p_{\bm{\theta}}({\bf{z}}) \right) \end{aligned}\]上式表明,变分下界 \(\mathcal{L}(\bm{\theta}, \bm{\phi}; {\bf{x}}_i)\) 包含两个部分:第一部分是对数似然的期望,表示在潜变量 \(\bf{z}\) 的近似后验分布 \(q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i)\) 下,观测数据点 \({\bf{x}}_i\) 的对数似然;第二部分是 KL 散度,表示近似后验分布 \(q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i)\) 与先验分布 \(p_{\bm{\theta}}({\bf{z}})\) 之间的差异。
根据前面的推导,VAE 的目的是最大化变分下界 \(\mathcal{L}(\bm{\theta}, \bm{\phi}; {\bf{x}}_i)\),从而使得近似后验分布 \(q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i)\) 尽可能接近真实的后验分布 \(p_{\bm{\theta}}({\bf{z}} \lvert {\bf{x}}_i)\)。通过最大化变分下界,我们可以同时优化生成模型的参数 \(\bm{\theta}\) 和近似后验分布的参数 \(\bm{\phi}\)。直接使用蒙特卡洛梯度估计会导致高方差的问题,因此 VAE 引入了重参数化技巧 (Reparameterization Trick) 来降低梯度估计的方差,从而更有效地训练模型。
损失计算
在实际应用中,VAE 的损失函数通常表示为变分下界的负值,我们把负号移到期望里面,有:
\[\mathcal{L}_{\text{VAE}}(\bm{\theta}, \bm{\phi}; {\bf{x}}_i) = \underbrace{ \mathbb{E}_{ {\bf{z}} \sim q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) } \left[ -\log p_{\bm{\theta}}({\bf{x}}_i \lvert {\bf{z}}) \right] }_{\text{重构损失 (Reconstruction Loss)}} + \underbrace{ D_{KL} \left( q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) \lVert p_{\bm{\theta}}({\bf{z}}) \right) }_{\text{正则化项 (Regularization Term)}}\]我们逐个分析:
重构损失: 该项衡量了在潜变量 \(\bf{z}\) 的近似后验分布 \(q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i)\) 下,生成的数据点 \(\bf{x}_i\) 与原始数据点之间的差异。我们假设观测数据的生成过程服从高斯分布 \(p_{\bm{\theta}}({\bf{x}}_i \lvert {\bf{z}}) = \mathcal{N}({\bf{x}}_i; \mu_{\bm{\theta}}({\bf{z}}), \sigma^2 I)\),那么重构损失可以简化为均方误差 (Mean Squared Error, MSE):
\[\mathbb{E}_{ {\bf{z}} \sim q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) } \left[ -\log p_{\bm{\theta}}({\bf{x}}_i \lvert {\bf{z}}) \right] \propto \mathbb{E}_{ {\bf{z}} \sim q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) } \left[ \lVert {\bf{x}}_i - \mu_{\bm{\theta}}({\bf{z}}) \rVert^2 \right]\]正则化项: 该项衡量了近似后验分布 \(q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i)\) 与先验分布 \(p_{\bm{\theta}}({\bf{z}})\) 之间的差异。通常我们假设先验分布为标准正态分布 \(p_{\bm{\theta}}({\bf{z}}) = \mathcal{N}({\bf{z}}; 0, I)\),而近似后验分布为高斯分布 \(q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) = \mathcal{N}({\bf{z}}; \mu_{\bm{\phi}}({\bf{x}}_i), \sigma_{\bm{\phi}}^2({\bf{x}}_i) I)\),那么 KL 散度可以通过解析计算得到:
\[D_{KL} \left( q_{\bm{\phi}}({\bf{z}} \lvert {\bf{x}}_i) \lVert p_{\bm{\theta}}({\bf{z}}) \right) = \frac{1}{2} \sum_{j=1}^{d} \left( \sigma_{\bm{\phi}, j}^2({\bf{x}}_i) + \mu_{\bm{\phi}, j}^2({\bf{x}}_i) - 1 - \log \sigma_{\bm{\phi}, j}^2({\bf{x}}_i) \right)\]其中 \(d\) 是潜变量 \(\bf{z}\) 的维度,\(\mu_{\bm{\phi}, j}({\bf{x}}_i)\) 和 \(\sigma_{\bm{\phi}, j}^2({\bf{x}}_i)\) 分别是近似后验分布在第 \(j\) 个维度上的均值和方差。
补充材料
苏剑林在 变分自编码器(二):从贝叶斯观点出发3 中从联合分布的角度推导了 VAE 的变分下界。