DPO学习笔记

  • 1 原理
    • 1.0 名词
    • 1.1 preference model
    • 1.2 RLHF
    • 1.3 从RLHF到DPO
      • A.解的最优形式
      • B. DPO下参数估计
      • C. DPO下梯度更新
      • D. DPO训练的稳定性
  • 2 源代码
    • 2.1 数据集构成
    • 2.2 计算log prob
    • 2.3 DPO loss

1 原理

1.0 名词

  • preference model:对人类偏好进行建模,这个"model"不是DL model
  • policy model:最终要训练得到的LLM πθ\pi_\thetaπθ
  • reward model:用来评价LLM生成的结果有多符合人类偏好

1.1 preference model

  • 是一种者范式、定义,是用来预测人类对不同输出项之间相对偏好概率的模型,例如,在比较两个响应时,偏好模型可以估计出“响应A比响应B更受欢迎”的概率
  • DPO中使用的是Bradley–Terry 模型来定义偏好的概率形式,给定2个选项ywy_wywyly_lyl,Bradley–Terry 定义的的ywy_wywyly_lyl好的概率为
    p(yw≥yl)=exp(θw)exp(θw)+exp(θl)p(y_w \ge y_l)=\frac{exp(\theta_w)}{exp(\theta_w)+exp(\theta_l)} p(ywyl)=exp(θw)+exp(θl)exp(θw)

1.2 RLHF

在这里插入图片描述
RLHF需要使用人标注的偏好数据对,先训练一个reward model,然后再让reward model和LLM做强化学习
【1】SFT训练LLM: 使用目标任务的训练数据训练得到的模型记为πSFT\pi^{SFT}πSFT
【2】训练reward model: 使用目标任务的另一份数据xxx输入πSFT\pi^{SFT}πSFT,每份数据得到2个输出,记为(y1,y2)∼πSFT(y∣x)(y_1,y_2) \sim \pi^{SFT}(y \mid x)(y1,y2)πSFT(yx)。这些成对的数据给到人工标注者,进行偏好标注,(y1,y2)(y_1,y_2)(y1,y2)里面人工觉得回答的好的数据为ywy_wyw,觉得回答的不好的数据为yly_lyl,得到的数据集为D={xi,ywi,yli}i=1N\mathcal{D}=\{x^{i},y^i_w,y^i_l\}^N_{i=1}D={xi,ywi,yli}i=1N。假设这种偏好产生自一个隐藏的奖励模型r∗(y,x)r^*(y,x)r(y,x),当使用Bradley-Terry模型来建模,人类偏好p∗p^*p的分布可以表示为
p∗(yw≻yl∣x)=exp(r∗(x.y1))exp(r∗(x.y1))+exp(r∗(x.y2))p^*(y_w \succ y_l \mid x)=\frac{exp(r^*(x.y_1))}{exp(r^*(x.y_1))+exp(r^*(x.y_2))} p(ywylx)=exp(r(x.y1))+exp(r(x.y2))exp(r(x.y1))
  可以形式化奖励模型参数为rϕ(x,y)r_\phi(x,y)rϕ(x,y)并且使用极大似然估计在数据集D\mathcal{D}D上估计参数,建模为二分类问题,损失函数可以为(也可以是其他形式,相减比较符合认知):
LR(rϕ,D)=−E(x,yw,yl)∼D[logσ(rϕ(x,yw)−rϕ(x,yl))]\mathcal{L}_R(r_\phi,\mathcal{D})=-\mathbb{E}_{(x,y_w,y_l)\sim\mathcal{D}}[log \sigma(r_\phi(x,y_w)-r_\phi(x,y_l))]LR(rϕ,D)=E(x,yw,yl)D[logσ(rϕ(x,yw)rϕ(x,yl))]

【3】RL微调: 在RL阶段,优化目标带有KL约束
max⁡πθEx∼D,y∼πθ(y∣x)[rϕ(x,y)−βDKL[πθ(y∣x)∥πref(y∣x)]]\max_{\pi_{\theta}}\mathbb{E}_{x \sim \mathcal{D},y \sim \pi_{\theta}(y \mid x)}[r_\phi(x,y)-\beta\mathbb{D}_{KL}[\pi_{\theta}(y \mid x)\parallel \pi_{ref}(y \mid x)]] πθmaxExD,yπθ(yx)[rϕ(x,y)βDKL[πθ(yx)πref(yx)]]

1.3 从RLHF到DPO

A.解的最优形式

  首先,根据RL优化目标的形式,奖励函数为rrr,最优的策略π\piπ的形式为
πr(y∣x))=1Z(x)πref(y∣x)exp(1βr(x,y))\pi_r(y \mid x))=\frac{1}{Z(x)}\pi_{ref}(y \mid x) exp(\frac{1}{\beta}r(x,y)) πr(yx))=Z(x)1πref(yx)exp(β1r(x,y))
其中Z(x)=∑yπref(y∣x)exp(1βr(x,y))Z(x)=\sum_{y}\pi_{ref}(y \mid x) exp(\frac{1}{\beta}r(x,y))Z(x)=yπref(yx)exp(β1r(x,y))。之所以能得到这个形式在原论文的附录中有推导
在这里插入图片描述
  里面的第3步到第4步是因为可以引入Z(x)Z(x)Z(x)构造一个新的概率分布,Z(x)Z(x)Z(x)是归一化因子,保证π~(y∣x)\tilde{\pi} (y \mid x)π~(yx)是有效的概率分布:
π~(y∣x)=1Z(x)πrefexp(1βr(x,y))\tilde{\pi} (y \mid x)=\frac{1}{Z(x)}\pi_{ref}exp(\frac{1}{\beta}r(x,y))π~(yx)=Z(x)1πrefexp(β1r(x,y))

  这样,原来的式子
logπ(y∣x)πref(y∣x)=logπ(y∣x)−πref(y∣x)−log[exp(1βr(x,y))]=logπ(y∣x)π~(y∣x)−logZ(x)log \frac{\pi(y \mid x)}{\pi_{ref}(y \mid x)} =log\pi(y \mid x)-\pi_{ref}(y \mid x) - log[exp(\frac{1}{\beta}r(x,y))] \\ =log \frac{\pi(y \mid x)}{\tilde{\pi}_(y \mid x)} - log Z(x) logπref(yx)π(yx)=logπ(yx)πref(yx)log[exp(β1r(x,y))]=logπ~(yx)π(yx)logZ(x)

  又因π\piπ的形式只需要满足是合法的概率分布就可以,因此形式上可以替换,以及Z(x)Z(x)Z(x)不是yyy的函数,所以期望写进去不会对logZ(x)log Z(x)logZ(x)有影响,得到了最优策略下,策略函数的形式(给定xxx的情况下输出yyy的概率 / 在给定状态SSS的情况下,下一个时间的进入状态S′S'S的概率)
π∗(y∣x)=1Z(x)πref(y∣x)exp(1βr(x,y))\pi^*(y \mid x)= \frac{1}{Z(x)}\pi_{ref}(y \mid x) exp(\frac{1}{\beta} r(x,y)) π(yx)=Z(x)1πref(yx)exp(β1r(x,y))
在这里插入图片描述

B. DPO下参数估计

  • 即使得到了最优策略πr\pi_rπr的形式,并且即使把里面的r(x,y)r(x,y)r(x,y)用MLE估计的rrr来替换,里面也有一个Z(x)Z(x)Z(x)需要估计,Z(x)Z(x)Z(x)的计算是很复杂的,里面的"状态"或者说词表yyy很大的情况下开销大
  • 但是可以进一步把式子整理一下,重新表示一下reward函数
    r(x,y)=βlogπr(y∣x)πref(y∣x)+βlogZ(x)r(x,y)=\beta log \frac{\pi_r(y \mid x)}{\pi_{ref}(y \mid x)}+ \beta log Z(x)r(x,y)=βlogπref(yx)πr(yx)+βlogZ(x)
  • 带入原始的Bradley-Terry的式子,会发现,最后衡量偏好的函数里面,没有reward function Z(x)Z(x)Z(x)这一项需要计算了抵消掉了

在这里插入图片描述

  • 所以DPO的目标是提升yw≻yly_w \succ y_lywyl的概率,损失函数的形式为
    LDPO(πθ;πref)=−E(x,yw,wl)∼D[logσ(βlogπθ(yw∣x)πref(yw∣x)−βlogπθ(yl∣x)πref(yl∣x))]\mathcal{L}_{DPO}(\pi_\theta;\pi_{ref}) = -\mathbb{E}_{(x,y_w,w_l)\sim \mathcal{D}}[log \sigma(\beta log \frac{\pi_\theta(y_w \mid x)}{\pi_{ref}(y_w \mid x)} - \beta log \frac{\pi_\theta(y_l \mid x)}{\pi_{ref}(y_l \mid x)}) ] LDPO(πθ;πref)=E(x,yw,wl)D[logσ(βlogπref(ywx)πθ(ywx)βlogπref(ylx)πθ(ylx))]

C. DPO下梯度更新

在这里插入图片描述

  • 和人类偏好差异越大的,前面的系数越大

D. DPO训练的稳定性

在这里插入图片描述

  • 第二项为归一化项是常数是因为对当前xxx,遍历了所有的yyy
  • 减少极端值的影响:通过指数加权平均,极端值的影响会被削弱,从而使得奖励函数更加平滑
  • 稳定梯度估计:由于奖励函数变得更加平滑,策略梯度的估计也会更加稳定,方差会显著减小

2 源代码

RLAIF-V:https://github.com/RLHF-V/RLAIF-V/tree/main

2.1 数据集构成

  • chose——人类偏好的回答
  • rejected——SFT阶段的模型回答
  • ref_win_logp——人类偏好回答的所有token的log_probability之和
  • ref_rej_logp——模型回答的的所有token的log_probability之和
  • ref_win_avg_logp——人类偏好回答的所有token的log_probability之和 / 回答长度的token数
data_dict = {'image': image,"question": question,"chosen": chosen,"rejected": rejected,"idx": sample['idx'],"metainfo": metainfo
}
logps=json.loads(sample['logps']) # 调用/muffin下面的./eval/muffin_inference_logp.pyif type(logps) == type([]):(data_dict['ref_win_logp'], data_dict['ref_win_avg_logp'], data_dict['ref_win_per_token_logp'],data_dict['ref_rej_logp'], data_dict['ref_rej_avg_logp'], data_dict['ref_rej_per_token_logp']) = logps
else:(data_dict['ref_win_logp'], data_dict['ref_win_avg_logp'], data_dict['ref_win_per_token_logp'],data_dict['ref_rej_logp'], data_dict['ref_rej_avg_logp'], data_dict['ref_rej_per_token_logp']) = logps['logps']return data_dict

2.2 计算log prob

def get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, return_per_token_logp=False, return_all=False, tokenizer=None) -> torch.FloatTensor:"""Compute the log probabilities of the given labels under the given logits.Args:logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length)Returns:A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits."""assert logits.shape[:-1] == labels.shape, f'logits.shape[:-1]={logits.shape[:-1]}, labels.shape={labels.shape}'labels = labels[:, 1:].clone()logits = logits[:, :-1, :]loss_mask = (labels != -100)# dummy token; we'll ignore the losses on these tokens laterlabels[labels == -100] = 0per_token_logps = torch.gather(logits.log_softmax(-1), dim=2,index=labels.unsqueeze(2)).squeeze(2) # get log probabilities for each token in labelslog_prob = (per_token_logps * loss_mask).sum(-1)average_log_prob = log_prob / loss_mask.sum(-1)

2.3 DPO loss

  • policy model指的是正在训练的模型,ref model是之前SFT阶段的模型
  • 注意policy_chosen_logps这些是log 的probability,所以和原始的DPO的loss公式是完全等价的
def get_beta_and_logps(data_dict, model, args, is_minicpm=False, is_llava15=False):win_input_ids = data_dict.pop('win_input_ids')rej_input_ids = data_dict.pop('rej_input_ids')ref_win_logp = data_dict.pop('ref_win_logp')ref_rej_logp = data_dict.pop('ref_rej_logp')log_prob, average_log_prob = get_batch_logps(output.logits, concatenated_labels, return_per_token_logp=False)if args.dpo_use_average:concatenated_logp = average_log_probwin_size = win_input_ids.shape[0]rej_size = rej_input_ids.shape[0]policy_win_logp, policy_rej_logp = concatenated_logp.split([win_size, rej_size])  # 默认的是average的log_logits,值越大越置信return policy_win_logp, policy_rej_logp, ref_win_logp, ref_rej_logp, betadef dpo_loss(policy_chosen_logps: torch.FloatTensor,policy_rejected_logps: torch.FloatTensor,reference_chosen_logps: torch.FloatTensor,reference_rejected_logps: torch.FloatTensor,beta: float,reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:"""Compute the DPO loss for a batch of policy and reference model log probabilities.Args:policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.Returns:A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).The losses tensor contains the DPO loss for each example in the batch.The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively."""pi_logratios = policy_chosen_logps - policy_rejected_logps  # log(\pi(a_i | x)) - log(\pi(b_i | x)) = log(\pi(a_i | x) / \pi(b_i | x))ref_logratios = reference_chosen_logps - reference_rejected_logps  # 完全等价的if reference_free:ref_logratios = 0logits = pi_logratios - ref_logratioslosses = -F.logsigmoid(beta * logits)chosen_rewards = beta * (policy_chosen_logps -reference_chosen_logps).detach()rejected_rewards = beta * \(policy_rejected_logps - reference_rejected_logps).detach()return losses, chosen_rewards, rejected_rewards############# 调用为policy_win_logp, policy_rej_logp, ref_win_logp, ref_rej_logp, beta = get_beta_and_logps(data_dict, model, self.args, is_llava15=True) # 这些都是averaged的token的log_logitslosses, chosen_rewards, rejected_rewards = dpo_loss(policy_win_logp,policy_rej_logp,ref_win_logp,ref_rej_logp,beta=beta)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/92129.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/92129.shtml
英文地址,请注明出处:http://en.pswp.cn/web/92129.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025最新、UI媲美豆包、DeepSeek等AI大厂的AIGC系统 - IMYAI源码部署教程

IMYAI 系统部署与使用手册 一、系统演示 🔹 快速体验 前端演示地址:https://super.imyaigc.com后台演示地址:https://super.imyaigc.com/settings 🔹 技术架构 前端:Vite Vue3 NaiveUI TailwindCSS Plyr后端&…

【关于Java的反射】

在 Java 编程中,反射(Reflection) 是一个非常强大的工具,它允许你在运行时动态地获取类的信息、创建对象、调用方法和访问字段。虽然反射功能强大,但它也有一些局限性和性能开销,因此需要谨慎使用。一、什么…

Gitee推出“移动软件工厂“解决方案 解决嵌入式与涉密场景研发困局

Gitee推出"移动软件工厂"解决方案 破解嵌入式与涉密场景研发困局 随着数字化转型浪潮的推进,软件开发正面临着前所未有的复杂环境挑战。特别是在嵌入式系统、FPGA开发以及涉密信息系统等特殊场景下,研发团队往往需要在高安全要求与有限网络环境…

低功耗16*8位四线串行8*4按键阵矩LED驱动专用电路

概述:PC0340是占空比可调的LED显示控制驱动电路。由16根段输出、8根位输出、数字接口、数据锁存器、显示存储器、键扫描电路及相关控制电路组成了一个高可靠性的单片机外围LED驱动电路。串行数据通过4线串行接口输入到PC0340,采用LQFP44L的封装形式。本产…

通过自定义注解加aop切面实现权限控制

前言:自定义注解,通过aop切面前置通知,对请求接口进行权限控制1,创建枚举类package org.springblade.sample.annotationCommon;import lombok.AllArgsConstructor; import lombok.Getter;import java.util.Arrays; import java.ut…

IDS知识点

在网络安全工程师、系统运维工程师等岗位的面试中,​​IDS(Intrusion Detection System,入侵检测系统)​​ 是高频考点,尤其是对网络安全防护、安全监控类岗位。以下是IDS的核心考点和必须掌握的知识点,按优…

Adobe Analytics 数据分析平台|全渠道客户行为分析与体验优化

Adobe Analytics 是业界领先的数据分析平台,帮助企业实时追踪客户行为,整合多渠道数据,通过强大的分析与可视化工具深入分析客户旅程,优化数字体验。结合 Adobe Experience Cloud,Adobe Analytics 成为推动数字化增长和…

【轮播图】H5端轮播图、横向滑动、划屏效果实现方案——Vue3+CSS position/CSS scroller

文章目录定位实现滑屏效果前置知识CSS: touch-action属性CSS: transform属性触摸事件forEach回调占位符准备阶段实现移动效果实现跟手效果触摸结束优化完整代码滚动实现滑屏效果前置知识CSS: scroll-snap-type属性准备阶段实现滑动效果实现吸附效果滚动条隐藏存在问题完整代码s…

忘记了WordPress管理员密码的找回方法

WordPress管理员密码找回方法 如果您忘记了WordPress管理员密码,可以通过以下几种方法找回或重置: 方法1:通过电子邮件重置(最简单) 访问您的WordPress登录页面(通常是wodepress.com/wp-admin或wodepress.com/wp-login.php) 点击”忘记密…

RAFT:让语言模型更聪明地用文档答题

RAFT:让语言模型更聪明地用文档答题 作者注: 本文旨在面向零基础读者介绍 UC Berkeley 提出的 RAFT(Retrieval-Augmented Fine-Tuning)方法。它是一种训练语言模型的新方式,让模型更好地利用“外部知识”——比如文档、…

【紧急预警】NVIDIA Triton推理服务器漏洞链可导致RCE!

2025 年 8 月 4 日消息,NVIDIA 旗下的 Triton 推理服务器(一款支持 Windows 和 Linux 系统、用于大规模运行 AI 模型的开源平台)被曝出一系列安全漏洞。这些漏洞一旦被利用,攻击者有可能完全接管存在漏洞的服务器。 Wiz 安全公司…

基于深度学习的医学图像分析:使用PixelCNN实现医学图像生成

前言 医学图像分析是计算机视觉领域中的一个重要应用,特别是在医学图像生成任务中,深度学习技术已经取得了显著的进展。医学图像生成是指通过深度学习模型生成医学图像,这对于医学研究、疾病模拟和图像增强等任务具有重要意义。近年来&#x…

React ahooks——副作用类hooks之useDebounceFn

useDebounceFn 是 ahooks 提供的用于函数防抖的 Hook,它可以确保一个函数在连续触发时只执行最后一次。一、基本用法import { useDebounceFn } from ahooks; import { Button } from antd;const Demo () > {const { run } useDebounceFn(() > {console.log(…

【机器学习深度学习】 知识蒸馏

目录 前言 一、什么是知识蒸馏? 二、知识蒸馏的核心意义 2.1 降低算力与成本 2.2 加速推理与边缘部署 2.3 推动行业应用落地 2.4 技术自主可控 三、知识蒸馏的本质:大模型的知识传承 四、知识蒸馏的“四重红利” 五、DeepSeek的知识蒸馏实践 …

Python高级编程与实践:Python高级数据结构与编程技巧

高级数据结构:掌握Python中的高效编程技巧 学习目标 通过本课程,学员将深入了解Python中的高级数据结构,包括列表推导式、字典推导式、集合推导式和生成器表达式。学员将学习如何利用这些结构来编写更简洁、更高效的代码,并了解它…

【C++】Stack and Queue and Functor

本文是小编巩固自身而作,如有错误,欢迎指出!本次我们介绍STL中的stack和queue和其相关的一些容器和仿函数一.stack and queue1.适配器stack和queue其实不是真正意义上的容器,而是容器适配器,而容器适配器又是什么呢&am…

Python爬虫实战:研究OpenCV技术构建图像数据处理系统

1. 引言 1.1 研究背景 在当今数字化时代,图像作为一种重要的信息载体,广泛存在于各类网站、社交媒体和在线平台中。这些图像数据涵盖了从自然风光、人物肖像到商品展示、新闻事件等丰富内容,为数据分析和模式识别提供了宝贵的资源。随着计算机视觉技术的快速发展,对大规模…

电感矩阵-信号完整性分析

电感矩阵:正如电容矩阵用于存储许多信号路径和返回路径的所有电容量,我们也需要一个矩阵存储许多导线的回路自感和回路互感值。需要牢记的是,这里的电感元件是回路电感。当信号沿传输线传播时,电流回路沿信号路径传输,然后立即从返…

JUC相关知识点总结

Java JUC(java.util.concurrent)是Java并发编程的核心工具包,提供了丰富的并发工具类和框架。以下是JUC的主要知识点,按难易程度分类,供你参考: 1. 基础概念与工具类 1.1 并发与并行(易&#x…

激光频率梳 3D 测量方案革新:攻克光学扫描遮挡,130mm 深孔测量精度达 2um

一、深孔测量的光学遮挡难题在精密制造领域,130mm 级深孔(如航空发动机燃油孔、模具冷却孔)的 3D 测量长期受困于光学遮挡。传统激光扫描技术依赖直射光束,当深径比超过 10:1 时,孔壁中下部形成大量扫描盲区&#xff0…