Skip to content

CVAE 条件变分自编码器

之前的VAE模型虽然具备了一定的生成能力,但是我们还没有实现输入给定信息生成内容的功能,而这需要条件变分自编码器。

CVAE的架构

CVAE与VAE的运算过程几乎一致,唯一的区别是在编码器输入中拼接上了one-hat标签。

one-hat标签

one-hat标签又称为独热标签,是一种把分类标量转化为二进制向量的方法,在数字分类问题中可以把0~1数字转为如下标签向量:

bash
0 = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
1 = [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
2 = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
3 = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
......

图像一维化与拼合

对于单通道28×28图像,将其一维化为784大小的一维向量,随后将其和one-hat标签进行拼接,得到794的一维向量。

权重共享

为了加快模型训练速度,可以将多个批次的一维化向量组合,得到二维矩阵,这样可以同时进行多个数据训练,其中的batch_size作为超参数由初始化时定义,具体需要考虑计算资源。

若该例子中假设选定批次为n,则拼接得到的[n,794]输入矩阵的全连接层将通过Z=XW+b来计算,这里的W是权重矩阵,b是偏置矩阵,X=[x1,x2,,xn],xi=[794]的每一行数据都将用同一个权重矩阵W训练,这样可以减少训练量,避免过拟合。

偏置b是一个一维向量[z]XW[n,794]矩阵,二者不能进行计算,因此会经过PyTorch的广播机制扩展复制n份得到[n,z]矩阵。

超参数

在CVAE模型中,有两个超参数:

  1. batch_size:数据集同时训练的批次数,如果该值为1,则为简单训练。同时训练n个批次将得到[n,794]的输入矩阵。
  2. z:编码器输出的隐空间维数。

编码

和VAE模型一样,输入矩阵进行计算后得到[n,z]大小的均值矩阵和方差矩阵,n是批次数,z是隐空间维数。

CVAE得到均值和方差后,从标准正态分布中采样得到随机变量,经过重参数化,得到隐空间采样值z=μ+σe,此时z是一个[n,z]矩阵,具体过程如下图所示:

解码

得到采样后的[n,z]矩阵再次与one-hat矩阵拼接,这里是[n,10]矩阵,拼接得到[n,z+10]矩阵,该矩阵将作为解码器网络的输入,经过前向计算后最终的解码器将输出[n,784]的矩阵,其中每一行将代表一个一维化的图像结果,将它展开得到28×28的图像。

上述解码器的架构流程如下:

损失函数

CVAE和VAE的损失函数计算公式完全一致,都是由重构损失+KL散度损失两部分组成,与VAE不同的时,CVAE在编码过程中就学习到了图像+标签到隐空间的映射,解码器学到了隐空间采样+标签到图像的映射。

经过数据集训练后,模型的解码器具备了采样+标签将生成与标签对应的图像的能力,也就是给定条件生成内容的能力。

CVAE 的损失函数在优化时,因为编码器和解码器都引入了标签信息,所以重构损失会隐性地受到标签的约束,生成的图像不仅要和原图像素匹配,还要符合标签对应的类别,这是 CVAE “条件生成” 能力的核心来源。

关于损失函数的具体计算,可以参见VAE部分。