TensorFlow深度学习实战——使用Hugging Face构建Transformer模型

    • 0. 前言
    • 1. 安装 Hugging Face
    • 2. 文本生成
    • 3. 自动模型选择和自动分词
    • 4. 命名实体识别
    • 5. 摘要生成
    • 6. 模型微调
    • 相关链接

0. 前言

除了需要实现特定的自定义结构,或者想要了解 Transformer 工作原理外,从零开始实现 Transformer 并不是最佳选择,和其它编程实践一样,通常并不需要从头开始造轮子。只有想要理解 Transformer 架构的内部细节,或者修改 Transformer 架构以得到新的变体时才需要从零开始构建。有很多优秀的库提供高质量的 Transformer 解决方案,Hugging Face 是其中的代表之一,它提供了一些构建 Transformer 的高效工具:

  • Hugging Face 提供了一个通用的 API 来处理多种 Transformer 架构
  • Hugging Face 不仅提供了基础模型,还提供了带有不同类型“头”的模型来处理特定任务(例如,对于 BERT 架构,提供了 TFBertModel,用于情感分析的 TFBertForSequenceClassification,用于命名实体识别的 TFBertForTokenClassification,以及用于问答的 TFBertForQuestionAnswering 等)
  • 可以通过使用 Hugging Face 提供的预训练权重来轻松创建自定义的网络,例如,使用 TFBertForPreTraining
  • 除了 pipeline() 方法,还可以以常规方式定义模型,使用 fit() 进行训练,使用 predict() 进行推理,就像普通的 TensorFlow 模型一样

1. 安装 Hugging Face

和其它第三方库一样,可以使用 pip 命令安装 Hugging Face 库:

$ pip install transformers[tf]

然后,通过下载一个用于情感分析的预训练模型来验证 Hugging Face 库是否安装成功:

$ python -c "from transformers import pipeline; print(pipeline('sentiment-analysis')('we love you'))"

如果成功安装,将显示如下输出结果:

[{'label': 'POSITIVE', 'score': 0.9998704791069031}]

接下来,介绍如何使用 Hugging Face 解决具体任务。

2. 文本生成

在本节中,我们将使用 GPT-2 进行自然语言生成,这是一个生成自然语言输出的过程。

(1) 使用 GPT-2 生成文本:

from transformers import pipeline
generator = pipeline(task="text-generation")

(2) 模型下载完成后,将文本传递给生成器,观察结果:

generator("Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone")generator ("The original theory of relativity is based upon the premise that all coordinate systems in relative uniform translatory motion to each other are equally valid and equivalent ")generator ("It takes a great deal of bravery to stand up to our enemies")

生成结果

3. 自动模型选择和自动分词

Hugging Face 能够尽可能帮助自动化多个步骤。

(1) 可以非常简单的从数十个可用的预训练模型中导入可用模型:

from transformers import TFAutoModelForSequenceClassification
model = TFAutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")

可以在下游任务上训练模型,以便用于预测和推理。

(2) 可以使用 AutoTokenizer 将单词转换为模型使用的词元:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
sequence = "The original theory of relativity is based upon the premise that all coordinate systems"
print(tokenizer(sequence))

输出结果

4. 命名实体识别

命名实体识别 (Named Entity Recognition, NER) 是经典的自然语言处理任务。命名实体识别也称实体识别 (entity identification)、实体分块 (entity chunking) 或实体提取 (entity extraction),是信息提取的一个子任务,旨在定位和分类在非结构化文本中提到的命名实体,将其划分为预定义的类别,例如人名、组织、地点、时间表达、数量、货币值和百分比等。接下来,我们使用 Hugging Face 完成命名实体识别任务。

(1) 创建一个 NER 管道:

from transformers import pipeline
ner_pipe = pipeline("ner")
sequence = """Mr. and Mrs. Dursley, of number four, Privet Drive, were
proud to say that they were perfectly normal, thank you very much."""
for entity in ner_pipe(sequence):print(entity)

(2) 结果如下所示,其中实体已经被识别出来:
识别结果

命名实体识别可以理解九个不同的类别:

  • O: 不属于命名实体
  • B-MIS: 在另一个杂项实体后开始的杂项实体
  • I-MIS: 杂项实体
  • B-PER: 在另一个人名后面开始的人名
  • I-PER: 人名
  • B-ORG: 在另一个组织后面开始的组织
  • I-ORG: 组织
  • B-LOC: 在另一个地点后面开始的地点
  • I-LOC: 地点

这些实体在 CoNLL-2003 数据集中定义,并由 Hugging Face 自动选择。

5. 摘要生成

摘要生成,是指用简短而清晰的形式表达有关某事或某人的最重要事实或观点。Hugging Face 使用 T5 模型作为完成此任务的默认模型。

(1) 首先,使用默认的 T5 small 模型创建一个摘要生成管道:

from transformers import pipeline
summarizer = pipeline("summarization")
ARTICLE = """Mr. and Mrs.Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr.Dursley was the director of a firm called Grunnings, which made drills.He was a big, beefy man with hardly any neck, although he did have a very large mustache.Mrs.Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors.The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere"""
print(summarizer(ARTICLE, max_length=130, min_length=30, do_sample=False))

输出结果如下:

输出结果

(2) 如果想要更换使用不同的模型,只需修改参数 model

summarizer = pipeline("summarization", model='t5-base')

输出结果如下:

输出结果

6. 模型微调

一种常见的 Transformer 使用模式是先使用预训练的大语言模型 (Large Language Model, LLM),然后对模型进行微调以适应特定的下游任务。微调步骤将基于自定义数据集,而预训练则是在非常大的数据集上进行的。这种策略的优点在于节省计算成本,此外,微调令我们使用最先进的模型,而不需要从头开始训练一个模型。接下来,我们介绍如何使用 TensorFlow 进行模型微调,使用的预训练模型是 bert-base-cased,在 Yelp Reviews 数据集上进行微调。
本节使用 datasets 库加载数据集,datasets 库是由 Hugging Face 提供的一个非常强大的工具,专门用于加载、处理和分享数据集,使用 pip 命令安装 datasets 库:

$ pip install datasets

(1) 首先,加载并对 Yelp 数据集进行分词:

from datasets import load_datasetdataset = load_dataset("yelp_review_full")
from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained("bert-base-cased")def tokenize_function(examples):return tokenizer(examples["text"], padding="max_length", truncation=True)tokenized_datasets = dataset.map(tokenize_function, batched=True)small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

(2) 然后,将数据集转换为 TensorFlow 格式:

from transformers import DefaultDataCollator
data_collator = DefaultDataCollator(return_tensors="tf")# convert the tokenized datasets to TensorFlow datasetstf_train_dataset = small_train_dataset.to_tf_dataset(columns=["attention_mask", "input_ids", "token_type_ids"],label_cols=["labels"],shuffle=True,collate_fn=data_collator,batch_size=8,
)tf_validation_dataset = small_eval_dataset.to_tf_dataset(columns=["attention_mask", "input_ids", "token_type_ids"],label_cols=["labels"],shuffle=False,collate_fn=data_collator,batch_size=8,
)

(3) 使用 TFAutoModelForSequenceClassification,选择 bert-base-cased

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassificationmodel = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

(4) 最后,微调模型的方法是使用 TensorFlow 中的标准训练方式,通过编译模型并使用 fit() 进行训练:

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=tf.metrics.SparseCategoricalAccuracy(),
)model.fit(tf_train_dataset, validation_data=tf_validation_dataset, epochs=3)

相关链接

TensorFlow深度学习实战(1)——神经网络与模型训练过程详解
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
TensorFlow深度学习实战(3)——深度学习中常用激活函数详解
TensorFlow深度学习实战(4)——正则化技术详解
TensorFlow深度学习实战(5)——神经网络性能优化技术详解
TensorFlow深度学习实战(6)——回归分析详解
TensorFlow深度学习实战(7)——分类任务详解
TensorFlow深度学习实战(8)——卷积神经网络
TensorFlow深度学习实战(12)——词嵌入技术详解
TensorFlow深度学习实战(13)——神经嵌入详解
TensorFlow深度学习实战(14)——循环神经网络详解
TensorFlow深度学习实战(15)——编码器-解码器架构
TensorFlow深度学习实战(16)——注意力机制详解
TensorFlow深度学习实战(21)——Transformer架构详解与实现
TensorFlow深度学习实战(22)——从零开始实现Transformer机器翻译
TensorFlow深度学习实战——Transformer变体模型
TensorFlow深度学习实战——Transformer模型评价指标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/84884.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/84884.shtml
英文地址,请注明出处:http://en.pswp.cn/web/84884.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SAP-ABAP:SAP全模块的架构化解析,涵盖核心功能、行业方案及技术平台

一、核心业务模块(Logistics & Operations) 模块代号核心功能典型流程关键事务码物料管理MM采购/库存/发票校验采购到付款 (P2P)ME21N(采购订单), MI31(库存盘点)销售与分销SD订单/定价/发货/开票订单…

实时预警!机场机坪井室无线智能液位监测系统助力安全降本

某沿海机场因地处多雨区域,每年雨季均面临排水系统超负荷运行压力。经勘测发现,5个井室因长期遭受地下水渗透侵蚀,井壁出现细微结构性裂缝,导致内部水位异常升高。作为机坪地下管网系统的核心节点,这些井室承担着雨水导…

边云协同 AI 视频分析系统设计方案

目录 一、项目背景与目标 二、系统架构概述 总体架构图 三、ER 图(核心数据库设计) 实体关系图简述 数据表设计(简要) 四、模型结构图(边缘云端AI推理架构) 边缘模型(YOLOv5-tiny/PP-YO…

vue3整合element-plus

为项目命名 选择vue 框架 选择TS 启动测试: npm run dev 开始整合 element-plus npm install element-plus --save npm install unplugin-vue-components unplugin vitejs/plugin-vue --save-dev 修改main.ts import { createApp } from vue import ./style.cs…

【AI 测试】测试用例设计:人工智能语言大模型性能测试用例设计

目录 一、性能测试可视化架构图 (1)测试整体架构图 (2)测试体系架构图 (3)测试流程时序图 二、性能测试架构总览 (1)性能测试功能点 (2)测试环境要…

Windsurf SWE-1模型评析:软件工程的AI革命

引言 软件开发领域正经历着前所未有的变革,AI辅助编程工具层出不穷,但大多数仅专注于代码生成这一环节。Windsurf公司近期推出的SWE-1系列模型打破了这一局限,首次将AI应用扩展至软件工程的全流程。这一举措不仅反映了行业对AI工具认知的深化…

Qt for OpenHarmony 编译鸿蒙调用的动态库

简介 Qt for Harmony​ 是跨平台开发框架 ​Qt​ 与华为 ​OpenHarmony​ 操作系统的深度集成方案,由 Qt Group 与华为联合推动。其核心目标是为开发者提供一套高效工具链,实现 ​​“一次开发,多端部署”​,加速 OpenHarmony 生…

退休时,按最低基数补缴医疗保险15年大概需要多少钱

在南京退休时,如果医保缴费年限不足(男需满25年/女需满20年),需补缴差额年限。若按最低基数一次性补缴15年医保,费用估算如下(以2024年政策为例): 一、补缴金额计算公式 总补缴费用…

wireshark过滤显示rtmp协议

wireshark中抓包显示的数据报文中,明明可以看到有 rtmp 协议的报文,但是过滤的时候却显示一条都没有 查看选项中的配置,已经没有 RTMP 这个协议了,已经被 RTMPT 替换了,过滤框中输入 rtmpt 过滤即可

《哈希表》K倍区间(解题报告)

文章目录 零、题目描述一、算法概述二、算法思路三、代码实现四、算法解释五、复杂度分析 零、题目描述 题目链接:K倍区间 一、算法概述 计算子数组和能被k整除的子数组数量的算法。通过前缀和与哈希表的结合,高效地统计满足条件的子数组。  需要注…

OpenShift 在 Kubernetes 多出的功能中,哪些开源?

OpenShift 在 Kubernetes 基础上增加的功能中,部分组件是开源的(代码可公开访问),而另一些则是 Red Hat 专有(闭源)。以下是详细分类: 1. 完全开源的功能(代码可查) 这些…

【每天一个知识点】CITE-seq 技术

一、技术背景 单细胞RNA测序(scRNA-seq)自问世以来,极大推动了细胞异质性和组织复杂性的研究。但RNA水平并不能完全代表蛋白质水平,因为蛋白质的表达受转录后调控、翻译效率及蛋白降解等多种因素影响。此外,许多细胞类…

中文Windows系统下程序输出重定向乱码问题解决方案

导言 最近我在用 Rust 开发时,遇到了一个让人头疼的问题:运行 cargo run -- version Cargo.toml > output.txt 将输出重定向到文件后,打开 output.txt 却发现里面全是乱码!我的程序确实是UTF8但是输出的文件却是UTF16LE编码的…

Python管理工具UV

常用 UV 命令 安装 pip install uv 版本相关 uv python list 打印所有uv支持的python版本uv python install cpython-3.12 安装指定的python版本uv run -p 3.12 test.py 用指定的python版本运行python代码uv run -p 3.12 python 进入python执行环境。假如输入的版本是一个本…

论文略读:ASurvey on Intent-aware Recommender Systems

202406 arxiv 推荐系统在许多现代在线服务中发挥着关键作用,例如电子商务或媒体流服务,它们能够为消费者和服务提供商创造巨大的价值。因此,过去几十年来,研究人员提出了大量生成个性化推荐的技术方法。传统算法——从早期的 Gro…

Neo4j 中存储和查询数组数据的完整指南

Neo4j 中存储和查询数组数据的完整指南 图形数据库 Neo4j 不仅擅长处理节点和关系,还提供了强大的数组(Array)存储和操作能力。本文将全面介绍如何在 Neo4j 中高效地使用数组,包括存储、查询、优化以及实际应用场景。 数组在 Neo4j 中的基本使用 数组…

Android 编译和打包image镜像流程

1. 编译命令 source build/envsetup.sh lunch aosp_car_arm64-userdebug make2. 编译流程 source build/envsetup.sh 定义一些函数的环境变量,如 lunchvalidate_current_shell,确认 shell 环境set_global_paths,设置环境变量 ANDROID_GLOB…

MySQL:SQL 慢查询优化的技术指南

1、简述 在 Java 后端开发中,数据库是系统性能瓶颈的高发地带,而 慢 SQL 查询 往往是系统响应迟缓的“罪魁祸首”。本文将全面梳理慢 SQL 的优化思路,并结合 Java 示例进行实战演练。 2、慢查询的常见表现 慢查询通常表现为: 接…

leetcode543-二叉树的直径

leetcode 543 思路 路径长度计算:任意两个节点之间的路径长度,等于它们的最低公共祖先到它们各自的深度之和递归遍历:通过后序遍历(左右根)计算每个节点的左右子树深度,并更新全局最大直径深度与直径的关…

详解main的参数并实现读取文件

在 C 语言中,main函数的参数argc和argv用于接收命令行传入的参数 main 函数的两个参数 int main(int argc, char* argv[]) 假设顾客通过手机 APP 点餐,订单信息会被传递给餐厅的处理系统(也就是你的程序)。 订单信息结构 argc…