diff --git a/app/core/config.py b/app/core/config.py index ec4cfbe..f6b786d 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -1,4 +1,4 @@ -from pydantic import BaseSettings +from pydantic_settings import BaseSettings import os class Settings(BaseSettings): @@ -6,7 +6,7 @@ class Settings(BaseSettings): EMBEDDING_DIM: int = 768 TOP_K: int = 5 DOC_PATH: str = "docs/" - DEVICE: str = "cpu" # 可设置为 cuda:0 + DEVICE: str = "cpu" MODEL_NAME: str = "BAAI/bge-m3" -settings = Settings() \ No newline at end of file +settings = Settings() diff --git a/requirements.txt b/requirements.txt index 56f4a46..1ac1eac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ gunicorn pydantic numpy transformers -torch \ No newline at end of file +torch +llama-index==0.12.34 \ No newline at end of file diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index 897a856..be26149 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -9,10 +9,9 @@ from llama_index import ( ) from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.vector_stores.faiss import FaissVectorStore -from llama_index.llms.base import ChatMessage +from app.core.config import settings +from scripts.permissions import get_user_allowed_indexes -# 假设你要用的本地嵌入模型 -EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" USER_INDEX_PATH = "index_data" USER_DOC_PATH = "docs" @@ -22,7 +21,7 @@ def build_user_index(user_id: str): raise FileNotFoundError(f"文档目录不存在: {doc_dir}") documents = SimpleDirectoryReader(doc_dir).load_data() - embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_NAME) + embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model) # 构建向量索引 @@ -37,22 +36,28 @@ def build_user_index(user_id: str): faiss.write_index(index.vector_store.index, index_path) print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}") -from scripts.permissions import get_user_allowed_indexes - def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") if not os.path.exists(index_path): raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在") - embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_NAME) + embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model) - # 加载索引 + # 加载主索引 vector_store = FaissVectorStore.from_persist_path(index_path) index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context) - retriever = index.as_retriever(similarity_top_k=top_k) - nodes = retriever.retrieve(question) + nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question) + + # 加载权限范围内的共享索引 + shared_indexes = get_user_allowed_indexes(user_id) + for shared_name in shared_indexes: + shared_path = os.path.join(USER_INDEX_PATH, shared_name) + if os.path.exists(shared_path): + shared_store = FaissVectorStore.from_persist_path(shared_path) + shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context) + nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) # 构造 Prompt context_str = "\n\n".join([n.get_content() for n in nodes]) @@ -73,4 +78,4 @@ if __name__ == "__main__": build_user_index(uid) prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?") print("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n") - print(prompt) \ No newline at end of file + print(prompt)