以下是基于飞桨平台实现的多模态分类详细案例,结合图像和文本信息进行分类任务。案例包含数据处理、模型构建、训练和评估的完整流程,并提供详细注释:

一、多模态分类案例实现

import os
import json
import numpy as np
from PIL import Image
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import Dataset, DataLoader
from paddle.vision import models
import paddlenlp as ppnlp
from paddlenlp.transformers import ErnieTokenizer, ErnieModel# 设置随机种子,确保结果可复现
paddle.seed(42)
np.random.seed(42)# ---------------------- 1. 数据集定义 ----------------------
class MultiModalDataset(Dataset):"""多模态图像-文本分类数据集"""def __init__(self, data_path, image_dir, tokenizer, max_seq_len=128, mode='train'):"""data_path: 标注文件路径image_dir: 图像文件夹路径tokenizer: 文本tokenizermax_seq_len: 文本最大长度mode: 模式,train/val/test"""super().__init__()self.image_dir = image_dirself.tokenizer = tokenizerself.max_seq_len = max_seq_lenself.mode = mode# 加载数据集with open(data_path, 'r', encoding='utf-8') as f:self.data = json.load(f)# 定义类别到ID的映射(根据数据集调整)self.label2id = {'科技': 0, '娱乐': 1, '体育': 2, '财经': 3, '教育': 4}self.id2label = {v: k for k, v in self.label2id.items()}def __len__(self):return len(self.data)def __getitem__(self, idx):# 获取单条数据item = self.data[idx]image_path = os.path.join(self.image_dir, item['image'])text = item['text']label = self.label2id[item['label']]# 处理图像image = Image.open(image_path).convert('RGB')image = self._preprocess_image(image)# 处理文本encoded_inputs = self.tokenizer(text=text,max_seq_len=self.max_seq_len,pad_to_max_seq_len=True,return_attention_mask=True,return_token_type_ids=True)# 转换为Tensorinput_ids = paddle.to_tensor(encoded_inputs['input_ids'], dtype='int64')attention_mask = paddle.to_tensor(encoded_inputs['attention_mask'], dtype='int64')token_type_ids = paddle.to_tensor(encoded_inputs['token_type_ids'], dtype='int64')label = paddle.to_tensor(label, dtype='int64')return {'image': image,'input_ids': input_ids,'attention_mask': attention_mask,'token_type_ids': token_type_ids,'label': label}def _preprocess_image(self, image):"""图像预处理:缩放、归一化、转Tensor"""# 调整图像大小为224x224image = image.resize((224, 224), Image.BICUBIC)# 转换为numpy数组image = np.array(image).astype('float32')# 归一化image = image / 255.0# 标准化(ImageNet均值和标准差)image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])# 调整通道顺序 (HWC -> CHW)image = np.transpose(image, (2, 0, 1))return paddle.to_tensor(image, dtype='float32')# ---------------------- 2. 多模态分类模型 ----------------------
class MultiModalClassifier(nn.Layer):"""基于图像和文本的多模态分类模型"""def __init__(self, num_classes, text_encoder='ernie-1.0', pretrained=True):super().__init__()# 图像编码器(使用预训练ResNet50)self.image_encoder = models.resnet50(pretrained=pretrained)# 移除最后的全连接层self.image_encoder.fc = nn.Identity()# 添加投影层,将图像特征映射到共同空间self.image_proj = nn.Linear(2048, 512)# 文本编码器(使用预训练ERNIE)self.text_encoder = ErnieModel.from_pretrained(text_encoder)# 添加投影层,将文本特征映射到共同空间self.text_proj = nn.Linear(768, 512)# 特征融合层self.fusion = nn.Sequential(nn.Linear(1024, 512),  # 拼接图像和文本特征 (512+512)nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, 256),nn.ReLU(),nn.Dropout(0.5))# 分类器self.classifier = nn.Linear(256, num_classes)def forward(self, image, input_ids, attention_mask, token_type_ids=None):# 提取图像特征image_features = self.image_encoder(image)  # [batch_size, 2048]image_features = self.image_proj(image_features)  # [batch_size, 512]# 提取文本特征text_outputs = self.text_encoder(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)# 获取[CLS] token的表示text_features = text_outputs[1]  # [batch_size, 768]text_features = self.text_proj(text_features)  # [batch_size, 512]# 特征融合fused_features = paddle.concat([image_features, text_features], axis=1)  # [batch_size, 1024]fused_features = self.fusion(fused_features)  # [batch_size, 256]# 分类预测logits = self.classifier(fused_features)  # [batch_size, num_classes]return logits# ---------------------- 3. 模型训练 ----------------------
def train_model(model, train_loader, val_loader, optimizer, criterion, epochs, save_dir):"""训练多模态分类模型"""best_acc = 0.0for epoch in range(epochs):# 训练模式model.train()train_loss = 0.0correct = 0total = 0for batch in train_loader:# 获取数据image = batch['image']input_ids = batch['input_ids']attention_mask = batch['attention_mask']token_type_ids = batch['token_type_ids']label = batch['label']# 前向传播logits = model(image, input_ids, attention_mask, token_type_ids)loss = criterion(logits, label)# 反向传播loss.backward()optimizer.step()optimizer.clear_grad()# 统计训练指标train_loss += loss.numpy()[0]total += label.shape[0]pred = paddle.argmax(logits, axis=1)correct += (pred == label).sum().numpy()[0]# 计算训练准确率train_acc = correct / totalprint(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}')# 验证val_acc = evaluate_model(model, val_loader)print(f'Epoch [{epoch+1}/{epochs}], Val Acc: {val_acc:.4f}')# 保存最佳模型if val_acc > best_acc:best_acc = val_accpaddle.save(model.state_dict(), os.path.join(save_dir, 'best_model.pdparams'))print(f'Model saved at acc: {best_acc:.4f}')# ---------------------- 4. 模型评估 ----------------------
def evaluate_model(model, data_loader):"""评估模型性能"""model.eval()correct = 0total = 0with paddle.no_grad():for batch in data_loader:# 获取数据image = batch['image']input_ids = batch['input_ids']attention_mask = batch['attention_mask']token_type_ids = batch['token_type_ids']label = batch['label']# 模型预测logits = model(image, input_ids, attention_mask, token_type_ids)pred = paddle.argmax(logits, axis=1)# 统计准确率total += label.shape[0]correct += (pred == label).sum().numpy()[0]return correct / total# ---------------------- 5. 主函数 ----------------------
def main():# 配置参数config = {'train_data_path': 'data/train.json',  # 训练数据路径'val_data_path': 'data/val.json',      # 验证数据路径'image_dir': 'data/images',            # 图像文件夹路径'save_dir': 'checkpoints',             # 模型保存路径'num_classes': 5,                      # 分类类别数'batch_size': 16,                      # 批次大小'epochs': 10,                          # 训练轮数'learning_rate': 1e-4,                 # 学习率'max_seq_len': 128                     # 文本最大长度}# 创建保存目录os.makedirs(config['save_dir'], exist_ok=True)# 初始化tokenizertokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')# 创建数据集train_dataset = MultiModalDataset(config['train_data_path'], config['image_dir'], tokenizer, config['max_seq_len'],mode='train')val_dataset = MultiModalDataset(config['val_data_path'], config['image_dir'], tokenizer, config['max_seq_len'],mode='val')# 创建数据加载器train_loader = DataLoader(train_dataset,batch_size=config['batch_size'],shuffle=True,num_workers=4)val_loader = DataLoader(val_dataset,batch_size=config['batch_size'],shuffle=False,num_workers=4)# 初始化模型model = MultiModalClassifier(config['num_classes'])# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = paddle.optimizer.AdamW(learning_rate=config['learning_rate'],parameters=model.parameters())# 训练模型train_model(model, train_loader, val_loader, optimizer, criterion, config['epochs'], config['save_dir'])# 加载最佳模型并评估model.set_state_dict(paddle.load(os.path.join(config['save_dir'], 'best_model.pdparams')))test_acc = evaluate_model(model, val_loader)print(f'Final Test Accuracy: {test_acc:.4f}')if __name__ == '__main__':main()

二、数据集格式说明

数据集采用JSON格式,每条数据包含图像路径、文本描述和类别标签:

[{"image": "image_001.jpg","text": "这款新手机的相机功能非常出色,拍照效果堪比专业相机","label": "科技"},{"image": "image_002.jpg","text": "这支足球队在本赛季表现出色,有望夺得冠军","label": "体育"},...
]

三、模型架构解析

  1. 图像编码器:使用预训练的ResNet50提取图像特征,最后通过全连接层投影到512维空间。
  2. 文本编码器:使用预训练的ERNIE模型提取文本特征,取[CLS]标记表示,再通过全连接层投影到512维空间。
  3. 特征融合:将图像和文本特征拼接后,通过多层感知机进行融合和降维。
  4. 分类器:基于融合特征进行多分类预测。

四、训练和评估流程

  1. 数据加载:使用自定义数据集类加载图像和文本数据,并进行预处理。
  2. 模型训练:采用交叉熵损失函数和AdamW优化器,训练10个epoch。
  3. 模型评估:在验证集上计算分类准确率,并保存性能最佳的模型。

五、扩展建议

  1. 特征融合改进:尝试更复杂的融合方法,如注意力机制、双线性池化等。
  2. 数据增强:对图像进行随机裁剪、翻转等增强,对文本进行同义词替换、插入等操作。
  3. 模型调优:调整学习率、批次大小、dropout率等超参数。
  4. 多模态权重平衡:为图像和文本分支设计可学习的权重,自适应调整各模态的重要性。

这个案例展示了如何结合图像和文本信息进行多模态分类,您可以根据实际需求调整模型架构和数据集。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/84348.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/84348.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/84348.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Express框架:Node.js的轻量级Web应用利器

Hi,我是布兰妮甜 !在当今快速发展的Web开发领域,Node.js已成为构建高性能、可扩展网络应用的重要基石。而在这片肥沃的生态系统中,Express框架犹如一座经久不衰的灯塔,指引着无数开发者高效构建Web应用的方向。本文章在为读者提供一份全面而深入的Express框架指南。无论您…

K-Means颜色变卦和渐变色

一、理论深度提升:补充算法细节与数学基础 1. K-Means 算法核心公式(增强专业性) 在 “原理步骤” 中加入数学表达式,说明聚类目标: K-Means 的目标是最小化簇内平方和(Within-Cluster Sum of Squares, W…

深入解析C#表达式求值:优先级、结合性与括号的魔法

—— 为什么2/6*4不等于1/12? 🔍 一、表达式求值顺序为何重要? 表达式如精密仪器,子表达式求值顺序直接决定结果。例如: int result 3 * 5 2;若先算乘法:(3*5)2 17 ✅若先算加法:3*(52)21…

Docker 离线安装指南

参考文章 1、确认操作系统类型及内核版本 Docker依赖于Linux内核的一些特性,不同版本的Docker对内核版本有不同要求。例如,Docker 17.06及之后的版本通常需要Linux内核3.10及以上版本,Docker17.09及更高版本对应Linux内核4.9.x及更高版本。…

Spring——Spring相关类原理与实战

摘要 本文深入探讨了 Spring 框架中 InitializingBean 接口的原理与实战应用,该接口是 Spring 提供的一个生命周期接口,用于在 Bean 属性注入完成后执行初始化逻辑。文章详细介绍了接口定义、作用、典型使用场景,并与其他相关概念如 PostCon…

Angular微前端架构:Module Federation + ngx-build-plus (Webpack)

以下是一个完整的 Angular 微前端示例,其中使用的是 Module Federation 和 npx-build-plus 实现了主应用(Shell)与子应用(Remote)的集成。 🛠️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…

ESP32 I2S音频总线学习笔记(四): INMP441采集音频并实时播放

简介 前面两期文章我们介绍了I2S的读取和写入,一个是通过INMP441麦克风模块采集音频,一个是通过PCM5102A模块播放音频,那如果我们将两者结合起来,将麦克风采集到的音频通过PCM5102A播放,是不是就可以做一个扩音器了呢…

冯诺依曼架构是什么?

冯诺依曼架构是什么? 冯诺依曼架构(Von Neumann Architecture)是现代计算机的基础设计框架,由数学家约翰冯诺依曼(John von Neumann)及其团队在1945年提出。其核心思想是通过统一存储程序与数据&#xff0…

【持续更新】linux网络编程试题

问题1 请简要说明TCP/IP协议栈的四层结构,并分别举出每一层出现的典型协议或应用。 答案 应用层:ping,telnet,dns 传输层:tcp,udp 网络层:ip,icmp 数据链路层:arp,rarp 问题2 下列协议或应用分别属于TCP/IP协议…

椭圆曲线密码学(ECC)

一、ECC算法概述 椭圆曲线密码学(Elliptic Curve Cryptography)是基于椭圆曲线数学理论的公钥密码系统,由Neal Koblitz和Victor Miller在1985年独立提出。相比RSA,ECC在相同安全强度下密钥更短(256位ECC ≈ 3072位RSA…

【JVM】- 内存结构

引言 JVM:Java Virtual Machine 定义:Java虚拟机,Java二进制字节码的运行环境好处: 一次编写,到处运行自动内存管理,垃圾回收的功能数组下标越界检查(会抛异常,不会覆盖到其他代码…

React 基础入门笔记

一、JSX语法规则 1. 定义虚拟DOM时,不要写引号 2.标签中混入JS表达式时要用 {} (1).JS表达式与JS语句(代码)的区别 (2).使用案例 3.样式的类名指定不要用class,要用className 4.内…

Linux链表操作全解析

Linux C语言链表深度解析与实战技巧 一、链表基础概念与内核链表优势1.1 为什么使用链表?1.2 Linux 内核链表与用户态链表的区别 二、内核链表结构与宏解析常用宏/函数 三、内核链表的优点四、用户态链表示例五、双向循环链表在内核中的实现优势5.1 插入效率5.2 安全…

SQL进阶之旅 Day 19:统计信息与优化器提示

【SQL进阶之旅 Day 19】统计信息与优化器提示 文章简述 在数据库性能调优中,统计信息和优化器提示是两个至关重要的工具。统计信息帮助数据库优化器评估查询成本并选择最佳执行计划,而优化器提示则允许开发人员对优化器的行为进行微调。本文深入探讨了…

安宝特方案丨船舶智造AR+AI+作业标准化管理系统解决方案(维保)

船舶维保管理现状:设备维保主要由维修人员负责,根据设备运行状况和维护计划进行定期保养和故障维修。维修人员凭借经验判断设备故障原因,制定维修方案。 一、痛点与需求 1 Arbigtec 人工经验限制维修效率: 复杂设备故障的诊断和…

MFC内存泄露

1、泄露代码示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 获取 Ribbon Bar 指针// 创建自定义按钮CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…

基于区块链的供应链溯源系统:构建与实践

前言 在当今全球化的经济环境中,供应链的复杂性不断增加,商品从原材料采购到最终交付给消费者的过程涉及多个环节和众多参与者。如何确保供应链的透明度、可追溯性和安全性,成为企业和消费者关注的焦点。区块链技术以其去中心化、不可篡改和透…

Web攻防-SQL注入数据格式参数类型JSONXML编码加密符号闭合

知识点: 1、Web攻防-SQL注入-参数类型&参数格式 2、Web攻防-SQL注入-XML&JSON&BASE64等 3、Web攻防-SQL注入-数字字符搜索等符号绕过 案例说明: 在应用中,存在参数值为数字,字符时,符号的介入&#xff0c…

探秘鸿蒙 HarmonyOS NEXT:实战用 CodeGenie 构建鸿蒙应用页面

在开发鸿蒙应用时,你是否也曾为一个页面的布局反复调整?是否还在为查 API、写模板代码而浪费大量时间?今天带大家实战体验一下鸿蒙官方的 AI 编程助手——CodeGenie(代码精灵) ,如何从 0 到 1 快速构建一个…

DBAPI如何优雅的获取单条数据

API如何优雅的获取单条数据 案例一 对于查询类API,查询的是单条数据,比如根据主键ID查询用户信息,sql如下: select id, name, age from user where id #{id}API默认返回的数据格式是多条的,如下: {&qu…