diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index ce0547a..c311443 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -19,7 +19,7 @@ from scripts.permissions import get_user_allowed_indexes USER_INDEX_PATH = "index_data" USER_DOC_PATH = "docs" -# ✅ 替代 CustomEmbedding,用于 bge-m3 模型,自动加前缀 +# ✅ BGE-m3 模型嵌入类,加前缀 class BGEEmbedding(HuggingFaceEmbedding): def _get_query_embedding(self, query: str) -> List[float]: prefix = "Represent this sentence for searching relevant passages: " @@ -36,11 +36,9 @@ def build_user_index(user_id: str): documents = SimpleDirectoryReader(doc_dir).load_data() embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) - service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) - # ✅ 指定正确维度:bge-m3 是 1024 - faiss_index = faiss.IndexFlatL2(1024) + faiss_index = faiss.IndexFlatL2(1024) # ✅ bge-m3 是 1024维 vector_store = FaissVectorStore(faiss_index=faiss_index) index = VectorStoreIndex.from_documents( @@ -50,7 +48,7 @@ def build_user_index(user_id: str): ) index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") - faiss.write_index(index.vector_store.index, index_path) + faiss.write_index(vector_store.index, index_path) # ✅ 用传入的 vector_store print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}") def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: @@ -59,7 +57,6 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: all_nodes = [] - # 加载用户主索引 index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") if not os.path.exists(index_path): raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在") @@ -67,7 +64,6 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context) all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question) - # 加载共享索引 shared_indexes = get_user_allowed_indexes(user_id) if shared_indexes: for shared_name in shared_indexes: @@ -79,7 +75,6 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: else: print(f"[INFO] 用户 {user_id} 没有共享索引权限") - # 合并 + 按 score 排序 sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0)) top_nodes = sorted_nodes[:top_k]