diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index d6f241c..69444c6 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -1,40 +1,14 @@ import os import logging -from typing import List -import faiss -from llama_index import ( - SimpleDirectoryReader, - VectorStoreIndex, - ServiceContext, - PromptTemplate, - StorageContext, -) -from llama_index.embeddings.huggingface import HuggingFaceEmbedding -from llama_index.vector_stores.faiss import FaissVectorStore -from app.core.config import settings -from scripts.permissions import get_user_allowed_indexes - -USER_INDEX_PATH = "index_data" -USER_DOC_PATH = "docs" - -# 设置日志配置 -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -# ✅ 自动加前缀的 BGE-m3 embedding 封装类 -class BGEEmbedding(HuggingFaceEmbedding): - def _get_query_embedding(self, query: str) -> List[float]: - prefix = "Represent this sentence for searching relevant passages: " - return super()._get_query_embedding(prefix + query) - - def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: - prefix = "Represent this sentence for searching relevant passages: " - return super()._get_query_embeddings([prefix + q for q in queries]) def build_user_index(user_id: str): + # 设置日志 + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + logger.info(f"开始为用户 {user_id} 构建索引...") - doc_dir = os.path.join(USER_DOC_PATH, user_id) + doc_dir = os.path.join(USER_DOC_PATH, user_id) if not os.path.exists(doc_dir): raise FileNotFoundError(f"文档目录不存在: {doc_dir}") @@ -54,13 +28,21 @@ def build_user_index(user_id: str): persist_dir = os.path.join(USER_INDEX_PATH, user_id) os.makedirs(persist_dir, exist_ok=True) logger.info(f"索引保存路径: {persist_dir}") - + + # 检查目录中是否存在 index_store.json 文件 + index_store_path = os.path.join(persist_dir, "index_store.json") + if not os.path.exists(index_store_path): + logger.info(f"未找到 index_store.json,准备创建") + else: + logger.info(f"已找到 index_store.json,跳过创建") + storage_context = StorageContext.from_defaults( persist_dir=persist_dir, vector_store=vector_store, ) try: + # 构建索引 index = VectorStoreIndex.from_documents( documents, service_context=service_context, @@ -71,53 +53,3 @@ def build_user_index(user_id: str): except Exception as e: logger.error(f"索引构建失败: {e}") raise HTTPException(status_code=500, detail="索引构建失败") - -def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: - logger.info(f"为用户 {user_id} 查询问题:{question}") - - embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) - service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) - - persist_dir = os.path.join(USER_INDEX_PATH, user_id) - if not os.path.exists(persist_dir): - raise FileNotFoundError(f"用户 {user_id} 的索引目录不存在") - - logger.info(f"加载索引目录: {persist_dir}") - - storage_context = StorageContext.from_defaults(persist_dir=persist_dir) - index = VectorStoreIndex.load_from_storage(storage_context, service_context=service_context) - - logger.info(f"加载索引成功,开始检索相关文档...") - - all_nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question) - - shared_indexes = get_user_allowed_indexes(user_id) - logger.info(f"用户 {user_id} 被允许共享的索引:{shared_indexes}") - - for shared_name in shared_indexes: - shared_dir = os.path.join(USER_INDEX_PATH, shared_name) - if os.path.exists(shared_dir) and shared_dir != persist_dir: - shared_context = StorageContext.from_defaults(persist_dir=shared_dir) - shared_index = VectorStoreIndex.load_from_storage(shared_context, service_context=service_context) - all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) - - logger.info(f"共检索到 {len(all_nodes)} 个相关文档") - - sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0)) - top_nodes = sorted_nodes[:top_k] - - context_str = "\n\n".join([n.get_text() for n in top_nodes]) - prompt_template = PromptTemplate( - "请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}" - ) - final_prompt = prompt_template.format(context=context_str, query=question) - - logger.info("[PROMPT构建完成]") - return final_prompt - -if __name__ == "__main__": - uid = "user_001" - build_user_index(uid) - prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?") - logger.info("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n") - logger.info(prompt)