目录

前言

一、什么是模型微调(Fine-tuning)?

二、预训练 vs 微调:什么关系?

三、微调的基本流程(以BERT为例)

1️⃣ 准备数据

2️⃣ 加载预训练模型和分词器

3️⃣ 数据编码与加载

4️⃣ 定义优化器

5️⃣ 开始训练

6️⃣ 评估与保存模型

四、是否要冻结 BERT 层?

 五、完整训练示例代码

5.1 环境依赖

5.2 执行代码

总结:微调的优势


前言

在自然语言处理(NLP)快速发展的今天,预训练模型如 BERT 成为了众多任务的基础。但光有预训练模型并不能解决所有问题,模型微调 的技术应运而生。它让通用模型具备了“专才”的能力,使其能更好地服务于特定任务,如情感分析、问答系统、命名实体识别等。

本文将带你快速理解——什么是模型微调,它的基本流程又是怎样的?


一、什么是模型微调(Fine-tuning)?

✅ 概念通俗解释:

微调,就是在别人学得很好的“通用知识”上,加入你自己的“专业训练”。

具体来说,像 BERT 这样的预训练语言模型已经通过大规模语料学习了大量语言规律,比如语法结构、词语搭配等。我们不需要从头训练它,而是在此基础上继续用小规模、特定领域的数据进行训练,让模型更好地完成某个具体任务。


二、预训练 vs 微调:什么关系?

阶段目标数据类型举例
预训练(Pre-training)学习通用语言知识大规模通用语料维基百科、图书馆语料
微调(Fine-tuning)适应特定任务少量任务特定数据产品评论、医疗文本、法律文书

一句话总结:
🔁 预训练是“打基础”,微调是“练专业”。


三、微调的基本流程(以BERT为例)

让我们以“使用 BERT 进行情感分析”为例,梳理整个微调流程:

1️⃣ 准备数据

我们需要将文本和标签准备好,通常是一个 CSV 文件,比如:

评论内容情感标签
这部电影太好看了!正面
烂片,浪费时间。负面

我们会将“正面”转换为 1,负面为 0,方便模型学习。


2️⃣ 加载预训练模型和分词器

from transformers import BertTokenizer, BertForSequenceClassificationtokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2)

 此时模型的主体结构已经包含了 BERT 和一个分类头(Classification Head)。


3️⃣ 数据编码与加载

使用分词器将文本转为模型输入格式:

tokens = tokenizer("这部电影太好看了!", padding='max_length', truncation=True, return_tensors="pt")

你还需要构建自定义数据集类(Dataset),并使用 DataLoader 加载:

from torch.utils.data import DataLoadertrain_loader = DataLoader(my_dataset, batch_size=16, shuffle=True)

4️⃣ 定义优化器

from transformers import AdamWoptimizer = AdamW(model.parameters(), lr=5e-5)

 ▲优化器的作用是:根据损失函数的值,自动调整模型的参数,使模型表现越来越好。

▲通俗理解

优化器就像你走路的“策略”:
它告诉你“往哪边走,走多快,怎么避开障碍”,最终尽可能走到山底。

 ▲优化器做了什么?

神经网络训练时,每一轮都会:

  1. 计算当前模型的预测误差(损失函数 loss)

  2. 反向传播得到每个参数的梯度(方向)

  3. 👉 优化器根据梯度,更新参数的值

就像你爬山时,不断踩点 → 看地形 → 决定下一个落脚点。

组件类比作用
损失函数地图高度告诉你你离目标有多远
梯度当前坡度告诉你往哪里走
优化器走路策略告诉你怎么调整步伐走得更快更稳

5️⃣ 开始训练

model.train()
for batch in train_loader:outputs = model(**batch)loss = outputs.lossloss.backward()optimizer.step()optimizer.zero_grad()

通常我们会训练几个 epoch,让模型逐渐学会如何从文本中识别情感。


6️⃣ 评估与保存模型

训练完成后,我们可以在验证集上评估准确率,并保存模型:

torch.save(model.state_dict(), "bert_sentiment.pth")

四、是否要冻结 BERT 层?

微调过程中,有两种策略:

  • 全模型微调(默认): 所有 BERT 层和分类头都参与训练。效果通常更好,但对显存要求高。

  • 冻结 BERT,仅训练分类头: 保持 BERT 权重不变,只训练新加的分类层。适合数据量小或设备受限的场景。

冻结代码示例:

for param in model.bert.parameters():param.requires_grad = False

 【两种微调策略对比】

策略是否冻结BERT层?训练内容优点缺点
✅ 全部微调❌ 不冻结BERT + 分类层一起训练效果最好,能深入适配任务训练慢,显存占用大
🚫 只微调分类层✅ 冻结 BERT只训练分类层快速,适合小数据、低配置表现可能略逊一筹

【举个通俗类比】

想象你雇了一个精通语文的老师(BERT),但你只想让他教学生写作文(分类任务):

  • 全模型微调:老师重新备课、重新学习学生情况,全面参与教学(耗时但有效)。

  • 只调分类层:老师照搬旧知识,只教作文技巧,不深入了解学生(快速但效果一般)。


【什么时候选择“冻结 BERT 层”?】

  • 数据量很小(如只有几百个样本)

  • 硬件资源有限(显存小、设备性能弱)

  • 快速原型验证,先试试看效果


【小结一句话】

冻结 BERT 层就是让预训练好的 BERT 不再学习,只训练新增的部分;
不冻结则是让整个 BERT 跟着任务数据一起再“进修”一轮,效果更强,但代价更高。

你也可以先冻结一部分,再逐步解冻,称为 层级微调(layer-wise unfreezing),也是一个进阶策略。


 五、完整训练示例代码

5.1 环境依赖

 1、安装Pytorch

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

注意:

▲安装pytorch前先确定自己电脑是否有GPU,没有请安装cpu版本的;

pip3 install torch torchvision torchaudio

▲确保CUDA 12.6版本可以兼容

确定是否兼容可参考该文章对应内容:【CUDA&cuDNN安装】深度学习基础环境搭建_cudnn安装教程-CSDN博客


2、安装transformers

pip install transformers

3、安装scikit-learn

pip install scikit-learn

scikit-learn 是一个专注于传统机器学习的工具箱,涵盖从模型训练、评估到数据处理的一整套流程。


5.2 执行代码

import torch
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score
import pandas as pd# 1. 自定义数据集类
class SentimentDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_len=128):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self, idx):inputs = self.tokenizer(self.texts[idx],truncation=True,padding='max_length',max_length=self.max_len,return_tensors='pt')return {'input_ids': inputs['input_ids'].squeeze(0),'attention_mask': inputs['attention_mask'].squeeze(0),'labels': torch.tensor(self.labels[idx], dtype=torch.long)}# 2. 加载数据
df = pd.read_csv("data.csv")  # 假设 CSV 有 'text' 和 'label' 列
train_texts, train_labels = df['text'].tolist(), df['label'].tolist()# 3. 初始化 tokenizer 和 model
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2)# 4. 构建 DataLoader
train_dataset = SentimentDataset(train_texts, train_labels, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)# 5. 设置优化器
optimizer = AdamW(model.parameters(), lr=5e-5)# 6. 训练过程
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()for epoch in range(3):  # 可改为你想要的轮数total_loss = 0preds, targets = [], []for batch in train_loader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.losslogits = outputs.logitsoptimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()preds += torch.argmax(logits, dim=1).tolist()targets += labels.tolist()acc = accuracy_score(targets, preds)print(f"Epoch {epoch+1} | Loss: {total_loss:.4f} | Accuracy: {acc:.4f}")# 7. 保存模型
torch.save(model.state_dict(), "bert_finetuned.pth")

 


总结:微调的优势

少量数据就能训练出效果不错的模型
迁移学习加速开发,节省计算资源
灵活应对不同领域任务,如医学、法律、金融等

模型微调是现代 AI 应用的关键技能之一。如果说预训练模型是“万能工具箱”,那么微调就是“选对合适的工具并精修”。掌握这项技术,你就能迅速把通用模型打造成特定任务的专家。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/87515.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/87515.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/87515.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大语言模型预训练数据——数据采样方法介绍以GPT3为例

大语言模型预训练数据——数据采样方法介绍以GPT3为例一、数据采样核心逻辑二、各列数据含义一、数据采样核心逻辑 这是 GPT - 3 训练时的数据集配置,核心是非等比例采样——不按数据集原始大小分配训练占比,而是人工设定不同数据集在训练中被抽取的概率…

针对同一台电脑,为使用不同 SSH Key 的不同用户分别设置 Git 远程仓库凭据的操作指南

一、准备工作 生成多对 SSH Key 为每个用户(如“个人”、“公司”)生成一对独立的 SSH Key。 示例(在 Git Bash 或 Linux 终端中执行): # 个人 ssh-keygen -t rsa -b 4096 -C "personalexample.com" -f ~/.…

【V5.0 - 视觉篇】AI的“火眼金睛”:用OpenCV量化“第一眼缘”,并用SHAP验证它的“审美”

系列回顾: 在上一篇 《给AI装上“写轮眼”:用SHAP看穿模型决策的每一个细节》 中,我们成功地为AI装上了“透视眼镜”,看穿了它基于数字决策的内心世界。 但一个巨大的问题暴露了:它的世界里,还只有数字。 它…

Open3D 基于最大团(MAC)的点云粗配准

MAC 一、算法原理1、原理概述2、实现流程3、总结二、代码实现三、结果展示博客长期更新,本文最新更新时间为:2025年7月1日。 一、算法原理 1、原理概述 最大团(Maximal Cliques, MAC)法在点云配准中的应用,是近年来解决高离群值(outlier)和低重叠场景下配准问题的重要…

Science Robotics发表 | 20m/s自主飞行+避开2.5mm电线的微型无人机!

从山火搜救到灾后勘察,时间常常意味着生命。分秒必争的任务要求无人机在陌生狭窄环境中既要飞得快、又要飞得稳。香港大学机械工程系张富教授团队在Science Robotics(2025)发表论文“Safety-assured High-speed Navigation for MAVs”提出了微型无人机的安全高速导航…

【数据分析】如何在PyCharm中高效配置和使用SQL

PyCharm 作为 Python 开发者的首选 IDE,其 Professional 版本提供了强大的数据库集成功能,让开发者无需切换工具即可完成数据库操作。本文将手把手教你配置和使用 PyCharm 的 SQL 功能。 一、安装和配置 PyCharm 老生常谈,第一步自然是安装并…

OpenShift AI - 使用 NVIDIA Triton Runtime 运行模型

《OpenShift / RHEL / DevSecOps 汇总目录》 说明:本文已经在 OpenShift 4.18 OpenShift AI 2.19 的环境中验证 文章目录 准备 Triton Runtime 环境添加 Triton Serving Runtime运行基于 Triton Runtime 的 Model Server 在 Triton Runtime 中运行模型准备模型运行…

物联网数据安全区块链服务

物联网数据安全区块链服务 下面是一个专为物联网数据安全设计的区块链服务实现,使用Python编写并封装为RESTful API。该服务确保物联网设备数据的不可篡改性、可追溯性和安全性。 import hashlib import json import time from datetime import datetime from uui…

数据集-目标检测系列- 卡车 数据集 truck >> DataBall

数据集-目标检测系列- 卡车 数据集 truck >> DataBall贵在坚持!* 相关项目1)数据集可视化项目:gitcode: https://gitcode.com/DataBall/DataBall-detections-100s/overview2)数据集训练、推理相关项目&…

vue/微信小程序/h5 实现react的boundary

ErrorBoundary react的boundary实现核心逻辑无法处理的情况包含函数详细介绍getDerivedStateFromError和componentDidCatch作用为什么分开调用 代码实现(补充其他异常捕捉)函数组件与useErrorBoundary(需自定义Hook) vue的boundar…

Day113 切换Node.js版本、多数据源配置

切换Node.js版本 1.nvm简介nvm(Node Version Manager),在Windows上管理Node.js版本,可以在同一台电脑上轻松管理和切换多个Node.js版本 nvm下载地址:https://github.com/coreybutler/nvm-windows/2.配置nvm安装之后检查nvm是否已经安装好了&a…

应急响应靶机-linux2-知攻善防实验室

题目: 1.提交攻击者IP2.提交攻击者修改的管理员密码(明文)3.提交第一次Webshell的连接URL(http://xxx.xxx.xxx.xx/abcdefg?abcdefg只需要提交abcdefg?abcdefg)4.提交Webshell连接密码5.提交数据包的flag16.提交攻击者使用的后续上传的木马文件名称7.提交攻击者隐藏…

新手前端使用Git(常用命令和规范)

发一篇文章来说一下前端在开发项目的时候常用的一些git命令 注:这篇文章只说最常用的,最下面有全面的 一:从git仓库拉取项目到本地 1:新建文件夹存放项目代码 2:在git上复制一下项目路径(看那个顺眼复制…

【面试题】常用Git命令

【面试题】常用Git命令1. 常用Git命令1. 常用Git命令 1.git clone git clone https://gitee.com/Blue_Pepsi_Cola/straw.git 2.使用-v选项,可以参看远程主机的网址 git remote -v origin https://ccc.ddd.com/1-java/a-admin-api.git (fetch) origin https://ccc.…

Webpack构建工具

构建工具系列 Gulp构建工具Grunt构建工具Webpack构建工具Vite构建工具 Webpack构建工具 构建工具系列前言一、安装打包配置webpack安装样式加载器devtoolwebpack devtool 配置详解常见 devtool 值及适用场景选择建议性能影响注意事项 module处理流程module.rulesmodule.usemod…

重学前端002 --响应式网页设计 CSS

文章目录 css 样式特殊说明 根据在这里 Freecodecamp 实践,调整顺序后做的总结。 css 样式 body {background-color: red; # 跟background-image 不同时使用background-image: url(https://cdn.freecodecamp.org/curriculum/css-cafe/beans.jpg);font-family: san…

RabbitMQ简单消息监听和确认

如何监听RabbitMQ队列 简单代码实现RabbitMQ消息监听 需要的依赖 <!--rabbitmq--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-amqp</artifactId><version>x.x.x</version>&l…

Docker学习笔记:Docker网络

本文是自己的学习笔记 1、Linux中的namespace1.1、创建namespace1.2、两个namespace互相通信2、Docker中的namespace2.1 容器中的默认Bridge3、容器的三种网络模式1、Linux中的namespace Docker中使用了虚拟网络技术&#xff0c;让各个容器的网络隔离。好像每个容器从网卡到端…

用自定义注解解决excel动态表头导出的问题

导入的excel有固定表头动态表头如何解决 自定义注解&#xff1a; import java.lang.annotation.*;/*** 自定义注解&#xff0c;用于动态生成excel表头*/ Target(ElementType.FIELD) Retention(RetentionPolicy.RUNTIME) public interface FieldLabel {// 字段中文String label(…

Android-EDLA 解决 GtsMediaRouterTestCases 存在 fail

问题描述&#xff1a;[原因]R10套件新增模块&#xff0c;getRemoteDevice获取远程蓝牙设备时&#xff0c;蓝牙MAC为空 [对策]实际蓝牙MAC非空;测试时绕过处理 1.release/ebsw_skg/skg/frameworks/base/packages/SettingsLib/src/com/android/settingslib/media/InfoMediaManage…