DPO 概要

  1. DPO(Direct Preference Optimization,直接偏好优化)是由斯坦福大学等研究团队于2023年提出的一种偏好优化算法,可用于LLM、VLM与MLLM的对齐训练。

  2. 算法基于PPO的RLHF基础上进行了大幅简化。DPO算法跳过了训练奖励模型这一中间过程,直接(Direct)优化策略模型 ——这正是DPO命名中“D(Direct)”的含义所在。

主要流程

  1. 数据收集: 基于SFT训练的模型作为推理模型,用户输入prompt,模型多次推理,找到好的答案和不好的答案。如果都是不好(rejected)的答案,则人工修改把不好的答案变为好的答案。

    标数据收集
  2. 主要包含两个基础模型,策略模型&参考模型(不需要Reward模型)。 在trl强化学习框架中,只需要传入策略模型,参考模型会复制一份策略模型。

    1. 策略模型是DPO需要训练的模型,后用在项目中的模型。策略模型的权重直接复制SFT阶段微调模型的权重

    2. 参考模型是策略模型的帮衬,其权重参数冻结不变。主要两个作用,其一协助其计算reward loss,其二计算kl正则项,防止其训练偏移初始SFT模型太远,由一个β参数控制。

  3. β参数控制含义

    1. 较大 beta(如 1.0):放大 reward 或 logp 的差异,使模型更“自信”地倾向于较优样本,但容易过拟合或 reward 震荡。

    2. 较小 beta(如 0.1):差异被压缩,模型训练更稳定,但收敛较慢、辨别力较弱。

    3. 极小 beta(趋近于 0):差异几乎无效,模型无法区分好坏样本,退化为随机训练

  4.  整体流程如下:

  5. 具体流程

    DPO训练流程细节

九个损失函数解析

"loss": 1.8678"rewards/chosen": 42.519317626953125"rewards/rejected": -33.865535736083984"rewards/accuracies": 0.865429699420929"rewards/margins": 76.38734436035156"logps/chosen": -948.4149780273438"logps/rejected": -1285.1175537109375"logits/chosen": 5.363300800323486"logits/rejected": 4.879658222198486
  1. logps/chosen和logps/rejected: logps 是模型生成 token 概率,在归一化后(softmax)取 log 后的值(log prob)。

    #1 把 prompt 和 response 拼接起来作为输入
    input = prompt + response
    from transformers import AutoTokenizer, AutoModelForCausalLM
    import torch# 加载 tokenizer 和模型
    tokenizer = AutoTokenizer.from_pretrained("your-model-name")
    model = AutoModelForCausalLM.from_pretrained("your-model-name").cuda()# 设置 prompt 和 response
    prompt = "你今天心情怎么样?"
    response = "我今天很开心,太阳出来了,我们一起去玩吧!"# 拼接输入
    full_input = prompt + response
    encodings = tokenizer(full_input, return_tensors="pt").to("cuda")
    input_ids = encodings["input_ids"]# 找到 response 的起始位置
    prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
    response_start = prompt_ids.shape[-1]# 前向推理,获取 logits
    with torch.no_grad():outputs = model(**encodings)logits = outputs.logits# 计算 log probabilities
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)# 获取 response 部分 token 的 log probability
    response_token_ids = input_ids[:, response_start:]
    response_logits = log_probs[:, response_start - 1:-1, :]  # 对应 shift
    response_logp = torch.gather(response_logits, 2, response_token_ids.unsqueeze(-1)).squeeze(-1)# 平均 log probability(整个 response)
    logp_response = response_logp.mean()logps_chosen = compute_logp(prompt, chosen, actor_model)
    logps_rejected = compute_logp(prompt, rejected, actor_model)
    logps_ref_chosen = compute_logp(prompt, chosen, ref_model)
    logps_ref_rejected = compute_logp(prompt, rejected, ref_model)
  2. logits/chosen和logits/rejected: 模型输出的raw score(未进行归一化)求平均

    # 模型输出:logits = [batch_size, seq_len, vocab_size]
    # 获取 chosen 的最后一个 token 的 logit:
    logit_chosen = logits[:, -1, :]  # 通常是这个位置
    logits/chosen = logit_chosen.mean().item()
    # 拿出 chosen response 部分的 token 对应的 logit 向量
    logits_response = logits[:, prompt_len:, :]  # mask 掉 prompt 部分
    logits/chosen = logits_response.mean().item()
  3. reward 计算方法

    chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach()
    rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach()
    reward_accuracies = (chosen_rewards > rejected_rewards).float()
    metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
    metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
    metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
    metrics[f"{prefix}rewards/margins"] = (
    self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
  4. Loss 计算方法

    本次默认使用sigmoidlogratios = chosen_logps - rejected_logpsref_logratios = ref_chosen_logps - ref_rejected_logps                logratios = logratios.to(self.accelerator.device)ref_logratios = ref_logratios.to(self.accelerator.device)logits = logratios - ref_logratios losses = (-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)- F.logsigmoid(-self.beta * logits) * self.label_smoothing )
    其他计算方法如下(后续介绍):"hinge","ipo",
    "exo_pair","nca_pair","robust","bco_pair",
    "sppo_hard","aot","apo_down""aot_pair","apo_zero","discopop",
  5. 关系理解

    指标

    含义

    关系

    logits

    每个 token 的原始输出分数(未归一化)

    模型输出的raw score(未进行归一化)求平均

    logps

    所有 token 的 log 概率之和(对 logit softmax 后求 log,token-wise 累加)

    来自 logits → softmax → log(prob) → sum over tokens

    rewards

    在 logp-based reward 情况下,reward 就是 sum(logps)/len(tokens)

    eval_rewards/chosen == eval_logps/chosen/len(tokens)

  6. 主要关注指标

    指标名

    含义

    影响

    loss

    当前 batch 的 DPO/IPO 损失值

    反映训练是否有效收敛,是否有发散/震荡

    rewards/margins

    reward_chosen - reward_rejected 的平均值

    反映模型区分正负样本的能力是否提升

    rewards/accuracies

    reward_chosen > reward_rejected 的比例

    反映偏好判断正确率是否提高

    logs/chosen& logs/rejected

    每个 sample 的对数似然总和

    趋势变化判断 token-level 拟合趋势

其他思考

1.  logps/chosen是负的合理吗

logps(y_{chosen}|x})logps(y_{chosen}|x}) 是模型对生成chosen回复时,每个token的概率取对数后加总, 由于每一个token的概率 ,所以。p(yt,y<t)∈(0,1),所以logp(yt)<0。 所以累加一段文本后,整个logp通常是一个比较大的负值。

2. reward为负值

因为是 rchosen=logπθ(ychosen|x) ,如果没有额外reward打分模型,则 r=sum(logps)/len(logps)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/87008.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/87008.shtml
英文地址,请注明出处:http://en.pswp.cn/web/87008.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UniApp完全支持快应用QUICKAPP-以及如何采用 Uni 模式开发发行快应用优雅草卓伊凡

UniApp完全支持快应用QUICKAPP-以及如何采用 Uni 模式开发发行快应用优雅草卓伊凡 一、UniApp 对快应用的支持深度 UniApp 已完全支持快应用的开发和发布&#xff0c;具体包括&#xff1a; 两种渲染模式&#xff1a; Webview 渲染&#xff08;快应用 Light 版&#xff09;&a…

js 允许生成特殊的变量名 基于字符集编码混淆的 XSS 绕过漏洞 -- Google 2025 Lost In Transliteration

题目实现了一个字符转换工具 在/file路由用户可以通过 ct 参数自定义 Content-Type // 文件路由 - 提供静态文件服务&#xff08;JS和CSS&#xff09;&#xff0c;支持内容类型验证 app.MapGet("/file", (string filename "", string? ct null, string?…

【仿muduo库实现并发服务器】LoopThreadPool模块

仿muduo库实现并发服务器 1.LoopThread模块1.1成员变量1.2构造函数13线程入口函数1.4获取eventloop对象GetLoop() 2.LoopThreadPool模块2.1成员变量2.2构造函数2.3配置线程数量2.4按照配置数量创建线程2.5依次分配Eventloop对象 1.LoopThread模块 这个模块是为了将EventLoop与…

华为云Flexus+DeepSeek征文|基于Dify构建文本/图像/视频生成工作流

华为云FlexusDeepSeek征文&#xff5c;基于Dify构建文本/图像/视频生成工作流 一、构建文本/图像/视频生成工作流前言二、构建文本/图像/视频生成工作流环境2.1 基于FlexusX实例的Dify平台2.2 基于MaaS的模型API商用服务 三、构建文本/图像/视频生成工作流实战3.1 配置Dify环境…

相机-IMU联合标定:IMU更新频率

文章目录 📚简介⚠️ IMU频率参数错误设置的影响❌ 相机-IMU联合标定失败:Optimization failed!🚀 确定IMU更新频率直接通过 rostopic hz 检查实际频率检查 IMU 驱动或数据手册从 bag 文件统计频率在这里插入图片描述修改 `update_rate` 的注意事项**最终建议****常见问题…

动手实践:如何提取Python代码中的字符串变量的值

要提取Python代码中所有变量类型为字符串的变量的值&#xff0c;但不执行代码&#xff08;避免安全风险&#xff09;&#xff0c;可以通过静态分析代码的抽象语法树&#xff08;AST&#xff09;来实现。以下是完整的解决方案&#xff1a; 本文由「大千AI助手」原创发布&#xf…

Python中字符串isalpha()函数详解

在 Python 中&#xff0c;isalpha() 是字符串&#xff08;string&#xff09;类型的内置方法&#xff0c;用于检查字符串中的所有字符是否都是字母字符&#xff08;alphabetic character&#xff09;。以下是详细说明&#xff1a; 一、基本功能 返回值&#xff1a;布尔值&…

Gradio全解13——MCP详解(4)——TypeScript包命令:npm与npx

Gradio全解13——MCP详解&#xff08;4&#xff09;——TypeScript包命令&#xff1a;npm与npx 第13章 MCP详解13.4 TypeScript包命令&#xff1a;npm与npx13.4.1 概念区分1. npm概念与运行逻辑2. npx概念及特点 13.4.2 操作示例1. 使用npm执行包2. 使用npx执行包3. 常用npm命令…

《推客小程序全链路开发指南:从架构设计到裂变运营》

在移动互联网流量红利逐渐消退的今天&#xff0c;如何低成本获客成为企业营销的核心痛点。推客小程序作为一种基于社交关系的裂变营销工具&#xff0c;正成为企业突破增长瓶颈的利器。本文将为您全面解析推客小程序的开发定制全流程&#xff0c;帮助您打造专属的社交裂变营销平…

中钧科技参加中亚数字经济对话会,引领新疆企业数字化新征程!

6月27 日&#xff0c;乌鲁木齐成为数字经济领域的焦点&#xff0c;中国新疆 - 中亚国家数字经济和数字贸易企业对话会在此盛大举行。 来自中亚国家及新疆数字经济领域的100 余位核心代表齐聚一堂&#xff0c;围绕数字经济时代的机遇、挑战与策略展开深度探讨。 本次对话会由新…

k8s一键部署tongweb企业版7049m6(by why+lqw)

声明 1.此贴仅供参考&#xff0c;请根据自身需求在测试环境测试和修改。 安装准备 1.获取对应的安装包和授权,并将授权和安装包放在同一个目录下 2.docekr已配置远程仓库 3.提前拉取jdk的镜像&#xff08;这里配置了使用openjdk:8&#xff09; 安装 将以下内容复制到k8s_…

Qt 与 Halcon 联合开发六:基于海康SDK设计完整的相机类【附源码】

在现代工业自动化、机器人视觉、等领域&#xff0c;相机模块的作用至关重要。通过相机模块采集到的图像数据&#xff0c;我们能够进行一系列的图像处理和分析。为了高效地控制相机和处理图像&#xff0c;本篇文章将介绍如何使用Qt和Halcon联合开发一个相机模块&#xff0c;帮助…

第7篇:Gin模板引擎——服务端页面渲染

作者:GO兔 博客:https://luckxgo.cn 分享大家都看得懂的博客 引言 在Web开发中&#xff0c;服务端页面渲染(SSR)依然是构建动态网页的重要方式。Gin框架虽然以API开发见长&#xff0c;但也内置了强大的模板引擎支持&#xff0c;基于Go标准库的html/template包实现。本文将深入…

RagFlow 源码部署启动指南

一、环境准备 1. 安装 uv 和 pre-commit 如果已安装&#xff0c;可跳过。推荐使用官方方式安装&#xff0c;避免报错&#xff1a; pipx install uv pre-commit export UV_INDEXhttps://mirrors.aliyun.com/pypi/simple安装报错 使用清华源安装&#xff1a; pipx install uv…

【Python基础】12 闲谈分享:Python用于无人驾驶的未来

引言&#xff1a;一个程序员的自动驾驶梦想 还记得2016年的那个秋天&#xff0c;我第一次坐进特斯拉Model S的驾驶座&#xff0c;体验Autopilot功能。当方向盘开始自己转动&#xff0c;车辆在高速公路上自动跟随前车时&#xff0c;我的内心涌起了一种奇妙的感觉——这不就是我…

为什么js是单线程?

js单线程&#xff0c;同一时间只能做一件事 。js的单线程 主要与它的用途有关。作为浏览器脚本语言&#xff0c;js的主要用途是与用户互动&#xff0c;以及操作DOM。这决定了它只能是单线程&#xff0c;否则会带来很复杂的同步问题。如果js同时有两个线程&#xff0c;一个线程在…

DVWA靶场通关笔记-文件包含(Medium级别 9种渗透方法)

目录 一、文件包含 1、原因 2、危害 3、防范措施 二、代码审计&#xff08;Medium级别&#xff09; 1、渗透准备 &#xff08;1&#xff09;配置php.ini &#xff08;2&#xff09;file1.php &#xff08;3&#xff09;file2.php &#xff08;4&#xff09;file3.php…

飞云翻倍布林(翻倍密码系统四线布林版)双安全系统+均价趋势指标+日线周线MACD,组合操盘技术图文分享

如上图组合操盘套装指标&#xff0c;主图指标-翻倍密码系统四线布林版-飞云翻倍布林。副图指标1-均价趋势指标&#xff0c;跟踪市场均价走势和趋势&#xff1b;副图指标2-日线周线MACD指标&#xff0c;跟踪日线和周线两个级别的MACD多空走势以及共振与否。 主图指标-飞云翻倍布…

《汇编语言:基于X86处理器》第6章 条件处理(1)

本章向程序员的汇编语言工具箱中引入一个重要的内容&#xff0c;使得编写出来的程序具备作决策的功能。几乎所有的程序都需要这种能力。首先&#xff0c;介绍布尔操作&#xff0c;由于能影响CPU状态标志&#xff0c;它们是所有条件指令的核心。然后&#xff0c;说明怎样使用演绎…

【分治思想】归并排序 与 逆序对

归并排序 归并排序是一种分治算法&#xff0c;怎么分&#xff0c;怎么治&#xff1f; 分&#xff1a;通过递归不断把数组分成两半&#xff0c;直到每个子数组只剩 1 个元素&#xff08;天然有序&#xff09;治&#xff1a;把两个已经排好序的子数组合并成一个有序数组。 把问…