离散扩散模型
不同于位置坐标等连续特征,分子的原子类型和键类型的类别变量需要特殊的离散扩散公式来是实现,其核心通过可控的概率转移完成前向加噪,并借助反向后验分布实现迭代去噪。
离散扩散是分子生成模型的主要组件之一,本文将从基础概率定义、转移矩阵设计、后验公式推导,到代码实现与损失函数设计,完整梳理离散扩散的数学原理与工程逻辑,帮助理解其在结构生成任务中的内在机制与实现细节。
离散变量的概率分布
离散随机变量
注意到,随机变量
离散扩散
离散情况下,对于有
例如,矩阵的第一行表示
式中
这样
一步加噪
和连续扩散一样,在前向加噪过程中需要根据选择加噪的时间步
离散扩散的前向过程满足马尔可夫性,即
用转移矩阵表示,单步转移为:
其中
根据马尔可夫链的性质,多步转移可以通过矩阵连乘得到:
公式描述了原始样本
反向去噪
根据贝叶斯定理,可以得到已知当前含噪样本
离散扩散的前向过程满足马尔可夫性:
把前面得出的公式代入分子分母中得到,分母项:
对于分子第一项的单步转移概率
最终把三个式子带入分子分母,得到:
反向过程的目的是从含噪的
:从“当前含噪样本 ”反向推,哪些 最可能生成 (比如:现在看到的是模糊的含噪图像,猜它上一步可能是哪些不那么模糊的图像); :从“原始样本 ”正向推,哪些 最可能是 加噪一步的结果(比如:原始清晰图像,加噪一步后最可能变成哪些样子); (哈达玛积):把上面两个“猜测”结合起来,取两者都认可的、最可能的 。 :分子的“猜测”是“未标准化”的(概率和不一定等于1),而概率的核心要求是“所有可能的结果加起来等于1”。
整个公式本质就是:先结合“正向加噪”和“反向去噪”的信息,初步猜测
模型训练
在上式中,我们不知道真实的
- 前向加噪:随机选一个时间步
,把干净数据 加噪成 ; - 模型预测:让模型输入
和时间步 ,输出对 分布的预测; - 计算误差:将模型的预测结果与公式计算出的“真实
分布”做对比不断调优。
整个优化过程的损失可表述为:
即让模型预测的分布
采样过程
训练好模型后即可从含噪样本还原出干净数据:
- 初始化:从一个完全随机的含噪样本
开始( 是最大时间步,此时样本最模糊); - 逐步去噪:从
开始,一步步往 倒推——每一步都用模型 ,根据当前的 ,预测出更干净的 ; - 得到结果:当
时,我们就得到了还原后的干净样本 ,完成去噪。
代码实现
接下来重点讲解一下离散扩散类的代码实现,它将接收
class GeneralCategoricalTransition(nn.Module):
def __init__(self, betas, num_classes, init_prob=None):
super().__init__()
# 加微小值,防止概率值趋于0时的梯度爆炸
self.eps = 1e-30
# 类别数
self.num_classes = num_classes
# 初始类别分布
# 默认均匀分布
if init_prob is None:
self.init_prob = np.ones(num_classes) / num_classes
# 第一个类别概率接近 1,其他接近 0;
elif init_prob == 'absorb':
init_prob = 0.01 * np.ones(num_classes)
init_prob[0] = 1
self.init_prob = init_prob / np.sum(init_prob)
# 最后一个类别概率接近 1,其他接近 0;
elif init_prob == 'tomask':
init_prob = 0.001 * np.ones(num_classes)
init_prob[-1] = 1.
self.init_prob = init_prob / np.sum(init_prob)
# 所有类别概率相等;
elif init_prob == 'uniform':
self.init_prob = np.ones(num_classes) / num_classes
# 将用户传入的自定义数组归一化
else:
self.init_prob = init_prob / np.sum(init_prob)
# betas参数
self.betas = betas
# 总时间步数
self.num_timesteps = len(betas)
# q(x_t | x_{t-1})
q_one_step_mats = [self._get_transition_mat(t) for t in range(0, self.num_timesteps)]
q_one_step_mats = np.stack(q_one_step_mats, axis=0) # (T, K, K)
# q(x_t | x_0)
q_mat_t = q_one_step_mats[0]
q_mats = [q_mat_t]
# 计算累积转移矩阵
for t in range(1, self.num_timesteps):
q_mat_t = np.tensordot(q_mat_t, q_one_step_mats[t], axes=[[1], [0]])
q_mats.append(q_mat_t)
q_mats = np.stack(q_mats, axis=0)
transpopse_q_onestep_mats = np.transpose(q_one_step_mats, axes=[0, 2, 1])
# 转成张量,绑定到实例self
self.q_mats = to_torch_const(q_mats)
self.transpopse_q_onestep_mats = to_torch_const(transpopse_q_onestep_mats)概率转移矩阵
初始化概率转移矩阵,输入参数是时间步t,返回该时刻的概率转移矩阵。
若不设置初始概率,则按标准均匀扩散构造转移矩阵,对角线元素以大概率保持不变,其他非对角线元素以小概率进行类别转移:
若设置了初始概率,先构建初始转移概率矩阵
# 计算概率转移矩阵
def _get_transition_mat(self, t):
"""
计算t时刻的概率转移矩阵q(x_t|x_{t-1})
Args:
t: 时间步
Returns:
Q_t: 当前时间步的N阶方阵,表示转移矩阵
"""
# 获取当前时间步的噪声参数beta
beta_t = self.betas[t]
# 无初始概率分支
if self.init_prob is None:
# 所有元素初始化为 “跨类别转移概率”
mat = np.full(
shape=(self.num_classes, self.num_classes),
fill_value=beta_t/float(self.num_classes),
dtype=np.float64
)
# 获取对角线索引
diag_indices = np.diag_indices_from(mat)
# 对角线元素的自转移概率
diag_val = 1. - beta_t * (self.num_classes-1.)/self.num_classes
# 替换对角线元素为自转移概率
mat[diag_indices] = diag_val
# 有初始概率分支
else:
# 初始概率扩充维度,得到二维矩阵
mat = np.repeat(np.expand_dims(self.init_prob, 0), self.num_classes, axis=0)
# 乘以噪声系数
mat = beta_t * mat
# 创建对角线为1 - beta_t、其余为 0 的矩阵
mat_diag = np.eye(self.num_classes) * (1. - beta_t)
# 两个矩阵相加
mat = mat + mat_diag
return mat计算条件概率分布
从初始分布的对数概率分布和时间步
# 从初始分布的对数概率和时间步t计算条件概率分布
def q_vt_pred(self, log_v0, time_step):
"""
计算条件概率分布 q(v_t | v_0)
Args:
log_v0: 原始类别对数概率分布 [num_nodes, num_classes]
time_step: 节点级时间步,形状 [num_nodes,]
Returns:
log_q_vt: 加噪后分布的对数概率,形状 [num_nodes, num_classes]
"""
# 得到每个节点对应的累计概率转移矩阵,[t, K, K]
qt_mat = self.q_mats[time_step]
# 对应时间的向量乘以对应时间的累计转移矩阵,[t, 1, K]
q_vt = torch.einsum('...i,...ij->...j', log_v0.exp(), qt_mat)
# 转化为对数概率并约束最小值
return torch.log(q_vt + self.eps).clamp_min(-32.)从对数概率分布中采样
输入离散分布的对数概率,从中采样得到一组离散随机变量值(返回独热编码和对数独热编码)。
首先生成随机均匀噪声,并转换为标准 Gumbel 噪声,将噪声加在原始对数概率上,引入随机性。随后将概率最大的类别作为采样结果,得到独热编码和对数独热编码,完成离散变量的采样。
def sample_log(self, logits):
"""
从对数概率分布中采样
Args:
logits: 输入对数概率分布,形状 [num_nodes, num_classes]
Return:
onehot: 采样后的独热编码
log_onehot: 采样后的对数独热编码
"""
# 生成0-1的均匀噪声
uniform = torch.rand_like(logits)
# 转换为Gumbel噪声
gumbel_noise = -torch.log(-torch.log(uniform + self.eps) + self.eps)
# 加噪声后的分值
noisy_logits = gumbel_noise + logits
# 选最大值对应的类别索引
sample_idx = noisy_logits.argmax(dim=-1)
# 索引→独热编码(值为0/1,形状与输入一致)
onehot = F.one_hot(sample_idx, num_classes=self.num_classes).float()
# 独热编码→对数概率
log_onehot = torch.log(onehot.clamp(min=self.eps)).clamp_min(-32.0)
return onehot, log_onehot加噪
对输入的独热向量加噪,输入的是独热向量的对数概率形式,输出加噪后的独热编码和对数独热编码。计算方法是先计算条件概率分布,随后从该分布中采样。
# 加噪
def add_noise(self, log_v0, time_step):
"""
加噪函数
Args:
log_v0: 时间步t0的对数独热向量
time_step: 要加噪的时间步,形状[num_nodes,]
Return:
onehot_t: 扰动后的独热编码特征 [num_nodes, num_classes]
log_onehot_t: 扰动后的对数独热编码 [num_nodes, num_classes]
"""
# 计算条件概率分布 q(v_t | v_0)
log_q_vt_v0 = self.q_vt_pred(log_v0, time_step)
# 采样
onehot_t, log_onehot_t = self.sample_log(log_q_vt_v0)
return onehot_t, log_onehot_t后验分布
根据原始数据数据
代码中fc1对应反向转移概率
# 离散扩散模型后验分布的核心计算函数
def q_v_posterior(self, log_v0, log_vt, time_step, v0_prob:bool):
"""
计算离散变量的后验概率分布
Args:
log_v0: 初始类别 v0 的对数概率分布
log_vt: 目标类别 vt 的对数概率分布
time_step: 对应的时间步
v0_prob: 是否有初始概率
Returns:
out: 后验分布的对数概率
"""
# 计算得到t-1的时间步
t_minus_1 = torch.clamp(time_step - 1, min=0)
# 计算反向转移概率 fact1
fact1 = self.transpopse_q_onestep_mats[time_step]
fact1 = torch.einsum('...i,...ij->...j', log_vt.exp(), fact1)
# 计算前向先验概率 fact2
fact2 = self.q_mats[t_minus_1]
if not v0_prob:
class_v0 = log_v0.argmax(dim=-1)
fact2 = fact2[torch.arange(len(class_v0)), class_v0]
else:
fact2 = torch.einsum('...i,...ij->...j', log_v0.exp(), fact2)
# 合并概率并归一化
log_prob = torch.log(fact1 + self.eps).clamp_min(-32.) + torch.log(fact2 + self.eps).clamp_min(-32.)
log_prob = log_prob - torch.logsumexp(log_prob, dim=-1, keepdim=True)
# 如果 t=0,直接返回原始数据 v0
out = torch.where(time_step.unsqueeze(-1) == 0, log_v0, log_prob)
return out损失函数
当
# 离散扩散模型损失函数
def compute_v_Lt(self, log_v_post_true, log_v_post_pred, log_v0, time_step):
"""
计算离散扩散模型的损失函数
Args:
log_v_post_true: 真实后验分布的对数概率,形状 [num_nodes, num_classes]
log_v_post_pred: 预测后验分布的对数概率,形状 [num_nodes, num_classes]
log_v0: 原始干净数据的对数概率,形状 [num_nodes, num_classes]
time_step: 节点级时间步,形状 [num_nodes,]
Returns:
loss_v: 掩码加权后的合并损失,形状与输入一致
"""
# 计算后验分布的 KL 散度损失
kl_v = (log_v_post_true.exp() * (log_v_post_true - log_v_post_pred)).sum(dim=-1)
# 计算解码器的负对数似然损失
decoder_nll_v = - (log_v0.exp() * log_v_post_pred).sum(dim=-1)
# 生成时间步掩码,区分t0/t
mask = (time_step == 0).float()
# 合并损失,t=0用NLL,t≠0用KL损失
loss_v = mask * decoder_nll_v + (1 - mask) * kl_v
return loss_v采样初始化
输入采样个数n,返回n个离散变量的初始化值。同样也是遵循先生成概率分布,再使用sample_log方法采样的流程,计算得到初始化的类型、初始化独热编码和初始化对数独热。
def sample_init(self, n):
"""
离散扩散的采样初始化
Args:
n: 批次大小
Returns:
init_types: 初始类别
init_onehot: 初始类别的独热编码
init_log_onehot: 初始类别的对数独热编码
"""
# 加载初始概率分布并转为对数概率
init_log_atom_vt = torch.log(
torch.from_numpy(self.init_prob)+self.eps
).clamp_min(-32.).to(self.q_mats.device)
# 拓展到n个样本
init_log_atom_vt = init_log_atom_vt.unsqueeze(0).repeat(n, 1)
init_onehot, init_log_onehot= self.sample_log(init_log_atom_vt)
init_types = init_onehot.argmax(dim=-1)
return init_types, init_onehot, init_log_onehot计算流程
在实际的应用中,训练流通常遵循以下流程:
- 初始化GeneralCategoricalTransition类,执行 _get_transition_mat() 方法,预计算转移矩阵和累计转移矩阵;
- 将离散索引转为对数独热编码,调用 transition.add_noise() 方法加噪;
- 神经网络预测得到类别向量,进行 softmax 归一化,得到重构的对数独热编码;
- 调用 transition.q_v_posterior() 方法计算真实类别的后验概率分布和模型输出的重构后验概率分布;
- 调用 transition.compute_v_Lt() 方法计算KL散度,反向传播更新参数。
而采样流通常遵循以下步骤:
- 调用 transition.sample_init() 方法得到初始化的对数独热编码,输入神经网络;
- 神经网络得到
步的的预测分布,进行 log_softmax() 归一化得到对数独热编码; - transition.q_v_posterior() 计算后验概率分布,transition.sample_log() 从后验概率分布中采样得到
步的独热编码类型; - 以此往复迭代T步。