diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index a1e793e..d6f241c 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -1,4 +1,5 @@ import os +import logging from typing import List import faiss from llama_index import ( @@ -16,6 +17,10 @@ from scripts.permissions import get_user_allowed_indexes USER_INDEX_PATH = "index_data" USER_DOC_PATH = "docs" +# 设置日志配置 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + # ✅ 自动加前缀的 BGE-m3 embedding 封装类 class BGEEmbedding(HuggingFaceEmbedding): def _get_query_embedding(self, query: str) -> List[float]: @@ -27,47 +32,68 @@ class BGEEmbedding(HuggingFaceEmbedding): return super()._get_query_embeddings([prefix + q for q in queries]) def build_user_index(user_id: str): + logger.info(f"开始为用户 {user_id} 构建索引...") doc_dir = os.path.join(USER_DOC_PATH, user_id) + if not os.path.exists(doc_dir): raise FileNotFoundError(f"文档目录不存在: {doc_dir}") - + + logger.info(f"发现文档目录: {doc_dir}") + documents = SimpleDirectoryReader(doc_dir).load_data() + logger.info(f"载入文档数量: {len(documents)}") + embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) + logger.info(f"使用模型: {settings.MODEL_NAME}") + service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) faiss_index = faiss.IndexFlatL2(1024) vector_store = FaissVectorStore(faiss_index=faiss_index) + persist_dir = os.path.join(USER_INDEX_PATH, user_id) os.makedirs(persist_dir, exist_ok=True) - + logger.info(f"索引保存路径: {persist_dir}") + storage_context = StorageContext.from_defaults( persist_dir=persist_dir, vector_store=vector_store, ) - index = VectorStoreIndex.from_documents( - documents, - service_context=service_context, - storage_context=storage_context - ) - - index.persist(persist_dir=persist_dir) - print(f"[BUILD] 为用户 {user_id} 构建并保存了完整索引 → {persist_dir}") + try: + index = VectorStoreIndex.from_documents( + documents, + service_context=service_context, + storage_context=storage_context + ) + index.persist(persist_dir=persist_dir) + logger.info(f"索引已保存到 {persist_dir}") + except Exception as e: + logger.error(f"索引构建失败: {e}") + raise HTTPException(status_code=500, detail="索引构建失败") def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: + logger.info(f"为用户 {user_id} 查询问题:{question}") + embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) persist_dir = os.path.join(USER_INDEX_PATH, user_id) if not os.path.exists(persist_dir): - raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引目录不存在") + raise FileNotFoundError(f"用户 {user_id} 的索引目录不存在") + logger.info(f"加载索引目录: {persist_dir}") + storage_context = StorageContext.from_defaults(persist_dir=persist_dir) index = VectorStoreIndex.load_from_storage(storage_context, service_context=service_context) + logger.info(f"加载索引成功,开始检索相关文档...") + all_nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question) shared_indexes = get_user_allowed_indexes(user_id) + logger.info(f"用户 {user_id} 被允许共享的索引:{shared_indexes}") + for shared_name in shared_indexes: shared_dir = os.path.join(USER_INDEX_PATH, shared_name) if os.path.exists(shared_dir) and shared_dir != persist_dir: @@ -75,6 +101,8 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: shared_index = VectorStoreIndex.load_from_storage(shared_context, service_context=service_context) all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) + logger.info(f"共检索到 {len(all_nodes)} 个相关文档") + sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0)) top_nodes = sorted_nodes[:top_k] @@ -84,12 +112,12 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: ) final_prompt = prompt_template.format(context=context_str, query=question) - print("[PROMPT构建完成]") + logger.info("[PROMPT构建完成]") return final_prompt if __name__ == "__main__": uid = "user_001" build_user_index(uid) prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?") - print("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n") - print(prompt) + logger.info("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n") + logger.info(prompt)