From 3f0f27a42f1798f3d4cb9e153c4682b30ab57e76 Mon Sep 17 00:00:00 2001 From: hailin Date: Sat, 10 May 2025 11:01:16 +0800 Subject: [PATCH] . --- rag_build_query.py | 110 +++++++++++++++++++++++++++++++++++++ scripts/rag_build_query.py | 15 +++-- 2 files changed, 121 insertions(+), 4 deletions(-) create mode 100644 rag_build_query.py diff --git a/rag_build_query.py b/rag_build_query.py new file mode 100644 index 0000000..4340685 --- /dev/null +++ b/rag_build_query.py @@ -0,0 +1,110 @@ +from fastapi import HTTPException # 导入 HTTPException 用于错误处理 +from typing import List # 导入 List 用于类型注解 +import os +import logging +import faiss +from llama_index import SimpleDirectoryReader, VectorStoreIndex, ServiceContext, StorageContext +from llama_index.embeddings.huggingface import HuggingFaceEmbedding +from llama_index.vector_stores.faiss import FaissVectorStore +from app.core.config import settings # 导入应用配置 +from scripts.permissions import get_user_allowed_indexes # 导入权限函数,管理用户索引 + +USER_INDEX_PATH = "index_data" # 用户索引存储路径 +USER_DOC_PATH = "docs" # 用户文档存储路径 + +# BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入 +class BGEEmbedding(HuggingFaceEmbedding): + def _get_query_embedding(self, query: str) -> List[float]: + # 在查询前加上前缀,生成嵌入向量 + prefix = "Represent this sentence for searching relevant passages: " + return super()._get_query_embedding(prefix + query) + + def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: + # 批量生成嵌入向量 + prefix = "Represent this sentence for searching relevant passages: " + return super()._get_query_embeddings([prefix + q for q in queries]) + +def build_user_index(user_id: str): + # 设置日志记录器,记录操作步骤和错误信息 + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + + logger.info(f"开始为用户 {user_id} 构建索引...") + + # 确认文档目录是否存在 + doc_dir = os.path.join(USER_DOC_PATH, user_id) + if not os.path.exists(doc_dir): + logger.error(f"文档目录不存在: {doc_dir}") + raise FileNotFoundError(f"文档目录不存在: {doc_dir}") + + logger.info(f"发现文档目录: {doc_dir}") + + # 使用 SimpleDirectoryReader 载入文档 + try: + documents = SimpleDirectoryReader(doc_dir).load_data() + logger.info(f"载入文档数量: {len(documents)}") + except Exception as e: + logger.error(f"加载文档时出错: {e}") + raise HTTPException(status_code=500, detail="文档加载失败") + + # 设置嵌入模型 + embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) + logger.info(f"使用模型: {settings.MODEL_NAME}") + + # 创建服务上下文 + service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) + + # 使用 Faiss 向量存储 + faiss_index = faiss.IndexFlatL2(1024) + vector_store = FaissVectorStore(faiss_index=faiss_index) + + # 确保索引保存路径存在,使用用户 ID 区分索引文件 + persist_dir = os.path.join(USER_INDEX_PATH, user_id) + logger.info(f"索引保存路径: {persist_dir}") + + # 如果目录不存在,则创建 + if not os.path.exists(persist_dir): + logger.info(f"目录 {persist_dir} 不存在,准备创建") + os.makedirs(persist_dir, exist_ok=True) + logger.info(f"目录 {persist_dir} 已创建") + + # 确保 index_store.json 文件路径存在 + index_store_path = os.path.join(persist_dir, "index_store.json") + if not os.path.exists(index_store_path): + logger.info(f"未找到 index_store.json,准备创建") + try: + with open(index_store_path, "w") as f: + f.write("{}") # 创建空的 index_store.json 文件 + logger.info(f"已创建 index_store.json 文件") + except Exception as e: + logger.error(f"创建 index_store.json 时出错: {e}") + raise HTTPException(status_code=500, detail="创建 index_store.json 文件失败") + else: + logger.info(f"已找到 index_store.json,跳过创建") + + # 创建 StorageContext,用于存储和管理索引数据 + storage_context = StorageContext.from_defaults( + persist_dir=persist_dir, + vector_store=vector_store, + ) + + try: + # 构建索引,并使用 `storage_context.persist()` 方法保存索引 + index = VectorStoreIndex.from_documents( + documents, + service_context=service_context, + storage_context=storage_context + ) + + # 保存 Faiss 索引为文件,使用正确的路径 + faiss_index_file = os.path.join(persist_dir, "index.faiss") + faiss.write_index(faiss_index, faiss_index_file) # 使用 Faiss 的 write_index 方法保存索引 + logger.info(f"Faiss 索引已保存到 {faiss_index_file}") + + # 使用 storage_context.persist() 保存其他索引数据 + storage_context.persist(persist_dir=persist_dir) + logger.info(f"索引数据已保存到 {persist_dir}") + + except Exception as e: + logger.error(f"索引构建失败: {e}") + raise HTTPException(status_code=500, detail="索引构建失败") diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index 4340685..997759e 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -54,8 +54,8 @@ def build_user_index(user_id: str): # 创建服务上下文 service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) - # 使用 Faiss 向量存储 - faiss_index = faiss.IndexFlatL2(1024) + # 创建 Faiss 向量存储 + faiss_index = faiss.IndexFlatL2(1024) # 假设每个向量是1024维度 vector_store = FaissVectorStore(faiss_index=faiss_index) # 确保索引保存路径存在,使用用户 ID 区分索引文件 @@ -89,14 +89,21 @@ def build_user_index(user_id: str): ) try: - # 构建索引,并使用 `storage_context.persist()` 方法保存索引 + # 生成嵌入向量 + embeddings = [embed_model._get_query_embedding(doc.text) for doc in documents] + logger.info(f"已生成 {len(embeddings)} 个嵌入向量") + + # 将嵌入向量存储到 Faiss + faiss_index.add(np.array(embeddings)) # 将嵌入向量添加到 Faiss 索引中 + + # 创建 VectorStoreIndex index = VectorStoreIndex.from_documents( documents, service_context=service_context, storage_context=storage_context ) - # 保存 Faiss 索引为文件,使用正确的路径 + # 保存 Faiss 索引为文件 faiss_index_file = os.path.join(persist_dir, "index.faiss") faiss.write_index(faiss_index, faiss_index_file) # 使用 Faiss 的 write_index 方法保存索引 logger.info(f"Faiss 索引已保存到 {faiss_index_file}")