Skip to content

离散扩散模型

不同于位置坐标等连续特征,分子的原子类型和键类型的类别变量需要特殊的离散扩散公式来是实现,其核心通过可控的概率转移完成前向加噪,并借助反向后验分布实现迭代去噪。

离散扩散是分子生成模型的主要组件之一,本文将从基础概率定义、转移矩阵设计、后验公式推导,到代码实现与损失函数设计,完整梳理离散扩散的数学原理与工程逻辑,帮助理解其在结构生成任务中的内在机制与实现细节。

离散变量的概率分布

Cat(x,p)表示离散随机变量x服从以向量p为概率参数的类别分布,即离散随机变量x取某个类别值的概率和p有关,例如:

离散随机变量x,x[1,2,3,4]四个类别,取某个类别的概率是p=[0.25,0.25,0.25,0.25],这里p是均匀分布,也就是x取类别1、2、3、4的概率都是0.25

注意到,随机变量x只能取有限个离散值(类别索引),不能取连续值,且p的概率和为1,每个类别的概率是互斥的。

离散扩散

离散情况下,对于有K个类别的离散随机变量xt,xt1[1,K],前向转移概率可以表示成矩阵[Qt]i,j=q(xt=j|xt1=i), 其中矩阵中的每一个元素Qi,j代表从状态 i到状态j的概率。

Qt=[x0,0x0,1x0,Kx1,0x1,1x1,KxK,0xK,1xK,K]

例如,矩阵的第一行表示xt1=0,(类别0)时,xt属于各类别的概率,这时所有概率值相加必须为1。用条件概率q(xt|xt1)表示xt1已知的情况下,xt的概率分布:

q(xt|xt1)=Cat(xt;p=xt1Qt)

式中 Cat() 是分类分布, p 是形状为[1,K]的概率向量。此时的xt1是独热编码向量[0,0,1,,0],第i个类别索引是1,其余为0,那么p的取值就是转移矩阵的第i行:

p=[0,0,1,,0]×[x0,0x0,1x0,Kx1,0x1,1x1,KxK,0xK,1xK,K]=[xi,0,xi,1,,xi,K]

这样q(xt|xt1)就描述了此时的xt所属的类别概率。

一步加噪

和连续扩散一样,在前向加噪过程中需要根据选择加噪的时间步 t,从初始值 x0 一步得到加噪后样本分布 q(xt|x0)

离散扩散的前向过程满足马尔可夫性,即 xt 仅依赖上一步的 xt1,与更早的状态无关:

q(x1:T|x0)=i=1Tq(xi|xi1)

用转移矩阵表示,单步转移为:

q(xt|xt1)=Cat(xt;p=xt1Qt)

其中 QtK×K 的转移概率矩阵(K 为类别数),(Qt)ij=q(xt=j|xt1=i)

根据马尔可夫链的性质,多步转移可以通过矩阵连乘得到:

q(xt|x0)=Cat(xt;p=x0Q¯t)Q¯t=Q1Q2Qt=i=1tQi

公式描述了原始样本x0与第t步的累计概率转移矩阵相乘,得到前向加噪样本xt。反向过程则将以xt,t为条件,学习预测反向转移分布p(xt1|xt)或直接恢复原始分布p(x0|xt)

反向去噪

根据贝叶斯定理,可以得到已知当前含噪样本xt和初始样本x0时,前一步样本xt1的概率分布:

q(xt1|xt,x0)=q(xt|xt1,x0)q(xt1|x0)q(xt|x0)

离散扩散的前向过程满足马尔可夫性:xt仅依赖上一步的xt1,与更早的x0无关,即:q(xt|xt1,x0)=q(xt|xt1),这时公式可以变成:

q(xt1|xt,x0)=q(xt|xt1)q(xt1|x0)q(xt|x0)

把前面得出的公式代入分子分母中得到,分母项:q(xt|x0)=Cat(xt;p=x0Qt);分子第二项:q(xt1|x0)=Cat(xt1;p=x0Qt1)

对于分子第一项的单步转移概率q(xt|xt1)xt1是未知项,xt是已知项,推理方向与前向过程相反。为了从当前状态xt反推出上一步可能的状态xt1,需要将转移概率矩阵转置,用QtT代替Qt,也就是q(xt|xt1)=xtQtT。因为:QtT的第k行等价于Qt的第k列,相当于已知结果求可能的来源。

最终把三个式子带入分子分母,得到:

q(xt1|xt,x0)=Cat(xt1;p=xtQtTx0Qt1x0QtxtT)

反向过程的目的是从含噪的xt,一步步倒推回前一步的xt1,重复这个过程,最终还原出原始数据x0,而上式告诉我们:已知当前含噪样本xt和原始样本x0时,上一步的干净样本xt1的最大概率

  • xtQtT:从“当前含噪样本xt”反向推,哪些xt1最可能生成xt(比如:现在看到的是模糊的含噪图像,猜它上一步可能是哪些不那么模糊的图像);
  • x0Qt1:从“原始样本x0”正向推,哪些xt1最可能是x0加噪一步的结果(比如:原始清晰图像,加噪一步后最可能变成哪些样子);
  • (哈达玛积):把上面两个“猜测”结合起来,取两者都认可的、最可能的xt1
  • x0QtxtT:分子的“猜测”是“未标准化”的(概率和不一定等于1),而概率的核心要求是“所有可能的结果加起来等于1”。

整个公式本质就是:先结合“正向加噪”和“反向去噪”的信息,初步猜测xt1的可能情况,再通过校准,得到xt1的真实概率分布

模型训练

在上式中,我们不知道真实的x0(如果知道x0,就不用去噪了)。因此我们需要利用神经网络来拟合一个近似分布pθ(xt1|xt),让它尽可能接近真实的反向分布q(xt1|xt,x0),训练流程如下:

  1. 前向加噪:随机选一个时间步t,把干净数据x0加噪成xt
  2. 模型预测:让模型输入xt和时间步t,输出对xt1分布的预测;
  3. 计算误差:将模型的预测结果与公式计算出的“真实xt1分布”做对比不断调优。

整个优化过程的损失可表述为:

L=Ex0,xt[DKL(q(xt1|xt,x0)pθ(xt1|xt))]

即让模型预测的分布 pθ 和真实分布 q 尽可能相似,差距越小,模型学得越好。

采样过程

训练好模型后即可从含噪样本还原出干净数据:

  1. 初始化:从一个完全随机的含噪样本xT开始(T是最大时间步,此时样本最模糊);
  2. 逐步去噪:从t=T开始,一步步往t=0倒推——每一步都用模型pθ(xt1|xt),根据当前的xt,预测出更干净的xt1
  3. 得到结果:当t=0时,我们就得到了还原后的干净样本x^0,完成去噪。

代码实现

接下来重点讲解一下离散扩散类的代码实现,它将接收β参数、类别数和初始概率类型,对外提供一系列处理离散概率的方法。

python
class GeneralCategoricalTransition(nn.Module):
    def __init__(self, betas, num_classes, init_prob=None):
        super().__init__()
        # 加微小值,防止概率值趋于0时的梯度爆炸
        self.eps = 1e-30
        # 类别数
        self.num_classes = num_classes

        # 初始类别分布
        # 默认均匀分布
        if init_prob is None:
            self.init_prob = np.ones(num_classes) / num_classes
        # 第一个类别概率接近 1,其他接近 0;
        elif init_prob == 'absorb':
            init_prob = 0.01 * np.ones(num_classes)
            init_prob[0] = 1
            self.init_prob = init_prob / np.sum(init_prob)
        # 最后一个类别概率接近 1,其他接近 0;
        elif init_prob == 'tomask':
            init_prob = 0.001 * np.ones(num_classes)
            init_prob[-1] = 1.
            self.init_prob = init_prob / np.sum(init_prob)
        # 所有类别概率相等;
        elif init_prob == 'uniform':
            self.init_prob = np.ones(num_classes) / num_classes
        # 将用户传入的自定义数组归一化
        else:
            self.init_prob = init_prob / np.sum(init_prob)
        
        # betas参数
        self.betas = betas
        # 总时间步数
        self.num_timesteps = len(betas)
        
        # q(x_t | x_{t-1})
        q_one_step_mats = [self._get_transition_mat(t) for t in range(0, self.num_timesteps)]
        q_one_step_mats = np.stack(q_one_step_mats, axis=0)  # (T, K, K)
        
        # q(x_t | x_0)
        q_mat_t = q_one_step_mats[0]
        q_mats = [q_mat_t]

        # 计算累积转移矩阵
        for t in range(1, self.num_timesteps):
            q_mat_t = np.tensordot(q_mat_t, q_one_step_mats[t], axes=[[1], [0]])
            q_mats.append(q_mat_t)
        q_mats = np.stack(q_mats, axis=0)
        
        transpopse_q_onestep_mats = np.transpose(q_one_step_mats, axes=[0, 2, 1])
        
        # 转成张量,绑定到实例self
        self.q_mats = to_torch_const(q_mats)
        self.transpopse_q_onestep_mats = to_torch_const(transpopse_q_onestep_mats)

概率转移矩阵

初始化概率转移矩阵,输入参数是时间步t,返回该时刻的概率转移矩阵。

若不设置初始概率,则按标准均匀扩散构造转移矩阵,对角线元素以大概率保持不变,其他非对角线元素以小概率进行类别转移:

Qt(i,j)={βtK,ij1βtK1K,i=j

若设置了初始概率,先构建初始转移概率矩阵Q0,随后得到第t步转移矩阵Qt,也就是有βt的概率按初始概率转移到指定类型,有1βt的概率保持不变:

Qt=(1βt)I+βtQ0T
python
    # 计算概率转移矩阵
    def _get_transition_mat(self, t):
        """
        计算t时刻的概率转移矩阵q(x_t|x_{t-1})

        Args:
            t: 时间步
        Returns:
            Q_t: 当前时间步的N阶方阵,表示转移矩阵
        """

        # 获取当前时间步的噪声参数beta
        beta_t = self.betas[t]
        # 无初始概率分支
        if self.init_prob is None:
            # 所有元素初始化为 “跨类别转移概率”
            mat = np.full(
                shape=(self.num_classes, self.num_classes),
                fill_value=beta_t/float(self.num_classes),
                dtype=np.float64
                )
            # 获取对角线索引
            diag_indices = np.diag_indices_from(mat)
            # 对角线元素的自转移概率
            diag_val = 1. - beta_t * (self.num_classes-1.)/self.num_classes
            # 替换对角线元素为自转移概率
            mat[diag_indices] = diag_val
        # 有初始概率分支
        else:
            # 初始概率扩充维度,得到二维矩阵
            mat = np.repeat(np.expand_dims(self.init_prob, 0), self.num_classes, axis=0)
            # 乘以噪声系数
            mat = beta_t * mat
            # 创建对角线为1 - beta_t、其余为 0 的矩阵
            mat_diag = np.eye(self.num_classes) * (1. - beta_t)
            # 两个矩阵相加
            mat = mat + mat_diag
        return mat

计算条件概率分布

从初始分布的对数概率分布和时间步t计算条件概率分布 q(xt|x0)=x0Q¯t,需要先得到每个节点(变量)的累计转移矩阵,再将离散变量的原始概率与累计转移矩阵相乘,得到条件概率,最后再对数化输出为对数概率。

python
    # 从初始分布的对数概率和时间步t计算条件概率分布
    def q_vt_pred(self, log_v0, time_step):
        """
        计算条件概率分布 q(v_t | v_0)

        Args:
            log_v0:     原始类别对数概率分布 [num_nodes, num_classes]
            time_step:  节点级时间步,形状 [num_nodes,]
        Returns:
            log_q_vt:   加噪后分布的对数概率,形状 [num_nodes, num_classes]
        """
        # 得到每个节点对应的累计概率转移矩阵,[t, K, K]
        qt_mat = self.q_mats[time_step]
        
        # 对应时间的向量乘以对应时间的累计转移矩阵,[t, 1, K]
        q_vt = torch.einsum('...i,...ij->...j', log_v0.exp(), qt_mat)
        
        # 转化为对数概率并约束最小值
        return torch.log(q_vt + self.eps).clamp_min(-32.)

从对数概率分布中采样

输入离散分布的对数概率,从中采样得到一组离散随机变量值(返回独热编码和对数独热编码)。

首先生成随机均匀噪声,并转换为标准 Gumbel 噪声,将噪声加在原始对数概率上,引入随机性。随后将概率最大的类别作为采样结果,得到独热编码和对数独热编码,完成离散变量的采样。

python
    def sample_log(self, logits):
        """
        从对数概率分布中采样

        Args:
            logits:      输入对数概率分布,形状 [num_nodes, num_classes]
        Return:
            onehot:      采样后的独热编码
            log_onehot:  采样后的对数独热编码
        """
        # 生成0-1的均匀噪声
        uniform = torch.rand_like(logits)
        # 转换为Gumbel噪声
        gumbel_noise = -torch.log(-torch.log(uniform + self.eps) + self.eps)
        # 加噪声后的分值
        noisy_logits = gumbel_noise + logits
        # 选最大值对应的类别索引
        sample_idx = noisy_logits.argmax(dim=-1)
        # 索引→独热编码(值为0/1,形状与输入一致)
        onehot = F.one_hot(sample_idx, num_classes=self.num_classes).float()
        # 独热编码→对数概率
        log_onehot = torch.log(onehot.clamp(min=self.eps)).clamp_min(-32.0)
        
        return onehot, log_onehot

加噪

对输入的独热向量加噪,输入的是独热向量的对数概率形式,输出加噪后的独热编码和对数独热编码。计算方法是先计算条件概率分布,随后从该分布中采样。

python
    # 加噪
    def add_noise(self, log_v0, time_step):
        """
        加噪函数

        Args:
            log_v0:         时间步t0的对数独热向量
            time_step:      要加噪的时间步,形状[num_nodes,]
        Return:
            onehot_t:       扰动后的独热编码特征 [num_nodes, num_classes]
            log_onehot_t:   扰动后的对数独热编码 [num_nodes, num_classes]
        """
        # 计算条件概率分布 q(v_t | v_0)
        log_q_vt_v0 = self.q_vt_pred(log_v0, time_step)
        # 采样
        onehot_t, log_onehot_t = self.sample_log(log_q_vt_v0)
        
        return onehot_t, log_onehot_t

后验分布

根据原始数据数据x0和含噪数据xt,计算出上一步数据xt1的概率分布,计算遵循前文的反向去噪公式

代码中fc1对应反向转移概率xtQtT,fc2对应前向先验概率x0Qt1,将其转为对数概率后相加(即逐元素相乘),再使用logsumexp进行归一化。

python
    # 离散扩散模型后验分布的核心计算函数
    def q_v_posterior(self, log_v0, log_vt, time_step, v0_prob:bool):
        """
        计算离散变量的后验概率分布
        Args:
            log_v0:     初始类别 v0 的对数概率分布
            log_vt:     目标类别 vt 的对数概率分布
            time_step:  对应的时间步
            v0_prob:    是否有初始概率
        Returns:
            out: 后验分布的对数概率
        """
        # 计算得到t-1的时间步
        t_minus_1 = torch.clamp(time_step - 1, min=0)
        # 计算反向转移概率 fact1
        fact1 = self.transpopse_q_onestep_mats[time_step]
        fact1 = torch.einsum('...i,...ij->...j', log_vt.exp(), fact1)

        # 计算前向先验概率 fact2
        fact2 = self.q_mats[t_minus_1]
        if not v0_prob:
            class_v0 = log_v0.argmax(dim=-1)
            fact2 = fact2[torch.arange(len(class_v0)), class_v0]
        else:
            fact2 = torch.einsum('...i,...ij->...j', log_v0.exp(), fact2)
        # 合并概率并归一化
        log_prob = torch.log(fact1 + self.eps).clamp_min(-32.) + torch.log(fact2 + self.eps).clamp_min(-32.)
        log_prob = log_prob - torch.logsumexp(log_prob, dim=-1, keepdim=True)
        # 如果 t=0,直接返回原始数据 v0​
        out = torch.where(time_step.unsqueeze(-1) == 0, log_v0, log_prob)
        return out

损失函数

t>0 时,使用 KL 散度损失,约束模型预测的反向去噪分布逼近理论真实后验分布;当 t=0 时,使用 负对数似然损失(NLL),直接监督模型精准重建原始离散数据。

python
    # 离散扩散模型损失函数
    def compute_v_Lt(self, log_v_post_true, log_v_post_pred, log_v0, time_step):
            """
            计算离散扩散模型的损失函数

            Args:
                log_v_post_true:    真实后验分布的对数概率,形状 [num_nodes, num_classes]
                log_v_post_pred:    预测后验分布的对数概率,形状 [num_nodes, num_classes]
                log_v0:             原始干净数据的对数概率,形状 [num_nodes, num_classes]
                time_step:          节点级时间步,形状 [num_nodes,]
            Returns:
                loss_v: 掩码加权后的合并损失,形状与输入一致
            """
            # 计算后验分布的 KL 散度损失
            kl_v = (log_v_post_true.exp() * (log_v_post_true - log_v_post_pred)).sum(dim=-1)
            # 计算解码器的负对数似然损失
            decoder_nll_v = - (log_v0.exp() * log_v_post_pred).sum(dim=-1)
            # 生成时间步掩码,区分t0/t
            mask = (time_step == 0).float()
            # 合并损失,t=0用NLL,t≠0用KL损失
            loss_v = mask * decoder_nll_v + (1 - mask) * kl_v
            return loss_v

采样初始化

输入采样个数n,返回n个离散变量的初始化值。同样也是遵循先生成概率分布,再使用sample_log方法采样的流程,计算得到初始化的类型、初始化独热编码和初始化对数独热。

python
    def sample_init(self, n):
        """
        离散扩散的采样初始化

        Args:
            n:                  批次大小
        Returns:
            init_types:         初始类别
            init_onehot:        初始类别的独热编码
            init_log_onehot:    初始类别的对数独热编码
        """
        # 加载初始概率分布并转为对数概率
        init_log_atom_vt = torch.log(
            torch.from_numpy(self.init_prob)+self.eps
        ).clamp_min(-32.).to(self.q_mats.device)
        # 拓展到n个样本
        init_log_atom_vt = init_log_atom_vt.unsqueeze(0).repeat(n, 1)
        
        init_onehot, init_log_onehot= self.sample_log(init_log_atom_vt)
        init_types = init_onehot.argmax(dim=-1)

        return init_types, init_onehot, init_log_onehot

计算流程

在实际的应用中,训练流通常遵循以下流程:

  • 初始化GeneralCategoricalTransition类,执行 _get_transition_mat() 方法,预计算转移矩阵和累计转移矩阵;
  • 将离散索引转为对数独热编码,调用 transition.add_noise() 方法加噪;
  • 神经网络预测得到类别向量,进行 softmax 归一化,得到重构的对数独热编码;
  • 调用 transition.q_v_posterior() 方法计算真实类别的后验概率分布和模型输出的重构后验概率分布;
  • 调用 transition.compute_v_Lt() 方法计算KL散度,反向传播更新参数。

而采样流通常遵循以下步骤:

  • 调用 transition.sample_init() 方法得到初始化的对数独热编码,输入神经网络;
  • 神经网络得到t1步的的预测分布,进行 log_softmax() 归一化得到对数独热编码;
  • transition.q_v_posterior() 计算后验概率分布,transition.sample_log() 从后验概率分布中采样得到t1步的独热编码类型;
  • 以此往复迭代T步。