import os from typing import List import faiss from llama_index import ( SimpleDirectoryReader, VectorStoreIndex, ServiceContext, PromptTemplate, StorageContext, ) from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.vector_stores.faiss import FaissVectorStore from app.core.config import settings from scripts.permissions import get_user_allowed_indexes USER_INDEX_PATH = "index_data" USER_DOC_PATH = "docs" # ✅ 自动加前缀的 BGE-m3 embedding 封装类 class BGEEmbedding(HuggingFaceEmbedding): def _get_query_embedding(self, query: str) -> List[float]: prefix = "Represent this sentence for searching relevant passages: " return super()._get_query_embedding(prefix + query) def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: prefix = "Represent this sentence for searching relevant passages: " return super()._get_query_embeddings([prefix + q for q in queries]) def build_user_index(user_id: str): doc_dir = os.path.join(USER_DOC_PATH, user_id) if not os.path.exists(doc_dir): raise FileNotFoundError(f"文档目录不存在: {doc_dir}") documents = SimpleDirectoryReader(doc_dir).load_data() embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) faiss_index = faiss.IndexFlatL2(1024) vector_store = FaissVectorStore(faiss_index=faiss_index) persist_dir = os.path.join(USER_INDEX_PATH, user_id) os.makedirs(persist_dir, exist_ok=True) storage_context = StorageContext.from_defaults( persist_dir=persist_dir, vector_store=vector_store, ) index = VectorStoreIndex.from_documents( documents, service_context=service_context, storage_context=storage_context ) index.persist(persist_dir=persist_dir) print(f"[BUILD] 为用户 {user_id} 构建并保存了完整索引 → {persist_dir}") def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) persist_dir = os.path.join(USER_INDEX_PATH, user_id) if not os.path.exists(persist_dir): raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引目录不存在") storage_context = StorageContext.from_defaults(persist_dir=persist_dir) index = VectorStoreIndex.load_from_storage(storage_context, service_context=service_context) all_nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question) shared_indexes = get_user_allowed_indexes(user_id) for shared_name in shared_indexes: shared_dir = os.path.join(USER_INDEX_PATH, shared_name) if os.path.exists(shared_dir) and shared_dir != persist_dir: shared_context = StorageContext.from_defaults(persist_dir=shared_dir) shared_index = VectorStoreIndex.load_from_storage(shared_context, service_context=service_context) all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0)) top_nodes = sorted_nodes[:top_k] context_str = "\n\n".join([n.get_text() for n in top_nodes]) prompt_template = PromptTemplate( "请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}" ) final_prompt = prompt_template.format(context=context_str, query=question) print("[PROMPT构建完成]") return final_prompt if __name__ == "__main__": uid = "user_001" build_user_index(uid) prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?") print("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n") print(prompt)