Transformer实战(18)——微调Transformer语言模型进行回归分析

    • 0. 前言
    • 1. 回归模型
    • 2. 数据处理
    • 3. 模型构建与训练
    • 4. 模型推理
    • 小结
    • 系列链接

0. 前言

在自然语言处理领域中,预训练 Transformer 模型不仅能胜任离散类别预测,也可用于连续数值回归任务。本节介绍了如何将 DistilBert 转变为回归模型,为模型赋予预测连续相似度分值的能力。我们以 GLUE 基准中的语义文本相似度 (STS-B) 数据集为例,详细介绍配置 DistilBertConfig、加载数据集、分词并构建 TrainingArguments,并定义 Pearson/Spearman 相关系数等回归指标。

1. 回归模型

回归模型通常最后一层只有一个神经元,它不会通过 softmax 逻辑回归处理,而是进行归一化。为了定义模型并在顶部添加一个单神经元的输出层,有两种方法:直接在 BERT.from_pretrained() 方法中使用参数 num_labels=1,或者通过 config 对象传递此信息。首先需要从预训练模型的 config 对象中复制这些信息:

from transformers import DistilBertConfig, DistilBertTokenizerFast, DistilBertForSequenceClassification
MODEL_PATH='distilbert-base-uncased'
config = DistilBertConfig.from_pretrained(MODEL_PATH, num_labels=1)
tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_PATH)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH, config=config)

由于我们设置了 num_labels=1 参数,因此预训练模型的输出层包含一个神经元。接下来,准备数据集微调模型进行回归分析。
在本节中,我们将使用语义文本相似度基准 (STS-B) 数据集,它包含从新闻标题等多种内容中提取的句子对。每对句子都有一个从 15 的相似度评分,我们的任务是微调 DistilBert 模型以预测这些评分,并使用 Pearson/Spearman 相关系数来评估模型。

2. 数据处理

(1) 加载数据。将原始数据分为三部分,但由于测试集没有标签,所以我们可以将验证数据分为两部分:

import datasets
from datasets import load_dataset
stsb_train= load_dataset('glue','stsb', split="train")
stsb_validation = load_dataset('glue','stsb', split="validation")
stsb_validation=stsb_validation.shuffle(seed=42)
stsb_val= datasets.Dataset.from_dict(stsb_validation[:750])
stsb_test= datasets.Dataset.from_dict(stsb_validation[750:])

(2) 使用 pandas 来整理 stsb_train 训练数据:

import pandas as pd
pd.DataFrame(stsb_train)

整理后的训练数据样本如下:

数据样本

(3) 查看三个数据集的形状:

stsb_train.shape, stsb_val.shape, stsb_test.shape
# ((5749, 4), (750, 4), (750, 4))

(4) 对数据集进行分词处理:

enc_train = stsb_train.map(lambda e: tokenizer( e['sentence1'],e['sentence2'], padding=True, truncation=True), batched=True, batch_size=1000) 
enc_val =   stsb_val.map(lambda e: tokenizer( e['sentence1'],e['sentence2'], padding=True, truncation=True), batched=True, batch_size=1000) 
enc_test =  stsb_test.map(lambda e: tokenizer( e['sentence1'],e['sentence2'], padding=True, truncation=True), batched=True, batch_size=1000) 

(5) 分词器将两个句子用 [SEP] 分隔符连接,并为句子对生成 input_idsattention_mask

pd.DataFrame(enc_train)

输出结果如下:

输出结果

3. 模型构建与训练

(1)TrainingArguments 类中定义参数集:

from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(# The output directory where the model predictions and checkpoints will be writtenoutput_dir='./stsb-model', do_train=True,do_eval=True,#  The number of epochs, defaults to 3.0 num_train_epochs=3,              per_device_train_batch_size=32,  per_device_eval_batch_size=64,# Number of steps used for a linear warmupwarmup_steps=100,                weight_decay=0.01,# TensorBoard log directorylogging_strategy='steps',                logging_dir='./logs',            logging_steps=50,# other options : no, stepsevaluation_strategy="epoch",save_strategy="epoch",fp16=True,load_best_model_at_end=True
)

(2) 定义 compute_metrics 函数。其中,评估指标基于皮尔逊相关系数 (Pearson correlation coefficient) 和斯皮尔曼等级相关系数 (Spearman’s rank correlation) 法,此外,还提供均方误差 (Mean Square Error, MSE)、均方根误差 (Root Mean Square Error, RMSE) 和平均绝对误差 (Mean Absolute Error, MAE) 等常用的回归模型评估指标:

from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
import numpy as np
from scipy.stats import pearsonr
from scipy.stats import spearmanr
def compute_metrics(pred):preds = np.squeeze(pred.predictions) return {"MSE": ((preds - pred.label_ids) ** 2).mean().item(),"RMSE": (np.sqrt ((  (preds - pred.label_ids) ** 2).mean())).item(),"MAE": (np.abs(preds - pred.label_ids)).mean().item(),"Pearson" : pearsonr(preds,pred.label_ids)[0],"Spearman's Rank" : spearmanr(preds,pred.label_ids)[0]}

(3) 实例化 Trainer 对象:

trainer = Trainer(model=model,args=training_args,train_dataset=enc_train,eval_dataset=enc_val,compute_metrics=compute_metrics,tokenizer=tokenizer)

(4) 运行训练过程:

train_result = trainer.train()
metrics = train_result.metrics

输出结果如下:

输出结果
最佳验证损失为 0.542073,评估最佳权重模型:

q=[trainer.evaluate(eval_dataset=data) for data in [enc_train, enc_val, enc_test]]
pd.DataFrame(q, index=["train","val","test"]).iloc[:,:6]

输出结果如下:

输出结果

在测试数据集上,PearsonSpearman 相关系数得分分别为 87.6987.64

4. 模型推理

(1) 运行模型进行推理。以下面两个意义相同的句子为例,将它们输入模型:

s1,s2="A plane is taking off.",	"An air plane is taking off."
encoding = tokenizer(s1,s2, return_tensors='pt', padding=True, truncation=True, max_length=512)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(input_ids, attention_mask=attention_mask)
outputs.logits.item()
# 4.57421875

(2) 接下来,将语义不同的句子对输入模型:

s1,s2="The men are playing soccer.",	"A man is riding a motorcycle."
encoding = tokenizer("hey how are you there","hey how are you", return_tensors='pt', padding=True, truncation=True, max_length=512)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(input_ids, attention_mask=attention_mask)
outputs.logits.item()
# 3.1953125

(3) 最后,保存模型:

model_path = "sentence-pair-regression-model"
trainer.save_model(model_path)
tokenizer.save_pretrained(model_path)

小结

本节介绍了如何基于预训练 DistilBert 架构完成语义相似度回归分析。首先,通过修改配置或传参的方式,为模型顶层添加单神经元回归头;随后,借助 STS-B 数据集构建训练、验证与测试集,并应用分词器生成模型输入。接着,使用 Trainer 框架与自定义的 compute_metrics 函数,对模型在 MSERMSEMAEPearsonSpearman 相关性等多维度指标上进行评估,验证了微调方法在回归任务中的有效性。

系列链接

Transformer实战(1)——词嵌入技术详解
Transformer实战(2)——循环神经网络详解
Transformer实战(3)——从词袋模型到Transformer:NLP技术演进
Transformer实战(4)——从零开始构建Transformer
Transformer实战(5)——Hugging Face环境配置与应用详解
Transformer实战(6)——Transformer模型性能评估
Transformer实战(7)——datasets库核心功能解析
Transformer实战(8)——BERT模型详解与实现
Transformer实战(9)——Transformer分词算法详解
Transformer实战(10)——生成式语言模型 (Generative Language Model, GLM)
Transformer实战(11)——从零开始构建GPT模型
Transformer实战(12)——基于Transformer的文本到文本模型
Transformer实战(13)——从零开始训练GPT-2语言模型
Transformer实战(14)——微调Transformer语言模型用于文本分类
Transformer实战(15)——使用PyTorch微调Transformer语言模型
Transformer实战(16)——微调Transformer语言模型用于多类别文本分类
Transformer实战(17)——微调Transformer语言模型进行多标签文本分类

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/96594.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/96594.shtml
英文地址,请注明出处:http://en.pswp.cn/web/96594.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】【实战向】Linux 进程替换避坑指南:从理解 bash 阻塞等待,到亲手实现能执行 ls/cd 的 Shell

前言:欢迎各位光临本博客,这里小编带你直接手撕,文章并不复杂,愿诸君耐其心性,忘却杂尘,道有所长!!!! IF’Maxue:个人主页🔥 个人专栏…

linux常用命令 (3)——系统包管理

博客主页:christine-rr-CSDN博客 ​​​​​ ​​ hi,大家好,我是christine-rr ! 今天来分享一下linux常用命令——系统包管理 目录linux常用命令---系统包管理(一)Debian 系发行版(Ubuntu、Debian、Linux …

YOLOv8 mac-intel芯片 部署指南

🚀 在 Jupyter Notebook 和 PyCharm 中使用 Conda 虚拟环境(YOLOv8 部署指南,Python 3.9) YOLOv8 是 Ultralytics 开源的最新目标检测模型,轻量高效,支持分类、检测、分割等多种任务。 在 Mac(…

【高等数学】第十一章 曲线积分与曲面积分——第六节 高斯公式 通量与散度

上一节:【高等数学】第十一章 曲线积分与曲面积分——第五节 对坐标的曲面积分 总目录:【高等数学】 目录 文章目录1. 高斯公式2. 沿任意闭曲面的曲面积分为零的条件3. 通量与散度1. 高斯公式 设空间区域ΩΩΩ是由分片光滑的闭曲面ΣΣΣ所围成&#x…

IDEA试用过期,无法登录,重置方法

IDEA过期,重置方法: IntelliJ IDEA 2024.2.0.2 (亲测有效) 最新Idea重置办法!: 方法一: 1、删除C:\Users\{用户名}\AppData\Local\JetBrains\IntelliJIdea2024.2 下所有文件(注意:是子目录全部删除) 2、删除C:\Users\{用户名}\App…

创建用户自定义桥接网络并连接容器

1.创建用户自定义的 alpine-net 网络[roothost1 ~]# docker network create --driver bridge alpine-net 9f6d634e6bd7327163a9d83023e435da6d61bc6cf04c9d96001d1b64eefe4a712.列出 Docker 主机上的网络[roothost1 ~]# docker network ls NETWORK ID NAME DRIVER …

Vue3 + Vite + Element Plus web转为 Electron 应用,解决无法登录、隐藏自定义导航栏

如何在vue3 Vite Element Plus搭好的架构下转为 electron应用呢? https://www.electronjs.org/zh/docs/latest/官方文档 https://www.electronjs.org/zh/docs/latest/ 第一步:安装 electron相关依赖 npm install electron electron-builder concurr…

qt QAreaLegendMarker详解

1. 概述QAreaLegendMarker 是 Qt Charts 模块中的一部分,用于在图例(Legend)中表示 QAreaSeries 的标记。它负责显示区域图的图例项,通常包含区域颜色样例和对应的描述文字。图例标记和对应的区域图关联,显示区域的名称…

linux 函数 kstrtoul

kstrtoul 函数概述 kstrtoul 是 Linux 内核中的一个函数&#xff0c;用于将字符串转换为无符号长整型&#xff08;unsigned long&#xff09;。该函数定义在 <linux/kernel.h> 头文件中&#xff0c;常用于内核模块中解析用户空间传递的字符串参数。 函数原型 int kstrtou…

LLM(三)

一、人类反馈的强化学习&#xff08;RLHF&#xff09;微调的目标是通过指令&#xff0c;包括路径方法&#xff0c;进一步训练你的模型&#xff0c;使他们更好地理解人类的提示&#xff0c;并生成更像人类的回应。RLHF&#xff1a;使用人类反馈微调型语言模型&#xff0c;使用强…

DPO vs PPO,偏好优化的两条技术路径

1. 背景在大模型对齐&#xff08;alignment&#xff09;里&#xff0c;常见的两类方法是&#xff1a;PPO&#xff1a;强化学习经典算法&#xff0c;OpenAI 在 RLHF 里用它来“用奖励模型更新策略”。DPO&#xff1a;2023 年提出的新方法&#xff08;参考论文《Direct Preferenc…

BLE6.0信道探测,如何重构物联网设备的距离感知逻辑?

在物联网&#xff08;IoT&#xff09;无线通信技术快速渗透的当下&#xff0c;实现人与物、物与物之间对物理距离的感知响应能力已成为提升设备智能高度与人们交互体验的关键所在。当智能冰箱感知用户靠近而主动亮屏显示内部果蔬时、当门禁系统感知到授权人士靠近而主动开门时、…

【计算机 UTF-8 转换为本地编码的含义】

UTF-8 转换为本地编码的含义 详细解释一下"UTF-8转换为本地编码"的含义以及为什么在处理中文时这很重要。 基本概念 UTF-8 编码 国际标准&#xff1a;UTF-8 是一种能够表示世界上几乎所有字符的 Unicode 编码方式跨平台兼容&#xff1a;无论在哪里&#xff0c;UTF-8 …

4.6 变体

1.变体简介 2.为什么需要变体 3.变体是如何产生的 4.变体带来的麻烦 5.multi_compile和shader_feature1.变体简介 比如我们开了一家餐厅, 你有一本万能的菜单(Shader源代码), 上面包含了所有可能的菜式; 但是顾客每次来点餐时, 不可能将整本菜单都做一遍, 他们会根据今天有没有…

猿辅导Android开发面试题及参考答案(下)

为什么开发中要使用线程池,而不是直接创建线程(如控制线程数量、复用线程、降低开销)? 开发中优先使用线程池而非直接创建线程,核心原因是线程池能优化线程管理、降低资源消耗、提高系统稳定性,而直接创建线程存在难以解决的缺陷,具体如下: 控制线程数量,避免资源耗尽…

【网络通信】IP 地址深度解析:从技术原理到企业级应用​

IP 地址深度解析&#xff1a;从技术原理到企业级应用​ 文章目录IP 地址深度解析&#xff1a;从技术原理到企业级应用​前言一、基础认知&#xff1a;IP 地址的技术定位与核心特性​1.1 定义与网络层角色1.2 核心属性与表示法深化二、地址分类&#xff1a;从类别划分到无类别路…

grafana实践

一、如何找到grafana的插件目录 whereis grafana grafana: /etc/grafana /usr/share/grafana插件安装目录、默认安装目录&#xff1a; 把vertamedia-clickhouse-datasource-3.4.4.zip解压到下面目录&#xff0c;然后重启就可以了 /var/lib/grafana/plugins# 6. 设置权限 sudo …

uniapp 文件查找失败:main.js

重装HbuilderX vue.config.js 的 配置 有问题main.js 框架能自动识别 到&#xff0c;不用多余的配置

KEIL烧录时提示“SWD/JTAG communication failure”的解决方法

最新在使用JTAG仿真器串口下载调试程序时&#xff0c;老是下载不成功&#xff0c;识别不到芯片&#xff0c;我尝试重启keil5或者重新插拔仿真器连接线、甚至重启电脑也都不行&#xff0c;每次下载程序都提示如下信息&#xff1a;在确定硬件连接没有问题之后&#xff0c;就开始分…

红日靶场(三)——个人笔记

环境搭建 添加一张网卡&#xff08;仅主机模式&#xff09;&#xff0c;192.168.93.0/24 网段 开启centos&#xff0c;第一次运行&#xff0c;重启网络服务 service network restart192.168.43.57/24&#xff08;外网ip&#xff09; 192.168.93.100/24&#xff08;内网ip&am…