一般而言,训练一个完整的 LLM 需要经过图1中的三个阶段——Pretrain、SFT 和 RLHF。
1.预训练
Pretrain,即预训练,是训练 LLM 最核心也是工程量最大的第一步。LLM 的预训练和传统预训练模型非常类似,同样是使用海量无监督文本对随机初始化的模型参数进行训练。目前主流的 LLM 几乎都采用了 Decoder-Only 的类 GPT 架构(LLaMA 架构),它们的预训练任务也都沿承了 GPT 模型的经典预训练任务——因果语言模型(Causal Language Model,CLM)。因果语言模型建模,即和最初的语言模型一致,通过给出上文要求模型预测下一个 token 来进行训练。
LM 的核心特点即在于其具有远超传统预训练模型的参数量,同时在更海量的语料上进行预训练。根据由 OpenAI 提出的 Scaling Law:C ~ 6ND,其中 C 为计算量,N 为模型参数,D 为训练的 token 数,可以实验得出训练 token 数应该是模型参数的 1.7倍,也就是说 175B 的 GPT-3,需要使用 300B token 进行预训练。而 LLaMA 更是进一步提出,使用 20倍 token 来训练模型能达到效果最优,因此 175B 的 GPT-3,可以使用3.5T token 数据预训练达到最优性能。
如此庞大的模型参数和预训练数据,使得预训练一个 LLM 所需要的算力资源极其庞大。事实上,哪怕是预训练一个 1B 的大模型,也至少需要多卡分布式 GPU 集群,通过分布式框架对模型参数、训练的中间参数和训练数据进行切分,才能通过以天为单位的长时间训练来完成。一般来说,百亿级 LLM 需要 1024张 A100 训练一个多月,而十亿级 LLM 一般也需要 256张 A100 训练两、三天,计算资源消耗非常高。
也正因如此,分布式训练框架也成为 LLM 训练必不可少的组成部分。分布式训练框架的核心思路是数据并行和模型并行。所谓数据并行,是指训练模型的尺寸可以被单个 GPU 内存容纳,但是由于增大训练的 batch_size 会增大显存开销,无法使用较大的 batch_size 进行训练;同时,训练数据量非常大,使用单张 GPU 训练时长难以接受。
在数据并行的情况下,每张 GPU 上的模型参数是保持一致的,训练的总批次大小等于每张卡上的批次大小之和。
但是,当 LLM 扩大到上百亿参数,单张 GPU 内存往往就无法存放完整的模型参数。在这种情况下,可以将模型拆分到多个 GPU 上,每个 GPU 上存放不同的层或不同的部分,从而实现模型并行。
在数据并行和模型并行的思想基础上,还演化出了多种更高效的分布式方式,例如张量并行、3D 并行、ZeRO(Zero Redundancy Optimizer,零冗余优化器)等。目前,主流的分布式训练框架包括 Deepspeed、Megatron-LM、ColossalAI 等,其中,Deepspeed 使用面最广。
Deepspeed 的核心策略是 ZeRO 和 CPU-offload。ZeRO 是一种显存优化的数据并行方案,其核心思想是优化数据并行时每张卡的显存占用,从而实现对更大规模模型的支持。ZeRO 将模型训练阶段每张卡被占用的显存分为两类:
- 模型状态(Model States),包括模型参数、模型梯度和优化器 Adam 的状态参数。假设模型参数量为 1M,一般来说,在混合精度训练的情况下,该部分需要 16M 的空间进行存储,其中 Adam 状态参数会占据 12M 的存储空间。
- 剩余状态(Residual States),除了模型状态之外的显存占用,包括激活值、各种缓存和显存碎片。
针对上述显存占用,ZeRO 提出了三种不断递进的优化策略:
- ZeRO-1,对模型状态中的 Adam 状态参数进行分片,即每张卡只存储 1N 的 Adam 状态参数,其他参数仍然保持每张卡一份。
- ZeRO-2,继续对模型梯度进行分片,每张卡只存储 1N 的模型梯度和 Adam 状态参数,仅模型参数保持每张卡一份。
- ZeRO-3,将模型参数也进行分片,每张卡只存储 1N 的模型梯度、模型参数和 Adam 状态参数。
可以看出,随着分片的参数量不断增加,每张卡需要占用的显存也不断减少。当然,分片的增加也就意味着训练中通信开销的增加,一般而言,每张卡的 GPU 利用率 ZeRO-1 最高而 ZeRO-3 最低。具体使用什么策略,需要结合计算资源的情况和需要训练的模型体量动态确定。
预训练数据的处理与清洗也是 LLM 预训练的一个重要环节。诸多研究证明,预训练数据的质量往往比体量更加重要。预训练数据处理一般包括以下流程:
- 文档准备。由于海量预训练语料往往是从互联网上获得,一般需要从爬取的网站来获得自然语言文档。文档准备主要包括 URL 过滤(根据网页 URL 过滤掉有害内容)、文档提取(从 HTML 中提取纯文本)、语言选择(确定提取的文本的语种)等。
- 语料过滤。语料过滤的核心目的是去除低质量、无意义、有毒有害的内容,例如乱码、广告等。语料过滤一般有两种方法:基于模型的方法,即通过高质量语料库训练一个文本分类器进行过滤;基于启发式的方法,一般通过人工定义 web 内容的质量指标,计算语料的指标值来进行过滤。
- 语料去重。实验表示,大量重复文本会显著影响模型的泛化能力,因此,语料去重即删除训练语料中相似度非常高的文档,也是必不可少的一个步骤。去重一般基于 hash 算法计算数据集内部或跨数据集的文档相似性,将相似性大于指定阈值的文档去除;也可以基于子串在序列级进行精确匹配去重。
目前,已有很多经过处理的高质量预训练语料和专用于预训练数据处理的框架。例如,有基于 LLaMA 思路收集、清洗的预训练数据集RedPajama-1T,以及在 RedPajama 基础上进行筛选去重的SlimPajama-627B数据集,实验证明高质量的 627B Slimpajama 数据集能够获得比 1T 的 RedPajama 数据集更好的效果。
2. 有监督微调
预训练是 LLM 强大能力的根本来源,事实上,LLM 所覆盖的海量知识基本都是源于预训练语料。LLM 的性能本身,核心也在于预训练的工作。但是,预训练赋予了 LLM 能力,却还需要第二步将其激发出来,也就是 SFT——Supervisor Finetune,有监督微调。所谓有监督微调就是对于能力有限的传统预训练模型,我们需要针对每一个下游任务单独对其进行微调以训练模型在该任务上的表现。
而面对能力强大的 LLM,我们往往不再是在指定下游任务上构造有监督数据进行微调,而是选择训练模型的“通用指令遵循能力”,也就是一般通过指令微调
的方式来进行 SFT。所谓指令微调,即我们训练的输入是各种类型的用户指令,而需要模型拟合的输出则是我们希望模型在收到该指令后做出的回复。例如,我们的一条训练样本可以是:
input:告诉我今天的天气预报?
output:根据天气预报,今天天气是晴转多云,最高温度26摄氏度,最低温度9摄氏度,昼夜温差大,请注意保暖哦
SFT 的主要目标是让模型从多种类型、多种风格的指令中获得泛化的指令遵循能力,也就是能够理解并回复用户的指令。因此,类似于 Pretrain,SFT 的数据质量和数据配比也是决定模型指令遵循能力的重要因素。
高质量的指令数据集具有较高的获取难度。不同于预训练使用的无监督语料,SFT 使用的指令数据集是有监督语料,除去设计广泛、合理的指令外,还需要对指令回复进行人工标注,并保证标注的高质量。
{
"instruction":"即输入的用户指令",
"input":"执行该指令可能需要的补充输入,没有则置空",
"output":"即模型应该给出的回复"
}
同时,为使模型能够学习到和预训练不同的范式,在 SFT 的过程中,往往会针对性设置特定格式。例如,LLaMA 的 SFT 格式为:
### Instruction:\n{{content}}\n\n### Response:\n
其中的 content 即为具体的用户指令,也就是说,对于每一个用户指令,将会嵌入到上文的 content 部分,这里的用户指令不仅指上例中的 “instruction”,而是指令和输入的拼接,即模型可以执行的一条完整指令。例如,针对上例,LLaMA 获得的输入应该是:
### Instruction:\n将下列文本翻译成英文:今天天气真好\n\n### Response:\n
其需要拟合的输出则是:
### Instruction:\n将下列文本翻译成英文:今天天气真好\n\n### Response:\nToday is a nice day!
指令微调本质上仍然是对模型进行 CLM 训练,只不过要求模型对指令进行理解和回复而不是简单地预测下一个 token,所以模型预测的结果不仅是 output,而应该是 input + output,只不过 input 部分不参与 loss 的计算,但回复指令本身还是以预测下一个 token 的形式来实现的。
随着 LLM 能力的不断增强,模型的多轮对话能力逐渐受到重视。所谓多轮对话,是指模型在每一次对话时能够参考之前对话的历史记录来做出回复。模型的多轮对话能力完全来自于 SFT 阶段。如果要使模型支持多轮对话,我们需要在 SFT 时将训练数据构造成多轮对话格式,让模型能够利用之前的知识来生成回答。
构造多轮对话样本一般有三种方式:
1. 直接将最后一次模型回复作为输出,前面所有历史对话作为输入,直接拟合最后一次回复:
input=<prompt_1><completion_1><prompt_2><completion_2><prompt_3><completion_3>
output=[MASK][MASK][MASK][MASK][MASK]<completion_3>
2. 将 N 轮对话构造成 N 个样本:
input_1 = <prompt_1><completion_1>
output_1 = [MASK]<completion_1>input_2 = <prompt_1><completion_1><prompt_2><completion_2>
output_2 = [MASK][MASK][MASK]<completion_2>input_3=<prompt_1><completion_1><prompt_2><completion_2><prompt_3><completion_3>
output_3=[MASK][MASK][MASK][MASK][MASK]<completion_3>
3. 直接要求模型预测每一轮对话的输出:
input=<prompt_1><completion_1><prompt_2><completion_2><prompt_3><completion_3>
output=[MASK]<completion_1>[MASK]<completion_2>[MASK]<completion_3>
第一种方式会丢失大量中间信息,第二种方式造成了大量重复计算,只有第三种方式是最合理的多轮对话构造。我们之所以可以以第三种方式来构造多轮对话样本,是因为 LLM 本质还是进行的 CLM 任务,进行单向注意力计算,因此在预测时会从左到右依次进行拟合,前轮的输出预测不会影响后轮的预测。目前,绝大部分 LLM 均使用了多轮对话的形式来进行 SFT。
3.人类反馈强化学习
RLHF,全称是 Reinforcement Learning from Human Feedback,即人类反馈强化学习,是利用强化学习来训练 LLM 的关键步骤。相较于 SFT,RLHF 往往被认为是 ChatGPT 相较于 GPT-3 的最核心突破。事实上,从功能上出发,我们可以将 LLM 的训练过程分成预训练与对齐(alignment)两个阶段。预训练的核心作用是赋予模型海量的知识,而所谓对齐,其实就是让模型与人类价值观一致,从而输出人类希望其输出的内容。在这个过程中,SFT 是让 LLM 和人类的指令对齐,从而具有指令遵循能力;而 RLHF 则是从更深层次令 LLM 和人类价值观对齐,令其达到安全、有用、无害的核心标准。
ChatGPT 在技术报告中将对齐分成三个阶段,后面两个阶段训练 RM 和 PPO 训练,就是 RLHF 的步骤:
LHF 的思路是,引入强化学习的技术,通过实时的人类反馈令 LLM 能够给出更令人类满意的回复。强化学习是有别于监督学习的另一种机器学习方法,主要讨论的问题是智能体怎么在复杂、不确定的环境中最大化它能获得的奖励。强化学习主要由两部分构成:智能体和环境。在强化学习过程中,智能体会不断行动并从环境获取反馈,根据反馈来调整自己行动的策略。应用到 LLM 的对齐上,其实就是针对不同的问题,LLM 会不断生成对应的回复,人工标注员会不断对 LLM 的回复做出反馈,从而让 LLM 学会人类更偏好、喜欢的回复。
如上图,RLHF 分为两个步骤:训练 RM 和 PPO 训练。
RM,Reward Model,即奖励模型。RM 是用于拟合人类偏好,来给 LLM 做出反馈的。在强化学习的训练中,对于 LLM 的每一个回复,RM 会进行打分,这个打分反映了生成回复符合人类偏好的程度。然后 LLM 会根据强化学习的原理,基于 RM 的打分来进行优化训练。所以,RM 本质上是一个文本分类模型,对于一个文本输出一个标量奖励,和文本分类任务中的隐藏层输出非常类似。在具体实现上,RM 也往往就是传统的 LLM 架构(或 BERT 架构)加上一层分类层,和用于文本分类的 LLM 架构完全一致,只不过使用隐藏层输出而不是最后的分类输出而已。
在完成 RM 训练之后,就可以使用 PPO 算法来进行强化学习训练。PPO,Proximal Policy Optimization,近端策略优化算法,是一种经典的 RL 算法。事实上,强化学习训练时也可以使用其他的强化学习算法,但目前 PPO 算法因为成熟、成本较低,还是最适合 RLHF 的算法。
在具体 PPO 训练过程中,会存在四个模型。如图所示,两个 LLM 和两个 RM。两个 LLM 分别是进行微调、参数更新的 actor model 和不进行参数更新的 ref model,均是从 SFT 之后的 LLM 初始化的。两个 RM 分别是进行参数更新的 critic model 和不进行参数更新的 reward model,均是从上一步训练的 RM 初始化的。
如上图,使用 PPO 算法的强化学习训练过程如下:
在上述过程中,因为要使用到四个模型,显存占用会数倍于 SFT。 Actor Model 和 Critic Model 较为容易理解,而之所以我们还需要保持原参数不更新的 Ref Model 和 Reward Model,是为了限制模型的更新不要过于偏离原模型以至于丢失了 Pretrain 和 SFT 赋予的能力。
当然,如此大的资源占用和复杂的训练过程,使 RLHF 成为一个门槛非常高的阶段。也有学者从监督学习的思路出发,提出了 DPO(Direct Preference Optimization,直接偏好优化),可以低门槛平替 RLHF。DPO 的核心思路是,将 RLHF 的强化学习问题转化为监督学习来直接学习人类偏好。DPO 通过使用奖励函数和最优策略间的映射,展示了约束奖励最大化问题完全可以通过单阶段策略训练进行优化,也就是说,通过学习 DPO 所提出的优化目标,可以直接学习人类偏好,而无需再训练 RM 以及进行强化学习。由于直接使用监督学习进行训练,DPO 只需要两个 LLM 即可完成训练,且训练过程相较 PPO 简单很多,是 RLHF 更简单易用的平替版本。
参考文献:《Happy-LLM从零开始的大语言模型原理与实践教程》