From 301795569cc0248d843582392e7e249097da551c Mon Sep 17 00:00:00 2001 From: hailin Date: Fri, 9 May 2025 20:01:38 +0800 Subject: [PATCH] . --- scripts/rag_build_query.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index 714c2b0..9b86d70 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -1,13 +1,16 @@ import os -from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, ServiceContext, PromptTemplate +from llama_index.core import ( + SimpleDirectoryReader, + VectorStoreIndex, + PromptTemplate, + Settings, +) from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.vector_stores.faiss import FaissVectorStore from app.core.config import settings from scripts.permissions import get_user_allowed_indexes - - - +import faiss USER_INDEX_PATH = "index_data" USER_DOC_PATH = "docs" @@ -19,11 +22,10 @@ def build_user_index(user_id: str): documents = SimpleDirectoryReader(doc_dir).load_data() embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) - service_context = ServiceContext.from_defaults(embed_model=embed_model) - + Settings.embed_model = embed_model # ✅ 新式配置 + index = VectorStoreIndex.from_documents( documents, - service_context=service_context, vector_store=FaissVectorStore() ) @@ -33,7 +35,7 @@ def build_user_index(user_id: str): def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) - service_context = ServiceContext.from_defaults(embed_model=embed_model) + Settings.embed_model = embed_model # ✅ 全局设置一次即可 all_nodes = [] @@ -42,7 +44,7 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: if not os.path.exists(index_path): raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在") user_store = FaissVectorStore.from_persist_path(index_path) - user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context) + user_index = VectorStoreIndex.from_vector_store(user_store) all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question) # 加载共享索引 @@ -52,7 +54,7 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: shared_path = os.path.join(USER_INDEX_PATH, shared_name) if os.path.exists(shared_path) and shared_path != index_path: shared_store = FaissVectorStore.from_persist_path(shared_path) - shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context) + shared_index = VectorStoreIndex.from_vector_store(shared_store) all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) else: print(f"[INFO] 用户 {user_id} 没有共享索引权限")