From 01f2317bcc562449c932ffd34617bd580cfb015d Mon Sep 17 00:00:00 2001 From: hailin Date: Fri, 9 May 2025 20:24:04 +0800 Subject: [PATCH] . --- scripts/rag_build_query.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index 9b86d70..37b6e78 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -1,12 +1,13 @@ import os -from llama_index.core import ( +from llama_index import ( SimpleDirectoryReader, VectorStoreIndex, + PromptHelper, PromptTemplate, - Settings, + ServiceContext, ) -from llama_index.embeddings.huggingface import HuggingFaceEmbedding +from llama_index.embeddings.base import BaseEmbedding from llama_index.vector_stores.faiss import FaissVectorStore from app.core.config import settings from scripts.permissions import get_user_allowed_indexes @@ -15,18 +16,31 @@ import faiss USER_INDEX_PATH = "index_data" USER_DOC_PATH = "docs" +class CustomEmbedding(BaseEmbedding): + def __init__(self, model_name: str): + from sentence_transformers import SentenceTransformer + self.model = SentenceTransformer(model_name) + + def embed(self, text: str): + return self.model.encode(text).tolist() + + def embed_batch(self, texts: list): + return self.model.encode(texts).tolist() + def build_user_index(user_id: str): doc_dir = os.path.join(USER_DOC_PATH, user_id) if not os.path.exists(doc_dir): raise FileNotFoundError(f"文档目录不存在: {doc_dir}") documents = SimpleDirectoryReader(doc_dir).load_data() - embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) - Settings.embed_model = embed_model # ✅ 新式配置 + embed_model = CustomEmbedding(model_name=settings.MODEL_NAME) + + service_context = ServiceContext.from_defaults(embed_model=embed_model) index = VectorStoreIndex.from_documents( documents, - vector_store=FaissVectorStore() + vector_store=FaissVectorStore(), + service_context=service_context ) index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") @@ -34,8 +48,8 @@ def build_user_index(user_id: str): print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}") def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: - embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) - Settings.embed_model = embed_model # ✅ 全局设置一次即可 + embed_model = CustomEmbedding(model_name=settings.MODEL_NAME) + service_context = ServiceContext.from_defaults(embed_model=embed_model) all_nodes = [] @@ -44,7 +58,7 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: if not os.path.exists(index_path): raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在") user_store = FaissVectorStore.from_persist_path(index_path) - user_index = VectorStoreIndex.from_vector_store(user_store) + user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context) all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question) # 加载共享索引 @@ -54,7 +68,7 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: shared_path = os.path.join(USER_INDEX_PATH, shared_name) if os.path.exists(shared_path) and shared_path != index_path: shared_store = FaissVectorStore.from_persist_path(shared_path) - shared_index = VectorStoreIndex.from_vector_store(shared_store) + shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context) all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) else: print(f"[INFO] 用户 {user_id} 没有共享索引权限") @@ -63,7 +77,7 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0)) top_nodes = sorted_nodes[:top_k] - context_str = "\n\n".join([n.get_content() for n in top_nodes]) + context_str = "\n\n".join([n.get_text() for n in top_nodes]) prompt_template = PromptTemplate( "请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}" )