diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index 4af3de0..6698576 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -1,15 +1,12 @@ import os from typing import List -import asyncio import faiss -from sentence_transformers import SentenceTransformer - from llama_index import ( SimpleDirectoryReader, VectorStoreIndex, - PromptHelper, - PromptTemplate, ServiceContext, + PromptTemplate, + StorageContext, ) from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.vector_stores.faiss import FaissVectorStore @@ -19,7 +16,7 @@ from scripts.permissions import get_user_allowed_indexes USER_INDEX_PATH = "index_data" USER_DOC_PATH = "docs" -# ✅ BGE-m3 模型嵌入类,加前缀 +# ✅ 自动加前缀的 bge-m3 embedding class BGEEmbedding(HuggingFaceEmbedding): def _get_query_embedding(self, query: str) -> List[float]: prefix = "Represent this sentence for searching relevant passages: " @@ -33,44 +30,47 @@ def build_user_index(user_id: str): doc_dir = os.path.join(USER_DOC_PATH, user_id) if not os.path.exists(doc_dir): raise FileNotFoundError(f"文档目录不存在: {doc_dir}") - + documents = SimpleDirectoryReader(doc_dir).load_data() embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) - faiss_index = faiss.IndexFlatL2(1024) # ✅ bge-m3 是 1024维 + faiss_index = faiss.IndexFlatL2(1024) vector_store = FaissVectorStore(faiss_index=faiss_index) + persist_dir = os.path.join(USER_INDEX_PATH, user_id) + storage_context = StorageContext.from_defaults( + persist_dir=persist_dir, + vector_store=vector_store + ) index = VectorStoreIndex.from_documents( documents, - vector_store=vector_store, + storage_context=storage_context, service_context=service_context ) - - index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") - faiss.write_index(faiss_index, index_path) # ✅ 改这里,直接写 faiss_index - print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}") + index.persist(persist_dir=persist_dir) + print(f"[BUILD] 为用户 {user_id} 构建并保存了完整索引 → {persist_dir}") def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) - all_nodes = [] + persist_dir = os.path.join(USER_INDEX_PATH, user_id) + if not os.path.exists(persist_dir): + raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引目录不存在") - index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") - if not os.path.exists(index_path): - raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在") - user_store = FaissVectorStore.from_persist_path(index_path) - user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context) - all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question) + storage_context = StorageContext.from_defaults(persist_dir=persist_dir) + index = VectorStoreIndex.load_from_storage(storage_context, service_context=service_context) + + all_nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question) shared_indexes = get_user_allowed_indexes(user_id) if shared_indexes: for shared_name in shared_indexes: shared_path = os.path.join(USER_INDEX_PATH, shared_name) - if os.path.exists(shared_path) and shared_path != index_path: - shared_store = FaissVectorStore.from_persist_path(shared_path) - shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context) + if os.path.exists(shared_path) and shared_path != persist_dir: + shared_storage = StorageContext.from_defaults(persist_dir=shared_path) + shared_index = VectorStoreIndex.load_from_storage(shared_storage, service_context=service_context) all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) else: print(f"[INFO] 用户 {user_id} 没有共享索引权限") @@ -82,15 +82,11 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: prompt_template = PromptTemplate( "请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}" ) - final_prompt = prompt_template.format( - context=context_str, - query=question, - ) + final_prompt = prompt_template.format(context=context_str, query=question) print("[PROMPT构建完成]") return final_prompt -# 示例: if __name__ == "__main__": uid = "user_001" build_user_index(uid)