import os from llama_index import ( SimpleDirectoryReader, VectorStoreIndex, PromptHelper, PromptTemplate, ServiceContext, ) from llama_index.embeddings.base import BaseEmbedding from llama_index.vector_stores.faiss import FaissVectorStore from app.core.config import settings from scripts.permissions import get_user_allowed_indexes import faiss from typing import List import asyncio USER_INDEX_PATH = "index_data" USER_DOC_PATH = "docs" class CustomEmbedding(BaseEmbedding): model: SentenceTransformer # ✅ 显式声明是字段 def __init__(self, model_name: str): from sentence_transformers import SentenceTransformer self.model = SentenceTransformer(model_name) # 同步方法(必须实现) def _get_text_embedding(self, text: str) -> List[float]: return self.model.encode(text).tolist() def _get_query_embedding(self, query: str) -> List[float]: return self.model.encode(query).tolist() def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: return self.model.encode(texts).tolist() def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: return self.model.encode(queries).tolist() # 异步方法(必须实现,哪怕用同步方式包起来) async def _aget_query_embedding(self, query: str) -> List[float]: return self._get_query_embedding(query) async def _aget_query_embeddings(self, queries: List[str]) -> List[List[float]]: return self._get_query_embeddings(queries) def build_user_index(user_id: str): doc_dir = os.path.join(USER_DOC_PATH, user_id) if not os.path.exists(doc_dir): raise FileNotFoundError(f"文档目录不存在: {doc_dir}") documents = SimpleDirectoryReader(doc_dir).load_data() embed_model = CustomEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model) index = VectorStoreIndex.from_documents( documents, vector_store=FaissVectorStore(), service_context=service_context ) index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") faiss.write_index(index.vector_store.index, index_path) print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}") def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: embed_model = CustomEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model) all_nodes = [] # 加载用户主索引 index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") if not os.path.exists(index_path): raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在") user_store = FaissVectorStore.from_persist_path(index_path) user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context) all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question) # 加载共享索引 shared_indexes = get_user_allowed_indexes(user_id) if shared_indexes: for shared_name in shared_indexes: shared_path = os.path.join(USER_INDEX_PATH, shared_name) if os.path.exists(shared_path) and shared_path != index_path: shared_store = FaissVectorStore.from_persist_path(shared_path) shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context) all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) else: print(f"[INFO] 用户 {user_id} 没有共享索引权限") # 合并 + 按 score 排序 sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0)) top_nodes = sorted_nodes[:top_k] context_str = "\n\n".join([n.get_text() for n in top_nodes]) prompt_template = PromptTemplate( "请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}" ) final_prompt = prompt_template.format( context=context_str, query=question, ) print("[PROMPT构建完成]") return final_prompt # 示例: if __name__ == "__main__": uid = "user_001" build_user_index(uid) prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?") print("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n") print(prompt)