CVAE 条件变分自编码器
之前的VAE模型虽然具备了一定的生成能力,但是我们还没有实现输入给定信息生成内容的功能,而这需要条件变分自编码器。
CVAE的架构
CVAE与VAE的运算过程几乎一致,唯一的区别是在编码器输入中拼接上了one-hat标签。
one-hat标签
one-hat标签又称为独热标签,是一种把分类标量转化为二进制向量的方法,在数字分类问题中可以把0~1数字转为如下标签向量:
0 = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
1 = [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
2 = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
3 = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
......
图像一维化与拼合
对于单通道
权重共享
为了加快模型训练速度,可以将多个批次的一维化向量组合,得到二维矩阵,这样可以同时进行多个数据训练,其中的batch_size
作为超参数由初始化时定义,具体需要考虑计算资源。
若该例子中假设选定批次为
偏置
超参数
在CVAE模型中,有两个超参数:
batch_size
:数据集同时训练的批次数,如果该值为1,则为简单训练。同时训练n个批次将得到的输入矩阵。 z
:编码器输出的隐空间维数。
编码
和VAE模型一样,输入矩阵进行计算后得到
CVAE得到均值和方差后,从标准正态分布中采样得到随机变量,经过重参数化,得到隐空间采样值
解码
得到采样后的
上述解码器的架构流程如下:
损失函数
CVAE和VAE的损失函数计算公式完全一致,都是由重构损失+KL散度损失两部分组成,与VAE不同的时,CVAE在编码过程中就学习到了图像+标签到隐空间的映射,解码器学到了隐空间采样+标签到图像的映射。
经过数据集训练后,模型的解码器具备了采样+标签将生成与标签对应的图像的能力,也就是给定条件生成内容的能力。
CVAE 的损失函数在优化时,因为编码器和解码器都引入了标签信息,所以重构损失会隐性地受到标签的约束,生成的图像不仅要和原图像素匹配,还要符合标签对应的类别,这是 CVAE “条件生成” 能力的核心来源。
关于损失函数的具体计算,可以参见VAE部分。