From cb949ac721997139d02dc93bbcebe5a2470d2219 Mon Sep 17 00:00:00 2001 From: hailin Date: Fri, 9 May 2025 22:46:39 +0800 Subject: [PATCH] . --- scripts/rag_build_query.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index 6698576..a1e793e 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -16,7 +16,7 @@ from scripts.permissions import get_user_allowed_indexes USER_INDEX_PATH = "index_data" USER_DOC_PATH = "docs" -# ✅ 自动加前缀的 bge-m3 embedding +# ✅ 自动加前缀的 BGE-m3 embedding 封装类 class BGEEmbedding(HuggingFaceEmbedding): def _get_query_embedding(self, query: str) -> List[float]: prefix = "Represent this sentence for searching relevant passages: " @@ -38,16 +38,19 @@ def build_user_index(user_id: str): faiss_index = faiss.IndexFlatL2(1024) vector_store = FaissVectorStore(faiss_index=faiss_index) persist_dir = os.path.join(USER_INDEX_PATH, user_id) + os.makedirs(persist_dir, exist_ok=True) + storage_context = StorageContext.from_defaults( persist_dir=persist_dir, - vector_store=vector_store + vector_store=vector_store, ) index = VectorStoreIndex.from_documents( documents, - storage_context=storage_context, - service_context=service_context + service_context=service_context, + storage_context=storage_context ) + index.persist(persist_dir=persist_dir) print(f"[BUILD] 为用户 {user_id} 构建并保存了完整索引 → {persist_dir}") @@ -65,15 +68,12 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: all_nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question) shared_indexes = get_user_allowed_indexes(user_id) - if shared_indexes: - for shared_name in shared_indexes: - shared_path = os.path.join(USER_INDEX_PATH, shared_name) - if os.path.exists(shared_path) and shared_path != persist_dir: - shared_storage = StorageContext.from_defaults(persist_dir=shared_path) - shared_index = VectorStoreIndex.load_from_storage(shared_storage, service_context=service_context) - all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) - else: - print(f"[INFO] 用户 {user_id} 没有共享索引权限") + for shared_name in shared_indexes: + shared_dir = os.path.join(USER_INDEX_PATH, shared_name) + if os.path.exists(shared_dir) and shared_dir != persist_dir: + shared_context = StorageContext.from_defaults(persist_dir=shared_dir) + shared_index = VectorStoreIndex.load_from_storage(shared_context, service_context=service_context) + all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0)) top_nodes = sorted_nodes[:top_k]