RAG 原理:通过 “检索(从知识库获取相关信息)→ 增强(将信息作为上下文输入模型)→ 生成(模型基于上下文回答)” 三步,解决大模型知识时效性、领域局限性问题。
接下来将完成这么一个场景
1. 向量数据库使用api调用
2. llm大模型也使用api调用
3. 加载文档到向量库库中
4. 完成 检索 -- 增强 --- 生成三步
先讲原理 后面直接贴代码
参考: https://zhuanlan.zhihu.com/p/1895505027537298602
① 文档读取然后存入向量库中
原理:
5. 读取文档内容(文本,pdf/word/excel..)
6. 文档切片
7. 使用嵌入式模型向量化
8. 存入向量数据库
② 检索(从向量库中检索相关内容)
1. 检索器先将用户的查询文本通过嵌入模型转换为向量。
2. 在向量索引(Chroma 集合)中,计算查询向量与所有文档向量的相似度(通常用余弦相似度)。
3. 按照相似度分数从高到低排序,取指定个数的片段
③增强
将检索到的内容和问题组成promnt
④ 生成
调用llm大模型,生成内容
代码参考: https://blog.csdn.net/lucassu/article/details/146897774
代码结构如图
embedding.py 嵌入式模型
本地模型可以从魔搭下载
import os
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings# 初始化一个HuggingFaceEmbedding对象,用于将文本转换为向量表示
def configure_embedding():"""配置文本嵌入模型"""# 指定了一个预训练的sentence-transformer模型的路径app_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))# d:\MyWork\GitProject\MySource\LLM\LlamaIndex\LlamaIndex\demo1model_path = os.path.abspath(os.path.join(app_path, ".."))print(model_path)model = HuggingFaceEmbedding(model_name=model_path + "/model/embedding-bge-small-en-v1.5",cache_folder=app_path + "/temp_data/cache", # 模型缓存路径)# 必须设置全局 EmbeddingSettings.embed_model = modelreturn model
RestAPILLM.py 自定义llm类,
其实不需要那么麻烦,只需要写一个类能调用api接口就行,这个为了使用llama_index中的LLM搞得很麻烦
import json
import os
from dotenv import load_dotenv
from typing import Any, Optional, List, Dict, Generator, AsyncGenerator, Sequence
from unittest import result
import requests
import aiohttp
from llama_index.core.llms import LLM
from llama_index.core.base.llms.types import (ChatMessage,ChatResponse,ChatResponseAsyncGen,ChatResponseGen,CompletionResponse,CompletionResponseAsyncGen,CompletionResponseGen,LLMMetadata,
)
from llama_index.core import Settings# 加载环境变量
load_dotenv()API_KEY = os.getenv("API_KEY")
API_URL = os.getenv("API_URL")
MODEL = os.getenv("MODEL")class RestAPILLM(LLM):"""对接REST API的大语言模型封装类继承自LlamaIndex的LLM基类,实现与自定义REST API的交互"""# API基础配置api_endpoint: str = API_URLapi_key: Optional[str] = API_KEY # API密钥timeout: int = 60 # 请求超时时间(秒)headers: Dict[str, str] = {} # 自定义请求头# 模型参数model_name: str = MODELtemperature: float = 0.7max_tokens: int = 1024is_chat_model: bool = Truedef __init__(self, **kwargs: Any) -> None:super().__init__(**kwargs)if not self.headers:self.headers = {"Content-Type": "application/json"}if self.api_key:self.headers["Authorization"] = f"Bearer {self.api_key}" def metadata(self) -> LLMMetadata:return LLMMetadata(model_name=self.model_name,is_chat_model=self.is_chat_model,temperature=self.temperature,max_tokens=self.max_tokens,)def _format_chat_payload(self, messages: Sequence[ChatMessage]) -> Dict[str, Any]:return {"model": self.model_name,"messages": [{"role": msg.role.value, "content": msg.content} for msg in messages],"temperature": self.temperature,"max_tokens": self.max_tokens,"stream": False,}def _format_completion_payload(self, prompt: str) -> Dict[str, Any]:print("prompt: " + prompt)data = {"model": self.model_name,"messages": [{"content": [{"text": prompt, "type": "text"}], "role": "user"}],"stream": False, # 开启流式响应}return datadef _parse_chat_response(self, response_data: Dict[str, Any]) -> ChatResponse:choice = response_data["choices"][0]return ChatResponse(message=ChatMessage(role=choice["message"]["role"], content=choice["message"]["content"]),raw=response_data,)def _parse_completion_response(self, response_data: Dict[str, Any]) -> CompletionResponse:result = ""if response_data: # 检查流结束标记try:chunk = json.loads(response_data)# 提取当前片段的文本内容result = chunk["choices"][0]["message"]["content"]# 实时打印当前片段(不换行)except (KeyError, json.JSONDecodeError) as e:print(f"\n解析响应片段出错: {str(response_data)}")print("response_data: " + response_data)return CompletionResponse(text=result, raw=response_data)# ()def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:try:converted_messages = self.convert_chat_messages(messages)response = requests.post(url=self.api_endpoint,json=self._format_chat_payload(converted_messages),headers=self.headers,timeout=self.timeout,**kwargs,)response.raise_for_status()return self._parse_chat_response(response.json())except Exception as e:raise ValueError(f"同步聊天请求失败: {str(e)}") from e# ()def complete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:"""非流式生成文本"""try:response = requests.post(url=self.api_endpoint,json=self._format_completion_payload(prompt),headers=self.headers,timeout=self.timeout,**kwargs,)response.raise_for_status()return self._parse_completion_response(response.text)except Exception as e:raise ValueError(f"同步补全请求失败: {str(e)}") from e# ()def stream_chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponseGen:def generate() -> Generator[ChatResponse, None, None]:try:converted_messages = self.convert_chat_messages(messages)payload = self._format_chat_payload(converted_messages)payload["stream"] = Truewith requests.post(url=self.api_endpoint,json=payload,headers=self.headers,stream=True,timeout=self.timeout,**kwargs,) as response:response.raise_for_status()for line in response.iter_lines():if line:line_str = line.decode("utf-8").strip()if line_str.startswith("data: "):line_str = line_str[len("data: ") :]if line_str == "[DONE]":breakimport jsonchunk = json.loads(line_str)delta = chunk["choices"][0]["delta"]yield ChatResponse(message=ChatMessage(role=delta.get("role", "assistant"),content=delta.get("content", ""),),delta=delta.get("content", ""),raw=chunk,)except Exception as e:raise ValueError(f"同步流式聊天失败: {str(e)}") from ereturn generate()# ()def stream_complete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponseGen:def generate() -> Generator[CompletionResponse, None, None]:try:payload = self._format_completion_payload(prompt)payload["stream"] = Truewith requests.post(url=self.api_endpoint,json=payload,headers=self.headers,stream=True,timeout=self.timeout,**kwargs,) as response:response.raise_for_status()for line in response.iter_lines():if line:line_str = line.decode("utf-8").strip()# print("----" + line_str)# ----data: {"choices":[{"delta":{"content":"更","role":"assistant"},"index":0}],"created":1755065926,"id":"02175506592656071349a582c393a7766c3954399cf7bc9546e9d","model":"doubao-seed-1-6-flash-250715","service_tier":"default","object":"chat.completion.chunk","usage":null}if line_str.startswith("data: "):line_str = line_str[len("data: ") :]if line_str == "[DONE]":breakimport jsonchunk = json.loads(line_str)text = chunk["choices"][0]["delta"]["content"]yield CompletionResponse(text=text, delta=text, raw=chunk)except Exception as e:raise ValueError(f"同步流式补全失败: {str(e)}") from ereturn generate()# ()async def achat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:try:converted_messages = self.convert_chat_messages(messages)async with aiohttp.ClientSession() as session:async with session.post(url=self.api_endpoint,json=self._format_chat_payload(converted_messages),headers=self.headers,timeout=self.timeout,**kwargs,) as response:response.raise_for_status()return self._parse_chat_response(await response.json())except Exception as e:raise ValueError(f"异步聊天请求失败: {str(e)}") from e# ()async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:try:async with aiohttp.ClientSession() as session:async with session.post(url=self.api_endpoint,json=self._format_completion_payload(prompt),headers=self.headers,timeout=self.timeout,**kwargs,) as response:response.raise_for_status()return self._parse_completion_response(await response.json())except Exception as e:raise ValueError(f"异步补全请求失败: {str(e)}") from e# ()async def astream_chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponseAsyncGen:async def generate() -> AsyncGenerator[ChatResponse, None]:try:converted_messages = self.convert_chat_messages(messages)payload = self._format_chat_payload(converted_messages)payload["stream"] = Trueasync with aiohttp.ClientSession() as session:async with session.post(url=self.api_endpoint,json=payload,headers=self.headers,timeout=self.timeout,**kwargs,) as response:response.raise_for_status()async for line in response.content.iter_lines():if line:line_str = line.decode("utf-8").strip()if line_str.startswith("data: "):line_str = line_str[len("data: ") :]if line_str == "[DONE]":breakimport jsonchunk = json.loads(line_str)delta = chunk["choices"][0]["delta"]yield ChatResponse(message=ChatMessage(role=delta.get("role", "assistant"),content=delta.get("content", ""),),delta=delta.get("content", ""),raw=chunk,)except Exception as e:raise ValueError(f"异步流式聊天失败: {str(e)}") from ereturn generate()# ()async def astream_complete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponseAsyncGen:async def generate() -> AsyncGenerator[CompletionResponse, None]:try:payload = self._format_completion_payload(prompt)payload["stream"] = Trueasync with aiohttp.ClientSession() as session:async with session.post(url=self.api_endpoint,json=payload,headers=self.headers,timeout=self.timeout,**kwargs,) as response:response.raise_for_status()async for line in response.content.iter_lines():if line:line_str = line.decode("utf-8").strip()if line_str.startswith("data: "):line_str = line_str[len("data: ") :]if line_str == "[DONE]":breakimport jsonchunk = json.loads(line_str)text = chunk["choices"][0].get("text", "")yield CompletionResponse(text=text, delta=text, raw=chunk)except Exception as e:raise ValueError(f"异步流式补全失败: {str(e)}") from ereturn generate()def configure_llm():"""配置LLM(使用Rest API)"""llm = RestAPILLM() # 替换为实际的LLM API地址# Settings.llm = llmreturn llmif __name__ == "__main__":# 测试用例# llm = RestAPILLM(# model_name="doubao-seed-1-6-flash-250715", temperature=0.5, max_tokens=1024# )llm = configure_llm()prompt = "写一篇关于春天旅游的文章 200字以内"response = llm.complete(prompt)print("非流式生成文本: \n\n")print(response.text)response_gen = llm.stream_complete(prompt)print("流式生成文本: \n\n")for chunk in response_gen:# print(chunk.text)print(chunk.text, end="", flush=True)
chroma_indexer.py 向量数据库 向量库部署之前文章有
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import VectorStoreIndex, StorageContext
import chromadb
from chromadb.config import Settings as ChromaSettingsfrom configs.embedding import configure_embeddingdef init_chroma_client():"""初始化Chroma客户端"""return chromadb.HttpClient(host="10.10.10.10", port=8000)def build_chroma_index(nodes, collection_name: str = "document_collection"):"""使用Chroma构建向量索引"""# 初始化Chroma客户端chroma_client = init_chroma_client()# 删除已有集合(可选,根据需求决定是否保留历史数据)if chroma_client.get_or_create_collection(collection_name).count() > 0:chroma_client.delete_collection(collection_name)# 获取或创建集合chroma_collection = chroma_client.get_or_create_collection(collection_name)# 创建向量存储# 创建一个 LlamaIndex 框架能识别的 “向量存储适配器”,将 Chroma 集合与 LlamaIndex 的索引系统关联起来。# chroma_collection 是 Chroma 数据库中的 “仓库”,而 vector_store 是 LlamaIndex 用来 “操作这个仓库的工具”。vector_store = ChromaVectorStore(chroma_collection=chroma_collection)# 创建存储上下文(关联向量存储)storage_context = StorageContext.from_defaults(vector_store=vector_store)embed_model = configure_embedding()# 构建索引(自动使用嵌入模型将nodes转换为向量并存储)index = VectorStoreIndex(nodes=nodes,storage_context=storage_context,embed_model=embed_model, # 显式指定嵌入模型 默认 Settings.embed_modelshow_progress=True,)# 持久化存储# chroma_client.persist()return indexdef load_chroma_index(collection_name: str = "document_collection"):"""加载已存在的Chroma索引"""chroma_client = init_chroma_client()# 检查集合是否存在try:chroma_collection = chroma_client.get_collection(collection_name)except ValueError:return Nonevector_store = ChromaVectorStore(chroma_collection=chroma_collection)return VectorStoreIndex.from_vector_store(vector_store)
data_loader.py 数据加载器, 加载文档并切片
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SimpleNodeParser
import osdef load_and_process_data(data_dir: str, chunk_size: int = 512):"""加载并处理文档数据"""# 确保目录存在if not os.path.exists(data_dir):raise ValueError(f"数据目录 {data_dir} 不存在")# 读取文档documents = SimpleDirectoryReader(data_dir).load_data()# 分块处理node_parser = SimpleNodeParser.from_defaults(chunk_size=chunk_size)nodes = node_parser.get_nodes_from_documents(documents)return documents, nodesif __name__ == "__main__":file_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))temp_dir = file_path + "/temp_data"print(f"数据目录:{temp_dir}")documents, nodes = load_and_process_data(temp_dir)print(f"共加载 {len(documents)} 篇文档,分块处理后共生成 {len(nodes)} 个节点")for node in nodes:print(f"节点包含 {node} 个分块")
app.py Streamlit主界面
from calendar import c
import streamlit as st
import os
import shutil
from configs.RestAPILLM import configure_llm
from configs.embedding import configure_embedding
from modules.chroma_indexer import build_chroma_index, load_chroma_index
from modules.data_loader import load_and_process_data
from llama_index.core.prompts import PromptTemplate
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core import get_response_synthesizer
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core import Settings# from configs.llm import configure_llm
# from modules.data_loader import load_and_process_data# 页面配置
st.set_page_config(page_title="xiaoSu智能文档问答系统",page_icon="",layout="wide",initial_sidebar_state="expanded",
)# 定义提示词模板(核心:将检索到的上下文注入模型输入)
DEFAULT_PROMPT_TEMPLATE = PromptTemplate("""你是一个基于文档的问答助手。请严格根据以下提供的上下文信息回答用户问题,不要编造内容。如果上下文信息不足以回答问题,请明确说明"根据提供的文档,无法回答该问题"。上下文信息:{context_str}用户问题:{query_str}回答:"""
)# 初始化系统
.cache_resource
def init_system():"""初始化嵌入模型"""configure_embedding()return configure_llm()def clear_temp_data():"""清理临时数据"""if os.path.exists("./temp_data"):shutil.rmtree("./temp_data")def create_index(uploaded_files):# 文件处理流程if uploaded_files:# 创建临时目录temp_dir = "./temp_data"os.makedirs(temp_dir, exist_ok=True)# 保存上传文件for file in uploaded_files:with open(os.path.join(temp_dir, file.name), "wb") as f:f.write(file.getbuffer())# 加载处理数据try:documents, nodes = load_and_process_data(temp_dir)# 构建索引with st.spinner("正在构建索引..."):index = build_chroma_index(nodes)st.session_state.index = indexst.success("Chroma索引构建完成共加载 "+ len(documents)+ "篇文档,分块处理后共生成"+ len(nodes)+ "个节点")except Exception as e:st.error(f"数据处理失败: {str(e)}")def main():# 初始化模型llm = init_system()# 侧边栏配置with st.sidebar:st.title("系统控制台")uploaded_files = st.file_uploader("上传文档(PDF/TXT)", type=["pdf", "txt"], accept_multiple_files=True)if st.button("构建索引"):create_index(uploaded_files)st.divider()st.subheader("检索配置")top_k = st.slider("返回相关片段数量", 1, 10, 3)st.divider()if st.button("清理系统缓存"):clear_temp_data()st.success("临时数据已清理")# 主界面st.title("📚 xiaoSu智能文档问答系统")st.caption("基于本地大模型的文档理解与问答系统")# 问答交互区if "index" in st.session_state:st.subheader("文档问答")query = st.text_input("输入您的问题:", placeholder="请输入关于文档内容的问题...")if query:try:with st.spinner("检索相关信息并生成答案中..."):# 1. 检索器:从向量索引中获取相关片段(上下文)# 加载嵌入模型 可省略默认使用Settings.embed_modelmodel = Settings.embed_modelretriever = VectorIndexRetriever(index=st.session_state.index,similarity_top_k=top_k, # 返回top_k个最相关的片段embed_model=model, # 嵌入模型)# 2. 构建上下文字符串(将检索到的片段拼接)# . 检索器先将用户的查询文本(如 “如何构建向量索引?”)通过嵌入模型转换为向量。# . 在向量索引(Chroma 集合)中,计算查询向量与所有文档向量的相似度(通常用余弦相似度)。# . 按照相似度分数从高到低排序,取前 top_k 个文档片段,封装成 Node 对象返回。retrieved_nodes = retriever.retrieve(query)context_str = "\n\n".join([node.text for node in retrieved_nodes])# 3. 显式将上下文和问题注入提示词formatted_prompt = DEFAULT_PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)# 4. 调用LLM生成答案(直接使用手动构建的提示词)response = llm.complete(formatted_prompt)# 显示答案st.markdown("### 回答")st.info(response.text)# # 显示参考来源st.markdown("### 参考内容")for idx, node in enumerate(retrieved_nodes):with st.expander(f"参考片段 {idx + 1} (相似度:{node.score:.2f})"):st.write(node.text)except Exception as e:import tracebacktraceback.print_exc()st.error(f"生成答案时出错: {str(e)}")if __name__ == "__main__":main()# streamlit run app.py
缺少的包,耐心点慢慢下载
Python 3.12.8