清华 ReST-MCTS*:基于过程奖励引导树搜索的 LLM 自训练深度剖析
一、引言
在人工智能的快速发展进程中,大语言模型(LLM)的训练方法不断演进。传统的 LLM 训练往往严重依赖大量的人工标注数据,这不仅耗费巨大的人力、物力和时间成本,而且标注的准确性和一致性也难以保证。清华提出的 ReST-MCTS* 方法为解决这一困境带来了创新性的思路,通过蒙特卡罗树搜索(MCTS* )自动生成高质量的推理轨迹,并利用这些轨迹来训练策略模型和过程奖励模型,从而避免了传统方法中对人工标注的依赖。本文将对 ReST-MCTS* 进行全面而深入的解读,详细分析其核心原理、训练流程、推理机制以及实验结果。
论文链接:ReST-MCTS∗ : LLMSelf-Training via ProcessRewardGuidedTreeSearch
GitHub 地址:GitHub - THUDM/ReST-MCTS: ReST-MCTS*: LLMSelf-Training via ProcessRewardGuidedTreeSearch (NeurIPS2024)
项目地址:ReST-MCTS*: LLMSelf-Training via ProcessRewardGuidedTree Search
二、ReST-MCTS*核心原理
(一)主要组件
ReST-MCTS*方法主要由四个关键组件构成:
- MCTS*:在过程奖励模型(PRM)引导下进行树搜索,是整个方法的核心搜索机制。它通过不断地构建和搜索树结构,寻找最优的推理路径。
- 过程奖励模型(PRM):负责评估部分解决方案的质量,并依据评估结果指导 MCTS* 的搜索方向。其评估综合考虑了多个因素,对搜索过程起到关键的引导作用。
- 策略模型:针对每个问题生成多个中间推理步骤,为推理过程提供多样化的思路和方向,推动推理进程的发展。
- LLM 自训练模块:利用 MCTS* 生成的推理轨迹,对策略模型和 PRM 进行迭代训练,通过不断学习和优化,提升模型的性能。
(二)关键概念定义
- 部分解决方案的价值:部分解决方案 $p_C$ 的价值 $v_C$ 应满足以下基本性质:
- 有限限制:$v_C$ 被限制在特定范围内,通常为 $[0, 1]$。
- 正确性概率:$v_C$ 反映了部分解决方案是完整且正确的概率,$v_C$ 值越高,表示质量越好或更接近正确答案。
- 步骤贡献:$v_C$ 不仅考虑步骤的正确性,还考虑其对最终答案的贡献。
- 推理距离:推理距离 $d$ 表示从部分解决方案 $p_C$ 开始,到达正确答案所需的最小推理步骤数。它反映了当前步骤的进展情况以及后续推理的难度,是衡量推理进程的重要指标。
- 单步加权奖励:单步加权奖励 $r_k$ 用于反映当前步骤 $k$ 的质量,基于常见的 PRM 奖励和推理距离 $d$,其公式为$r_k=\frac{v_{k - 1}+(1 - d)r_{PRM}(p_k)}{2}$,其中 $v_{k - 1}$ 是前一步 $k - 1$ 的质量值,$d$ 是推理距离,$r_{PRM}(p_k)$ 是 PRM 对当前步骤的预测得分。随着 $d$ 的增加,$r_k$ 减少,说明需要更少的推理步骤才能得到正确答案时,当前步骤的加权奖励 $r_k$ 会被赋予更高的权重。
质量值 $v_k$ 的更新公式为$v_k=\frac{v_{k - 1}+r_k}{2}$,质量值 $v_k$ 融合了当前奖励 $r_k$ 以及前一步 $k - 1$ 的质量值 $v_{k - 1}$。单步加权奖励 $r_k$ 和质量值 $v_k$ 具有如下性质:
- 如果从 $p_C$ 开始的推理路径需要更多步数才能得到正确答案,那么单步加权奖励 $r_k$ 较低。
- 单步加权奖励 $r_k$ 和 PRM 奖励分正相关。
- 当达到正确答案时,$v_k$ 收敛到上界 1。
三、训练流程详解
(一)过程奖励模型初始化
- 数据来源
- MATH 数据集:采用 Mistral - 7B: MetaMATH 模型结合广度优先搜索(BFS)方式生成解决方案轨迹。对于每个数学问题,从初始状态开始,逐步生成推理步骤,每次生成多个可能的推理步骤,形成一颗搜索树。在生成搜索树后,验证所有叶节点的答案是否正确,通过简单的字符串匹配或 LLM 判断,确定生成的答案是否正确。
- 科学数据集:使用 SciInstruct 中的精简科学数据集,该数据集由 11,554 个问题组成,每个问题都配有正确的分步解决方案。采用 ChatGLM2 模型构建正样本和负样本,对于数据集中的每个问题和相应的解决方案,提取所有部分解决方案以形成正样本;同时,将 ChatGLM2 生成的步骤全部视为错误来构建负样本,总共收集了 473.4k 个样本用于训练初始过程奖励模型。
- 构建训练样本:从验证后的搜索树中提取部分解决方案及其对应的目标质量值。对于每个部分解决方案,计算其推理距离(即从该部分解决方案开始到达正确答案所需的最小推理步骤数),使用硬估计(Hard Estimation)方法计算每个推理步骤的 PRM 奖励(其中 $r_{PRM}$ 表示该步骤是否能够到达正确答案),根据推理距离和 $r_{PRM}$,计算部分解决方案的质量值 $v_i$ 和加权奖励 $r_i$。
质量值计算:$v_i=\frac{v_{i - 1}+r_i}{2}$
加权奖励计算:$r_i=\frac{v_{i - 1}+(1 - d)r_{PRM}(p_i)}{2}$ - 构建训练集与训练:将所有提取的部分解决方案及其对应的质量值 $v_i$ 和加权奖励 $r_i$ 组合成训练集,使用 AdamW 优化器和 MSE 损失进行优化,最终得到一个可以初步评估过程价值的奖励模型。使用 14k 个数据样本进行评测,评估函数为$L=\frac{1}{N}\sum_{i = 1}^{N}(v_i^{pred}-v_i^{true})^2$,其中 $N$ 是样本数量,$q_i$ 是样本 $i$ 的问题,$p_i$ 是样本 $i$ 的部分解,$v_i^{true}$ 是样本的目标值(也就是前面计算的质量值)。初始化训练之后,过程奖励模型的准确率为 69.3%。
(二)自训练过程
在每次迭代中,执行以下步骤:
- 使用策略模型和过程奖励模型(此时未使用过程奖励模型的预估值)进行 MCTS*搜索,生成推理轨迹。
- 验证推理轨迹的正确性,并提取训练数据。使用提取的数据训练策略模型,使用另一部分数据训练过程奖励模型。
- 重复上述步骤,进行多轮迭代,不断优化策略模型和过程奖励模型的性能。
四、推理过程剖析
在面对新问题时,ReST-MCTS* 运用 MCTS* 搜索生成推理路径,具体过程如下:
- 节点选择:从根节点开始,使用 UCB(Upper Confidence Bound)策略选择子节点,UCB 策略公式为$UCB(C)=\mu_C+\sqrt{\frac{2\ln N_p}{n_C}}$,其中,$\mu_C$ 是当前节点 $C$ 的质量值,$N_p$ 是父节点的访问次数,$n_C$ 是当前节点的访问次数,$\sqrt{\frac{2\ln N_p}{n_C}}$ 是探索项,用于平衡探索和利用。选择过程是从根节点开始,逐层选择 UCB 值最大的子节点,直到到达一个叶节点。
- 自我判断:判断当前步骤是否正确以及推理是否结束(结束标记:EoI)。除了使用过程奖励模型预估的奖励值,还使用自我判断机制。在每个扩展阶段之前提示推理模型根据现有的部分解决方案 $p$ 生成推理结束(EoI)信号或提供下一步探索建议。如果收到 EoI 信号,思维扩展和贪婪蒙特卡洛扩展就跳过;否则,建议将在后续扩展阶段作为推理提示的一部分被利用。
- 思维扩展:如果当前节点的质量值 $v_C$ 低于阈值(通常设置为 0.9),则扩展新的推理步骤。扩展过程为:首先使用策略模型生成新的推理步骤,为每个新生成的节点分配质量值,通过过程奖励模型进行评估,最后将新生成的节点添加到搜索树中。
- 贪婪蒙特卡洛展开:对扩展的节点进行模拟,选择最有价值的推理路径。模拟过程如下:
- 从扩展的节点开始,逐步生成推理步骤,并使用过程奖励模型评估每一步的质量值。
- 记录模拟过程中获得的最大质量值 $v_{max}$。
- 使用加权平均方法更新节点的质量值:$v_C=\frac{v_C+\lambda v_{max}}{1+\lambda}$,其中 $\lambda$ 是权重参数,通常设置为 0.5。
- 更新节点的访问次数。
- 值回传:从选择的节点开始,回传质量值,更新父节点的质量值。回传过程为:从选择的节点开始,逐层回传质量值,使用加权平均方法更新父节点的质量值,并更新父节点的访问次数。
- 最终选择质量值最大的路径作为输出结果。
以下是 ReST-MCTS*推理过程的流程图:
st=>start: 开始
ns=>operation: 节点选择(Node Selection)
sj=>operation: 自我判断(Self-Critic)
te=>operation: 思维扩展(Thought Expansion)
gm=>operation: 贪婪蒙特卡洛展开(Greedy MC Rollout)
vb=>operation: 值回传(Value Backpropagation)
end=>end: 输出结果
st->ns->sj->te->gm->vb->end
五、实验结果洞察
(一)基础评测
在多个数据集(如 MATH、GPQA、CEval - Hard 等)上进行评测,设置了一系列参数,如对于 ReST-MCTS*,判断搜索是否结束的质量值阈值设置为 $l = 0.9$,贪婪蒙特卡洛扩展的最大步数为 $m$,迭代次数为 $n$,每次生成推理步骤的分枝数为 $b$。
(二)关键因素影响
论文中提到,增加 MCTS* 采样次数,准确率会显著提升。此外,Rest - MCTS* 自训练两轮的效果会远好于一轮。通过调整这些关键参数和训练轮次,模型能够在不同数据集上不断优化性能,展示出 ReST-MCTS* 方法的有效性和潜力。
(三)算法对比优势
在与不同搜索算法的时间和效果对比中,ReST-MCTS* 也呈现出一定的优势。其在保证一定准确率的前提下,能够有效地控制搜索时间,提高推理效率,这使得它在实际应用场景中具有更强的竞争力,为解决复杂的推理问题提供了一种高效的解决方案。
六、结论
ReST-MCTS* 通过创新的架构和训练机制,有效地解决了传统 LLM 训练中人工标注的难题。其独特的组件设计和训练流程,使得模型能够自动生成高质量的推理轨迹,并通过迭代训练不断提升性能。在推理性能上,通过合理的参数设置和自训练优化,取得了可观的成果。这一方法为大语言模型的自训练开辟了新的方向,有望在未来的人工智能研究和应用中发挥重要作用,推动自然语言处理领域及相关应用的进一步发展。同时,随着研究的深入,未来还可以进一步探索如何进一步提高过程奖励模型的准确性、优化 MCTS* 搜索算法的效率以及拓展 ReST-MCTS*在更多领域和任务中的应用等问题,不断完善和拓展这一创新方法的潜力。
评论