目录
1、基础:它到底是个啥?
1. 1、一句话理解核心
1.2、 为啥厉害?
1.3、怎么发展来的?
2、架构:它的 “身体构造” 是啥样的?
2.1、视觉语言模型架构:让 AI “看懂” 世界的核心系统
2.1.1、双塔模型(如 CLIP)
2.1.2、交叉注意力模型(如 BLIP-2)
2.1.3、端到端模型(如 Flamingo)
2.1.4、轻量级模型(如 Fuyu-8B)
2.2、语音语言模型架构:让 AI “听懂” 声音的核心系统
2.2.1、语音特征提取(MFCC)
2.2.2、序列对齐(CTC 损失)
2.2.3、端到端模型(如 Whisper)
2.3、多模态大语言模型架构:让 AI “感知” 世界的超级系统
2.3.1、模态编码器
2.3.2、连接器(Connector)
2.3.3、大型语言模型(LLM)
3、多模态架构的核心公式与训练流程
3.1、跨模态对齐公式
3.2、训练流程
3.3、训练:怎么让它变聪明的?
4、架构对比与选择指南
5、应用:它能帮我们做啥?
6、多模态LLM模型(图像-文本生成)(简化版)
7、多模态LLM模型(图像-文本生成+问答系统)(简化版)
1、基础:它到底是个啥?
1. 1、一句话理解核心
普通大模型(比如 ChatGPT)只能处理文字,而多模态大语言模型(简称 “多模态 LLM”)能同时 “看懂图、听懂声、读得懂字”,还能用文字回答你所有问题。比如你给它一张电路图,它能直接告诉你 “这里接反了会短路”;给它一段机器运转的声音,它能说 “轴承快坏了,得换”。
1.2、 为啥厉害?
以前的 AI 是 “偏科生”:有的只能看图(比如识别图片里的猫),有的只能处理文字(比如写作文),但多模态 LLM 是 “全能选手”—— 它用语言把所有信息打通了。就像人既会看路标(图像),又会读路牌(文字),还能跟人打听路(语言),最后找到目的地,而不是只认其中一种。
1.3、怎么发展来的?
- 先有 “文字学霸”:比如 GPT-3、Llama,只会处理文字,逻辑推理超强但 “看不见东西”。
- 再加上 “图像 / 声音翻译官”:比如 CLIP 能把图片转成文字能懂的 “密码”,让文字学霸能 “间接看图”。
- 最后合体:把 “翻译官” 和 “文字学霸” 绑在一起,就成了多模态 LLM,比如 GPT-4V、Llava 这些。
2、架构:它的 “身体构造” 是啥样的?
2.1、视觉语言模型架构:让 AI “看懂” 世界的核心系统
视觉语言模型(如 CLIP、BLIP-2)的核心是将图像和文字映射到同一语义空间,实现跨模态理解。其架构通常包含三个模块:
2.1.1、双塔模型(如 CLIP)
- 架构原理: 独立的图像编码器(如 ResNet)和文本编码器(如 Transformer)分别处理图片和文字,通过对比学习将两者特征投影到同一向量空间。
-
关键公式:对比损失函数
-
其中,
为余弦相似度,
是温度参数,控制相似度分布的平滑程度。
-
- 训练步骤:
- 输入图片和对应文本,分别编码为特征向量
和
。
- 计算所有图片 - 文本对的相似度矩阵,最大化正确对的相似度,最小化错误对的相似度。
- 输入图片和对应文本,分别编码为特征向量
- 应用场景:图片检索(如从百万张图中找出 “戴红帽子的猫”)
2.1.2、交叉注意力模型(如 BLIP-2)
- 架构原理: 引入Query-Former 模块,通过跨模态注意力机制让图像和文本特征直接交互。
- 关键公式:跨模态注意力
其中,Q 来自文本,K 和 V 来自图像,通过多头注意力实现深度融合。
- 关键公式:跨模态注意力
- 训练步骤:
- 冻结图像编码器(如 CLIP 的 ViT),仅微调 Query-Former 和 LLM。
- 输入图像和文本,Query-Former 生成融合特征,LLM 生成回答(如 “图中猫在做什么”)。
- 应用场景:视觉问答(VQA)、图文生成(如根据图片写故事)。
2.1.3、端到端模型(如 Flamingo)
- 架构原理: 冻结视觉编码器,仅微调语言模型(如 Llama),通过视觉提示引导语言模型生成答案。
- 关键设计:
- 视觉特征直接输入 LLM 的 Transformer 层,无需独立编码器。
- 采用 “视觉 token”(如<image>标签)标记输入中的图像部分。
- 关键设计:
- 训练步骤:
- 预训练阶段:在海量图文数据上对齐视觉和语言特征。
- 微调阶段:针对特定任务(如医学影像分析),用领域数据训练语言模型。
- 应用场景:实时图像标注(如直播中自动生成字幕)。
2.1.4、轻量级模型(如 Fuyu-8B)
- 架构原理: 摒弃传统图像编码器,直接将图像分块后通过线性投影输入 Transformer 解码器。
- 关键公式:图像分块投影
其中,x是原始图像,
是第 i 个图像块的特征向量。
- 关键公式:图像分块投影
- 训练步骤:
- 将图像切分为 16x16 的小块,每个块线性投影到文本模型的维度。
- 与文本 token 混合输入解码器,联合训练生成回答。
- 应用场景:边缘设备(如手机)的实时图像问答。
2.2、语音语言模型架构:让 AI “听懂” 声音的核心系统
语音语言模型(如 Whisper、DeepSpeech)的核心是将语音信号转化为文字序列,其架构通常包含三个模块:
2.2.1、语音特征提取(MFCC)
- 步骤解析:
- 预加重:提升高频信号,公式为
。
- 分帧加窗:将语音切分为 20-30ms 的帧,加汉明窗减少边界效应。
- FFT 变换:将时域信号转为频域,得到功率谱。
- 梅尔滤波:通过三角形滤波器组提取人耳敏感的频率特征。
- DCT 变换:将梅尔谱转换为倒谱系数(MFCC),去除冗余信息。
- 预加重:提升高频信号,公式为
- 输出结果:每帧生成 12-16 维 MFCC 特征,叠加能量、一阶 / 二阶差分,共 40 维左右。
2.2.2、序列对齐(CTC 损失)
- 架构原理: 解决语音和文本的时序不对齐问题,通过动态规划计算路径概率。
- 关键公式:CTC 损失函数
其中,
是对齐路径,
是时刻 t 输出字符
的概率。
- 关键公式:CTC 损失函数
- 训练步骤:
- 输入 MFCC 特征序列,通过 RNN 或 CNN 生成预测概率矩阵。
- 使用 CTC 算法计算所有可能对齐路径的概率之和,最大化正确路径的概率。
2.2.3、端到端模型(如 Whisper)
- 架构原理: 基于 Transformer 的编码器 - 解码器架构,直接输入音频波形生成文本。
- 关键设计:
- 编码器:将 30 秒音频转为 80 维 log-Mel 频谱,输入多层 Transformer。
- 解码器:在文本生成时引入交叉注意力,融合音频编码和历史文本。
- 关键设计:
- 训练步骤:
- 预训练:在 68 万小时多语言音频上训练,支持 99 种语言。
- 微调:针对特定领域(如医疗)优化转录准确率。
- 应用场景:实时语音转写(如会议记录)、跨语言翻译(如法语→英语)。
2.3、多模态大语言模型架构:让 AI “感知” 世界的超级系统
多模态大语言模型(如 GPT-4V、Llama 4 Maverick)的核心是整合视觉、语音、文本多模态信息,实现复杂推理。其架构通常包含四个模块:
2.3.1、模态编码器
- 功能:将图像、语音等非文本信息转化为特征向量。
- 技术方案:
- 图像:CLIP、Swin Transformer(如 GPT-4V)。
- 语音:MFCC+Transformer(如 Whisper)。
- 文本:Llama、Qwen(如 Qwen-VL)。
2.3.2、连接器(Connector)
- 功能:统一不同模态的特征格式,便于 LLM 处理。
- 技术方案:
- 线性投影:将图像 / 语音特征调整为与文本 token 相同维度(如 Fuyu-8B)。
- 跨模态注意力:在 Transformer 层引入图像 - 文本交互(如 BLIP-2)。
2.3.3、大型语言模型(LLM)
- 功能:作为 “大脑” 进行跨模态推理和生成。
- 技术方案:
- 参数量:通常为 7B-400B(如 Llama 4 Maverick 的 400B 参数)。
- 架构:混合专家(MoE)、稀疏注意力(如 DeepSeek-V3)。
-
生成器(可选)
- 功能:输出非文本模态(如图像、视频)。
- 技术方案:
- 图像生成:扩散模型(如 Stable Diffusion),基于 LLM 输出的文本描述生成图片。
- 视频生成:Transformer + 时空注意力,生成连贯视频序列。
3、多模态架构的核心公式与训练流程
3.1、跨模态对齐公式
- 对比学习(CLIP):
-
- 掩码语言建模
3.2、训练流程
- 预训练阶段:
- 多模态数据构建:爬取图文对、语音 - 文本对(如 SBU 数据集的 50 万图文对)。
- 特征对齐:通过对比学习或掩码建模,让模型理解跨模态关联。
- 微调阶段:
- 领域数据注入:如医疗影像 + 诊断报告,提升特定任务准确率(如 BakLLaVA-1 的 92% 诊断率)。
- 指令微调:设计多模态指令(如 “根据 X 光片诊断肺炎风险”),引导模型生成符合人类逻辑的回答。
- 优化技术:
- 混合专家(MoE):减少训练成本,如 Llama 4 通过 MoE 实现 400B 参数高效训练。
- 模型量化:将参数压缩至 4-bit/8-bit,如 Llama 4 Scout 支持单卡部署。
3.3、训练:怎么让它变聪明的?
就像教一个小孩 “认识世界”,分三步:
1. 先学 “基础知识”(预训练)
给它喂海量 “图文配对” 的资料:比如 “猫的图片 +‘这是一只猫’”“汽车图片 +‘四个轮子的交通工具’”。
目的是让它知道 “图片里的内容和文字说的是一回事”,就像小孩看绘本,把图画和文字对应起来。2. 再练 “具体技能”(微调)
针对具体任务 “补课”:比如想让它看懂 X 光片,就专门喂 “X 光片 + 医生诊断文字” 的资料;想让它讲题,就喂 “数学题图片 + 解题步骤”。
这一步就像学生上完基础课,再去学 “物理、化学” 等专业课。3. 关键技巧:让它 “不瞎猜”
训练时故意 “藏起一部分信息” 让它猜:比如盖住图片的一半让它补全,或者遮住文字的几个字让它填。这样能逼它更认真地 “看” 和 “想”,减少胡说八道(专业叫 “减少幻觉”)。
4、架构对比与选择指南
架构类型 | 代表模型 | 核心优势 | 适用场景 | 参数量范围 |
---|---|---|---|---|
双塔模型 | CLIP | 轻量、高检索效率 | 图片 / 文本匹配 | 400M-10B |
交叉注意力 | BLIP-2 | 复杂推理、多模态生成 | 视觉问答、图文生成 | 13B-65B |
端到端 | Flamingo | 高效适配、低延迟 | 实时交互、边缘设备 | 7B-30B |
混合专家(MoE) | Llama 4 Maverick | 高性能、稀疏计算 | 科学研究、工业级推理 | 100B-400B |
轻量级 | Fuyu-8B | 低功耗、单卡部署 | 手机、物联网设备 | 8B-16B |
5、应用:它能帮我们做啥?
生活里到处都能用,举几个接地气的例子:
- 看病:给医生当助手,拍张 X 光片,它能立刻标出 “这里可能有炎症”,再结合病历文字,提醒医生重点检查。
- 学习:学生拍一张数学题图片,它不光给答案,还能用文字讲 “第一步为什么要这么算”,比课本好懂。
- 干活:工厂里拍张零件照片,它能说 “这个螺丝松了,会导致机器异响”,工人不用自己盯着看半天。
- 日常:旅游时拍张外语路标,它能翻译文字,还能告诉你 “往前走 300 米有地铁站”(结合图片里的箭头)。
特别说明:训练度和数据集不够,结果存在问题,主要用于理解知识
6、多模态LLM模型(图像-文本生成)(简化版)
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import GPT2LMHeadModel, GPT2Tokenizer, get_linear_schedule_with_warmup
import os
from PIL import Image
import json
import requests
from io import BytesIO
import random
import numpy as np
from tqdm import tqdm
import warnings
import matplotlib.pyplot as plt # 用于绘图
from matplotlib.table import Table # 用于生成对比表格
import seaborn as sns # 美化图表
sns.set_style("whitegrid")# ---------------------------- 核心修复1:设置Matplotlib支持中文显示 ----------------------------
plt.rcParams["font.family"] = ["SimHei"] # 支持中文的字体
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题# ---------------------------- 消除其他警告配置 ----------------------------
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
warnings.filterwarnings("ignore", category=UserWarning, message="The parameter 'pretrained' is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, message="Arguments other than a weight enum or `None` for 'weights' are deprecated")# ---------------------------- 数据集类定义(平衡样本分布) ----------------------------
class BalancedDemoDataset(Dataset):"""平衡的演示数据集(确保各类别样本数量均等)"""def __init__(self, img_size=224, max_text_length=512):self.categories = [{"name": "cat", "url": "https://picsum.photos/seed/cat1/500/300"},{"name": "dog", "url": "https://picsum.photos/seed/dog1/500/300"},{"name": "bird", "url": "https://picsum.photos/seed/bird1/500/300"},{"name": "city", "url": "https://picsum.photos/seed/city1/500/300"},{"name": "mountain", "url": "https://picsum.photos/seed/mountains1/500/300"},{"name": "beach", "url": "https://picsum.photos/seed/beach1/500/300"},{"name": "forest", "url": "https://picsum.photos/seed/forest1/500/300"},{"name": "library", "url": "https://picsum.photos/seed/library1/500/300"},{"name": "restaurant", "url": "https://picsum.photos/seed/restaurant1/500/300"},{"name": "airport", "url": "https://picsum.photos/seed/airport1/500/300"}]# 为每个类别生成5个不同描述(更具体,避免模糊)self.data = []for cat in self.categories:base_descriptions = [f"A {cat['name']} scene with typical features",f"The {cat['name']} showing natural details",f"An image of {cat['name']} with clear views",f"View of {cat['name']} in daylight",f"Close-up of {cat['name']} key elements"]for desc in base_descriptions:self.data.append({"image_url": cat["url"],"text": desc,"category": cat["name"]})self.img_size = img_sizeself.max_text_length = max_text_length# 图像预处理self.image_transform = transforms.Compose([transforms.Resize((self.img_size, self.img_size)),transforms.RandomHorizontalFlip(p=0.3),transforms.RandomRotation(5),transforms.ColorJitter(brightness=0.1, contrast=0.1),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 初始化tokenizer(左padding,明确设置pad_token)self.text_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")self.text_tokenizer.pad_token = self.text_tokenizer.eos_tokenself.text_tokenizer.padding_side = "left"self.cached_images = {}def __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]if item["image_url"] not in self.cached_images:try:response = requests.get(item["image_url"], timeout=10)image = Image.open(BytesIO(response.content)).convert("RGB")self.cached_images[item["image_url"]] = imageexcept Exception as e:print(f"Image download failed for {item['category']}: {e}, using random image")color_map = {"cat": (255, 200, 200), "dog": (200, 255, 200), "bird": (200, 200, 255),"city": (255, 255, 200), "mountain": (200, 255, 255), "beach": (255, 220, 180),"forest": (180, 255, 180), "library": (220, 220, 220),"restaurant": (255, 180, 180), "airport": (200, 200, 200)}color = color_map.get(item["category"], (255, 255, 255))image = Image.new('RGB', (self.img_size, self.img_size), color=color)self.cached_images[item["image_url"]] = imageimage = self.cached_images[item["image_url"]]image_tensor = self.image_transform(image)text_tokens = self.text_tokenizer(item["text"],padding="max_length",truncation=True,max_length=self.max_text_length,return_tensors="pt")return {"image": image_tensor,"input_ids": text_tokens["input_ids"].squeeze(0),"attention_mask": text_tokens["attention_mask"].squeeze(0),"text": item["text"],"category": item["category"]}# ---------------------------- 模型架构(保持不变) ----------------------------
class ImageEncoder(nn.Module):def __init__(self, output_dim=768):super().__init__()self.base_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)modules = list(self.base_model.children())[:-1]self.feature_extractor = nn.Sequential(*modules)self.projection = nn.Sequential(nn.Linear(2048, 1024),nn.GELU(),nn.Dropout(0.2),nn.Linear(1024, output_dim))def forward(self, images):features = self.feature_extractor(images).squeeze(-1).squeeze(-1)return self.projection(features)class CrossModalFusion(nn.Module):def __init__(self, hidden_dim=768, num_heads=8):super().__init__()self.text_norm = nn.LayerNorm(hidden_dim)self.image_norm = nn.LayerNorm(hidden_dim)self.attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)self.fusion = nn.Linear(hidden_dim * 2, hidden_dim)self.activation = nn.GELU()def forward(self, text_features, image_features):batch_size, seq_len, _ = text_features.shapeimage_features = self.image_norm(image_features)image_expanded = image_features.unsqueeze(1).expand(-1, seq_len, -1)text_attn, _ = self.attention(query=text_features, key=image_expanded, value=image_expanded)text_attn = self.text_norm(text_features + text_attn)fused = self.activation(self.fusion(torch.cat([text_features, text_attn], dim=-1)))return fusedclass MultimodalLLM(nn.Module):def __init__(self, hidden_dim=768):super().__init__()self.text_encoder = GPT2LMHeadModel.from_pretrained("gpt2")for param in list(self.text_encoder.parameters())[:3]:param.requires_grad = Falseself.image_encoder = ImageEncoder(output_dim=hidden_dim)self.cross_modal_fusion = CrossModalFusion(hidden_dim=hidden_dim)self.final_norm = nn.LayerNorm(hidden_dim)def forward(self, images, input_ids, attention_mask=None):image_features = self.image_encoder(images)text_outputs = self.text_encoder.transformer(input_ids=input_ids, attention_mask=attention_mask)text_features = text_outputs.last_hidden_statefused_features = self.cross_modal_fusion(text_features, image_features)fused_features = self.final_norm(fused_features)return self.text_encoder.lm_head(fused_features)# ---------------------------- 训练与生成函数(优化生成策略) ----------------------------
def train_model(model, train_loader, val_loader, epochs=10, lr=5e-5):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)optimizer = optim.AdamW(model.parameters(), lr=lr)criterion = nn.CrossEntropyLoss(ignore_index=50256)scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=len(train_loader), num_training_steps=epochs*len(train_loader))train_losses = []val_losses = []for epoch in range(epochs):model.train()total_train_loss = 0progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")for batch in progress_bar:images = batch["image"].to(device)input_ids = batch["input_ids"].to(device)attention_mask = batch["attention_mask"].to(device)outputs = model(images, input_ids, attention_mask)shift_logits = outputs[..., :-1, :].contiguous()shift_labels = input_ids[..., 1:].contiguous()loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()scheduler.step()total_train_loss += loss.item()progress_bar.set_postfix(loss=f"{loss.item():.4f}")model.eval()total_val_loss = 0with torch.no_grad():for batch in val_loader:images = batch["image"].to(device)input_ids = batch["input_ids"].to(device)attention_mask = batch["attention_mask"].to(device)outputs = model(images, input_ids, attention_mask)shift_logits = outputs[..., :-1, :].contiguous()shift_labels = input_ids[..., 1:].contiguous()total_val_loss += criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).item()avg_train = total_train_loss / len(train_loader)avg_val = total_val_loss / len(val_loader)train_losses.append(avg_train)val_losses.append(avg_val)print(f"Epoch {epoch+1} | Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}")return model, train_losses, val_lossesdef generate_text(model, image, tokenizer, category, max_length=60):"""核心修复2:传递attention_mask,优化生成策略"""device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.eval()# 更具体的引导提示(避免模型生成列表)prompt = f"Describe the {category} image in detail: "# 生成input_ids和attention_mask(解决警告)inputs = tokenizer(prompt,return_tensors="pt",padding="max_length",max_length=len(prompt) + 5, # 足够容纳提示词truncation=True)input_ids = inputs["input_ids"].to(device)attention_mask = inputs["attention_mask"].to(device) # 传递注意力掩码# 提取图像特征image = image.unsqueeze(0).to(device)with torch.no_grad():image_features = model.image_encoder(image)# 核心修复3:优化生成参数(减少重复,提高相关性)output = model.text_encoder.generate(input_ids=input_ids,attention_mask=attention_mask, # 传入掩码,消除警告max_length=max_length,temperature=0.5, # 降低随机性,更聚焦输入num_beams=3,no_repeat_ngram_size=3, # 避免3字词重复early_stopping=True,encoder_hidden_states=image_features.unsqueeze(1))# 解码并移除提示词generated_text = tokenizer.decode(output[0], skip_special_tokens=True)return generated_text.replace(prompt, "").strip()# ---------------------------- 对比图生成(修复中文显示) ----------------------------
def plot_loss_curves(train_losses, val_losses):plt.figure(figsize=(10, 5))plt.plot(range(1, len(train_losses)+1), train_losses, label="训练损失", marker='o')plt.plot(range(1, len(val_losses)+1), val_losses, label="验证损失", marker='s')plt.xlabel("轮次")plt.ylabel("损失值")plt.title("训练与验证损失对比")plt.legend()plt.grid(alpha=0.3)plt.savefig("loss_curves.png")print("Loss对比图已保存为 loss_curves.png")plt.close()def generate_results_table(results):plt.figure(figsize=(12, 9))ax = plt.gca()ax.axis('off')table = Table(ax, bbox=[0, 0, 1, 1])# 表头(中文显示正常)table.add_cell(0, 0, 0.1, 0.1, text="类别", loc='center', facecolor='lightgray')table.add_cell(0, 1, 0.3, 0.1, text="原始文本", loc='center', facecolor='lightgray')table.add_cell(0, 2, 0.6, 0.1, text="生成文本", loc='center', facecolor='lightgray')# 添加内容for i, res in enumerate(results[:8]):table.add_cell(i+1, 0, 0.1, 0.15, text=res["category"], loc='center') # 增加行高,避免文本溢出table.add_cell(i+1, 1, 0.3, 0.15, text=res["original"], loc='left')table.add_cell(i+1, 2, 0.6, 0.15, text=res["generated"], loc='left')ax.add_table(table)plt.savefig("results_table.png", bbox_inches='tight')print("结果对比表已保存为 results_table.png")plt.close()# ---------------------------- 主函数 ----------------------------
if __name__ == "__main__":print("准备平衡数据集...")full_dataset = BalancedDemoDataset()train_size = int(0.8 * len(full_dataset))train_dataset, val_dataset = random_split(full_dataset, [train_size, len(full_dataset)-train_size])batch_size = 4 if torch.cuda.is_available() else 1train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)print("训练模型...")model = MultimodalLLM()model, train_losses, val_losses = train_model(model, train_loader, val_loader, epochs=12)torch.save(model.state_dict(), "optimized_model.pth")plot_loss_curves(train_losses, val_losses)print("生成测试结果...")tokenizer = full_dataset.text_tokenizerresults = []for category in [cat["name"] for cat in full_dataset.categories]:sample_idx = next(i for i, item in enumerate(full_dataset.data) if item["category"] == category)sample = full_dataset[sample_idx]generated = generate_text(model, sample["image"], tokenizer, category)results.append({"category": category,"original": sample["text"],"generated": generated})generate_results_table(results)print("\n部分生成结果:")for res in results[:5]:print(f"\n类别: {res['category']}")print(f"原始文本: {res['original']}")print(f"生成文本: {res['generated']}")
7、多模态LLM模型(图像-文本生成+问答系统)(简化版)
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms # 视觉模型与图像预处理工具
from torch.utils.data import Dataset, DataLoader, random_split # 数据加载与划分
from transformers import GPT2LMHeadModel, GPT2Tokenizer, get_linear_schedule_with_warmup # 文本模型与工具
import os
from PIL import Image # 图像处理库
import requests # 网络请求(下载图像)
from io import BytesIO # 内存中处理二进制数据
import numpy as np
from tqdm import tqdm # 进度条显示
import warnings # 警告处理
import matplotlib.pyplot as plt # 可视化工具
from matplotlib.table import Table # 生成结果表格
import seaborn as sns # 美化图表sns.set_style("whitegrid") # 设置图表风格# ---------------------------- 配置环境(解决中文显示与警告问题) ----------------------------
# 设置支持中文的字体,解决图表中文乱码
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC", "Arial Unicode MS"]
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示异常# 过滤无关警告,保持输出简洁
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
warnings.filterwarnings("ignore", category=UserWarning, message="The parameter 'pretrained' is deprecated")
warnings.filterwarnings("ignore", category=UserWarning,message="Arguments other than a weight enum or `None` for 'weights' are deprecated")# ---------------------------- 多模态数据集类(核心数据处理) ----------------------------
class BalancedMultimodalDataset(Dataset):"""平衡的多模态数据集:包含图像、描述文本和英文问答对特点:10个类别,每个类别样本数量均等,避免模型偏向某类数据"""def __init__(self, img_size=224, max_text_length=128):# 定义10个类别及其对应的图像URL(使用picsum生成可重复的随机图像)self.categories = [{"name": "cat", "url": "https://picsum.photos/seed/cat1/500/300"},{"name": "dog", "url": "https://picsum.photos/seed/dog1/500/300"},{"name": "bird", "url": "https://picsum.photos/seed/bird1/500/300"},{"name": "city", "url": "https://picsum.photos/seed/city1/500/300"},{"name": "mountain", "url": "https://picsum.photos/seed/mountains1/500/300"},{"name": "beach", "url": "https://picsum.photos/seed/beach1/500/300"},{"name": "forest", "url": "https://picsum.photos/seed/forest1/500/300"},{"name": "library", "url": "https://picsum.photos/seed/library1/500/300"},{"name": "restaurant", "url": "https://picsum.photos/seed/restaurant1/500/300"},{"name": "airport", "url": "https://picsum.photos/seed/airport1/500/300"}]# 构建数据集:每个类别包含5种描述和3组问答对self.data = []for cat in self.categories:# 为每个类别生成5种不同的图像描述(增强数据多样性)descriptions = [f"A {cat['name']} scene with typical features",f"The {cat['name']} showing natural details",f"An image of {cat['name']} with clear views",f"View of {cat['name']} in daylight",f"Close-up of {cat['name']} key elements"]# 为每个类别设计3组英文问答对(覆盖不同类型的问题)qa_pairs = [{"question": f"What is the main subject of this {cat['name']} image?", # 主体识别"answer": f"The main subject is a {cat['name']}."},{"question": f"What features are typical of this {cat['name']}?", # 特征描述"answer": f"Typical features include {cat['name']}-specific characteristics."},{"question": f"Where might this {cat['name']} be located?", # 位置推测"answer": f"This {cat['name']} might be located in its natural environment."}]# 组合描述和问答对,生成最终数据集for desc in descriptions:for qa in qa_pairs:self.data.append({"image_url": cat["url"], # 图像URL"description": desc, # 图像描述"question": qa["question"], # 问题"answer": qa["answer"], # 答案"category": cat["name"] # 类别标签})self.img_size = img_size # 图像统一尺寸self.max_text_length = max_text_length # 文本最大长度(防止输入过长)# 图像预处理管道(含数据增强)self.image_transform = transforms.Compose([transforms.Resize((self.img_size, self.img_size)), # 缩放至指定尺寸transforms.RandomHorizontalFlip(p=0.3), # 30%概率水平翻转(数据增强)transforms.RandomRotation(5), # 随机旋转±5度(增强视角鲁棒性)transforms.ColorJitter(brightness=0.1, contrast=0.1), # 微调亮度和对比度transforms.ToTensor(), # 转换为Tensor格式(通道×高度×宽度)transforms.Normalize( # 标准化(使用ImageNet的均值和标准差)mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])# 初始化文本分词器(适配GPT2模型)self.text_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")self.text_tokenizer.pad_token = self.text_tokenizer.eos_token # 使用终止符作为填充符self.text_tokenizer.padding_side = "left" # 左填充(适合自回归模型)self.cached_images = {} # 缓存已下载的图像(避免重复网络请求)def __len__(self):"""返回数据集样本总数"""return len(self.data)def __getitem__(self, idx):"""获取单个样本:图像+文本+标签"""item = self.data[idx]# 下载并缓存图像(若未缓存)if item["image_url"] not in self.cached_images:try:# 尝试下载图像response = requests.get(item["image_url"], timeout=10)image = Image.open(BytesIO(response.content)).convert("RGB") # 转换为RGB格式self.cached_images[item["image_url"]] = imageexcept Exception as e:# 下载失败时,生成与类别相关的纯色图(避免程序崩溃)print(f"图像下载失败({item['category']}): {e},使用替代图")# 为每个类别分配独特颜色(便于调试)color_map = {"cat": (255, 200, 200), "dog": (200, 255, 200), "bird": (200, 200, 255),"city": (255, 255, 200), "mountain": (200, 255, 255), "beach": (255, 220, 180),"forest": (180, 255, 180), "library": (220, 220, 220),"restaurant": (255, 180, 180), "airport": (200, 200, 200)}color = color_map.get(item["category"], (255, 255, 255)) # 默认白色image = Image.new('RGB', (self.img_size, self.img_size), color=color)self.cached_images[item["image_url"]] = image# 预处理图像image = self.cached_images[item["image_url"]]image_tensor = self.image_transform(image)# 预处理文本(将问答对转换为模型输入格式)input_text = f"Question: {item['question']} Answer: {item['answer']}" # 拼接问题和答案text_tokens = self.text_tokenizer(input_text,padding="max_length", # 填充至最大长度truncation=True, # 超长则截断max_length=self.max_text_length,return_tensors="pt" # 返回PyTorch张量)return {"image": image_tensor, # 预处理后的图像张量"input_ids": text_tokens["input_ids"].squeeze(0), # 文本ID序列(去除batch维度)"attention_mask": text_tokens["attention_mask"].squeeze(0), # 注意力掩码(1表示有效token)"question": item["question"], # 原始问题(用于测试)"answer": item["answer"], # 原始答案(用于对比)"category": item["category"] # 类别标签}# ---------------------------- 多模态模型架构(核心组件) ----------------------------
class ImageEncoder(nn.Module):"""图像编码器:将图像转换为与文本兼容的特征向量输入:图像(3×224×224)输出:特征向量(768维)"""def __init__(self, output_dim=768):super().__init__()# 使用预训练的ResNet50作为基础特征提取器self.base_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 移除最后一层全连接层(保留卷积特征提取部分)# ResNet50的最后一层是fc层,输出1000类,这里只保留前面的特征提取部分feature_extractor_modules = list(self.base_model.children())[:-1]self.feature_extractor = nn.Sequential(*feature_extractor_modules)# 投影层:将ResNet输出的2048维特征映射到768维(与文本特征维度一致)self.projection = nn.Sequential(nn.Linear(2048, 1024), # 降维至1024nn.GELU(), # 高斯误差线性单元(比ReLU更平滑)nn.Dropout(0.2), # Dropout层(防止过拟合)nn.Linear(1024, output_dim) # 最终投影至768维)def forward(self, images):"""前向传播:图像→特征向量"""# 提取卷积特征:ResNet50输出为[batch_size, 2048, 1, 1]conv_features = self.feature_extractor(images)# 展平为[batch_size, 2048]flattened_features = conv_features.squeeze(-1).squeeze(-1)# 投影至768维return self.projection(flattened_features)class CrossModalFusion(nn.Module):"""跨模态融合模块:实现文本特征与图像特征的交互核心:通过注意力机制让文本关注图像的关键信息"""def __init__(self, hidden_dim=768, num_heads=8):super().__init__()self.text_norm = nn.LayerNorm(hidden_dim) # 文本特征归一化self.image_norm = nn.LayerNorm(hidden_dim) # 图像特征归一化# 多头注意力机制(并行处理多个特征子空间)self.attention = nn.MultiheadAttention(embed_dim=hidden_dim, # 特征维度(768)num_heads=num_heads, # 注意力头数(8,768/8=96,每个头处理96维)batch_first=True # 输入格式为[batch, seq_len, dim])# 特征融合层:将文本特征与注意力输出融合self.fusion = nn.Linear(hidden_dim * 2, hidden_dim) # 1536→768self.activation = nn.GELU() # 激活函数def forward(self, text_features, image_features):"""输入:text_features: [batch_size, seq_len, hidden_dim](文本序列特征)image_features: [batch_size, hidden_dim](图像全局特征)输出:fused_features: [batch_size, seq_len, hidden_dim](融合特征)"""batch_size, seq_len, _ = text_features.shape # 获取文本序列长度# 图像特征归一化并扩展至序列长度(每个文本token都能关注图像)image_features = self.image_norm(image_features)# 扩展为[batch_size, seq_len, hidden_dim]image_expanded = image_features.unsqueeze(1).expand(-1, seq_len, -1)# 文本特征通过注意力关注图像特征(交叉注意力)# query=文本特征,key=图像特征,value=图像特征text_attn, _ = self.attention(query=text_features,key=image_expanded,value=image_expanded)# 残差连接+层归一化(缓解梯度消失,加速训练)text_attn = self.text_norm(text_features + text_attn)# 融合原始文本特征和注意力增强特征fused = self.activation(self.fusion(torch.cat([text_features, text_attn], dim=-1)))return fusedclass MultimodalLLM(nn.Module):"""多模态大语言模型:整合图像编码器、文本编码器和跨模态融合模块功能:根据图像生成描述文本,或回答与图像相关的问题"""def __init__(self, hidden_dim=768):super().__init__()# 文本编码器(基于GPT2,预训练语言模型)self.text_encoder = GPT2LMHeadModel.from_pretrained("gpt2")# 冻结前3层参数(减少训练量,保留预训练语言知识)for param in list(self.text_encoder.parameters())[:3]:param.requires_grad = False# 图像编码器(见上文)self.image_encoder = ImageEncoder(output_dim=hidden_dim)# 跨模态融合模块(见上文)self.cross_modal_fusion = CrossModalFusion(hidden_dim=hidden_dim)self.final_norm = nn.LayerNorm(hidden_dim) # 最终归一化层def forward(self, images, input_ids, attention_mask=None):"""前向传播流程:1. 提取图像特征2. 提取文本特征3. 跨模态融合4. 生成预测结果"""# 1. 图像特征提取image_features = self.image_encoder(images) # [batch_size, hidden_dim]# 2. 文本特征提取(通过GPT2的Transformer层)text_outputs = self.text_encoder.transformer(input_ids=input_ids,attention_mask=attention_mask)text_features = text_outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]# 3. 跨模态融合(文本特征+图像特征)fused_features = self.cross_modal_fusion(text_features, image_features)fused_features = self.final_norm(fused_features) # 归一化# 4. 通过GPT2的语言模型头生成下一个token的概率分布return self.text_encoder.lm_head(fused_features) # [batch, seq_len, vocab_size]# ---------------------------- 训练与生成函数(模型应用) ----------------------------
def train_model(model, train_loader, val_loader, epochs=10, lr=5e-5):"""训练多模态模型参数:model: 待训练的模型train_loader: 训练数据加载器val_loader: 验证数据加载器epochs: 训练轮次lr: 学习率返回:model: 训练好的模型train_losses: 训练损失曲线val_losses: 验证损失曲线"""# 选择计算设备(GPU优先)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device) # 模型移至设备# 优化器(AdamW:带权重衰减的Adam,减轻过拟合)optimizer = optim.AdamW(model.parameters(), lr=lr)# 损失函数(交叉熵损失,忽略填充token的损失)# 50256是GPT2的eos_token_id(即填充符)criterion = nn.CrossEntropyLoss(ignore_index=50256)# 学习率调度器(线性预热+衰减)scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=len(train_loader), # 预热步数=1个epoch的迭代次数num_training_steps=epochs * len(train_loader) # 总训练步数)train_losses = [] # 记录训练损失val_losses = [] # 记录验证损失for epoch in range(epochs):# 训练阶段model.train() # 开启训练模式(启用dropout等)total_train_loss = 0# 进度条显示训练过程progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")for batch in progress_bar:# 数据移至设备images = batch["image"].to(device)input_ids = batch["input_ids"].to(device)attention_mask = batch["attention_mask"].to(device)# 前向传播:获取模型输出outputs = model(images, input_ids, attention_mask)# 计算损失(预测下一个token)# 输出和标签都偏移一位(预测第i+1个token,基于第1..i个token)shift_logits = outputs[..., :-1, :].contiguous() # 预测序列shift_labels = input_ids[..., 1:].contiguous() # 目标序列loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), # 展平为[batch*(seq_len-1), vocab_size]shift_labels.view(-1) # 展平为[batch*(seq_len-1)])# 反向传播与参数更新optimizer.zero_grad() # 梯度清零loss.backward() # 计算梯度# 梯度裁剪(防止梯度爆炸,大模型训练必备)torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)optimizer.step() # 更新参数scheduler.step() # 更新学习率total_train_loss += loss.item()# 显示当前批次损失progress_bar.set_postfix(loss=f"{loss.item():.4f}")# 计算平均训练损失avg_train_loss = total_train_loss / len(train_loader)train_losses.append(avg_train_loss)# 验证阶段(不更新参数)model.eval() # 开启评估模式(关闭dropout等)total_val_loss = 0with torch.no_grad(): # 禁用梯度计算(节省内存)for batch in val_loader:# 数据移至设备images = batch["image"].to(device)input_ids = batch["input_ids"].to(device)attention_mask = batch["attention_mask"].to(device)# 前向传播outputs = model(images, input_ids, attention_mask)# 计算损失(同训练阶段)shift_logits = outputs[..., :-1, :].contiguous()shift_labels = input_ids[..., 1:].contiguous()total_val_loss += criterion(shift_logits.view(-1, shift_logits.size(-1)),shift_labels.view(-1)).item()# 计算平均验证损失avg_val_loss = total_val_loss / len(val_loader)val_losses.append(avg_val_loss)print(f"Epoch {epoch + 1} | 训练损失: {avg_train_loss:.4f} | 验证损失: {avg_val_loss:.4f}")return model, train_losses, val_lossesdef generate_description(model, image, tokenizer, category, max_new_tokens=40):"""根据图像生成描述文本参数:model: 训练好的模型image: 预处理后的图像tokenizer: 文本分词器category: 图像类别(用于提示词)max_new_tokens: 最大新增token数返回:生成的描述文本"""device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.eval() # 评估模式# 构建提示词(引导模型生成与类别相关的描述)prompt = f"Describe the {category} image in detail: "# 编码提示词(不填充,保留原始长度)inputs = tokenizer(prompt,return_tensors="pt",padding="do_not_pad",truncation=False)input_ids = inputs["input_ids"].to(device)attention_mask = inputs["attention_mask"].to(device)# 提取图像特征image = image.unsqueeze(0).to(device) # 增加batch维度with torch.no_grad():image_features = model.image_encoder(image) # [1, hidden_dim]# 生成文本(核心步骤)output = model.text_encoder.generate(input_ids=input_ids, # 提示词IDattention_mask=attention_mask, # 注意力掩码max_new_tokens=max_new_tokens, # 最多生成40个新tokentemperature=0.6, # 温度参数(控制随机性,值越小越确定)num_beams=3, # Beam搜索宽度(保留3个最优候选)no_repeat_ngram_size=2, # 禁止2-gram重复(减少冗余)early_stopping=True, # 生成终止符时停止encoder_hidden_states=image_features.unsqueeze(1) # 传入图像特征(关键))# 解码并清理生成的文本(去除特殊字符和提示词)generated_text = tokenizer.decode(output[0], skip_special_tokens=True)# 替换非-breaking空格为普通空格,去除提示词return generated_text.replace(u'\xa0', ' ').replace(prompt, "").strip()def answer_question(model, image, question, tokenizer, max_new_tokens=40):"""基于图像回答英文问题参数:model: 训练好的模型image: 预处理后的图像question: 英文问题tokenizer: 文本分词器max_new_tokens: 最大新增token数返回:生成的答案"""device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.eval() # 评估模式# 构建问答格式的提示词prompt = f"Question: {question} Answer: "# 编码提示词inputs = tokenizer(prompt,return_tensors="pt",padding="do_not_pad",truncation=False)input_ids = inputs["input_ids"].to(device)attention_mask = inputs["attention_mask"].to(device)# 提取图像特征image = image.unsqueeze(0).to(device)with torch.no_grad():image_features = model.image_encoder(image)# 生成答案output = model.text_encoder.generate(input_ids=input_ids,attention_mask=attention_mask,max_new_tokens=max_new_tokens,temperature=0.5, # 更低的温度(回答需更确定)num_beams=3,no_repeat_ngram_size=3, # 禁止3-gram重复(进一步减少冗余)early_stopping=True,encoder_hidden_states=image_features.unsqueeze(1))# 解码并清理答案generated_answer = tokenizer.decode(output[0], skip_special_tokens=True)return generated_answer.replace(u'\xa0', ' ').replace(prompt, "").strip()# ---------------------------- 可视化与评估函数 ----------------------------
def plot_loss_curves(train_losses, val_losses):"""绘制训练和验证损失曲线,评估模型训练效果"""plt.figure(figsize=(10, 5))# 绘制训练损失plt.plot(range(1, len(train_losses) + 1), train_losses, label="训练损失", marker='o')# 绘制验证损失plt.plot(range(1, len(val_losses) + 1), val_losses, label="验证损失", marker='s')plt.xlabel("训练轮次")plt.ylabel("损失值")plt.title("训练与验证损失对比")plt.legend()plt.grid(alpha=0.3) # 网格线(增强可读性)plt.savefig("loss_curves.png", bbox_inches='tight') # 保存图像print("损失对比图已保存为 loss_curves.png")plt.close()def generate_results_table(desc_results, qa_results):"""生成结果对比表格(包含描述和问答结果)"""plt.figure(figsize=(14, 10))ax = plt.gca()ax.axis('off') # 关闭坐标轴# 创建表格table = Table(ax, bbox=[0, 0, 1, 1]) # 表格占满整个图# 添加表头table.add_cell(0, 0, 0.1, 0.1, text="类别", loc='center', facecolor='lightgray')table.add_cell(0, 1, 0.25, 0.1, text="生成描述", loc='center', facecolor='lightgray')table.add_cell(0, 2, 0.3, 0.1, text="问题", loc='center', facecolor='lightgray')table.add_cell(0, 3, 0.35, 0.1, text="生成答案", loc='center', facecolor='lightgray')# 填充表格内容(前6个类别)for i in range(min(6, len(desc_results))):desc = desc_results[i]qa = qa_results[i]# 清理文本中的特殊字符clean_desc = desc["generated"].replace(u'\xa0', ' ')clean_question = qa["question"].replace(u'\xa0', ' ')clean_answer = qa["answer"].replace(u'\xa0', ' ')# 添加单元格内容table.add_cell(i + 1, 0, 0.1, 0.15, text=desc["category"], loc='center')table.add_cell(i + 1, 1, 0.25, 0.15, text=clean_desc, loc='left')table.add_cell(i + 1, 2, 0.3, 0.15, text=clean_question, loc='left')table.add_cell(i + 1, 3, 0.35, 0.15, text=clean_answer, loc='left')ax.add_table(table)plt.savefig("results_table.png", bbox_inches='tight') # 保存表格print("结果对比表已保存为 results_table.png")plt.close()# ---------------------------- 主函数(完整流程执行) ----------------------------
if __name__ == "__main__":# 1. 准备数据集print("准备多模态数据集(含问答)...")full_dataset = BalancedMultimodalDataset()# 划分训练集(80%)和验证集(20%)train_size = int(0.8 * len(full_dataset))train_dataset, val_dataset = random_split(full_dataset, [train_size, len(full_dataset) - train_size])# 创建数据加载器(批量加载数据)batch_size = 4 if torch.cuda.is_available() else 1 # GPU可用时使用更大批量train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)# 2. 训练模型print("开始训练多模态模型(支持问答)...")model = MultimodalLLM()model, train_losses, val_losses = train_model(model, train_loader, val_loader, epochs=10 # 训练10个轮次)# 保存训练好的模型权重torch.save(model.state_dict(), "multimodal_qa_model.pth")# 3. 绘制损失曲线(评估训练效果)plot_loss_curves(train_losses, val_losses)# 4. 生成测试结果(描述+问答)print("生成测试结果...")tokenizer = full_dataset.text_tokenizerdesc_results = [] # 存储描述生成结果qa_results = [] # 存储问答结果# 每个类别选1个样本测试categories = list(set(item["category"] for item in full_dataset.data)) # 去重类别for category in categories[:6]: # 测试前6个类别# 找到该类别的样本索引sample_idx = next(i for i, item in enumerate(full_dataset.data) if item["category"] == category)sample = full_dataset[sample_idx] # 获取样本# 生成图像描述description = generate_description(model, sample["image"], tokenizer, category)desc_results.append({"category": category, "generated": description})# 生成问答结果question = sample["question"] # 原始问题answer = answer_question(model, sample["image"], question, tokenizer)qa_results.append({"category": category,"question": question,"answer": answer})# 5. 生成结果对比表generate_results_table(desc_results, qa_results)# 6. 打印部分结果(展示效果)print("\n英文问答示例:")for i in range(3):print(f"\n示例 {i + 1}:")print(f"类别: {qa_results[i]['category']}")print(f"问题: {qa_results[i]['question']}")print(f"生成答案: {qa_results[i]['answer']}")