faiss_rag_enterprise/scripts/rag_build_query.py

82 lines
3.0 KiB
Python

import os
import faiss
from typing import List
from llama_index import (
SimpleDirectoryReader,
VectorStoreIndex,
ServiceContext,
PromptTemplate,
)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
from app.core.config import settings
from scripts.permissions import get_user_allowed_indexes
USER_INDEX_PATH = "index_data"
USER_DOC_PATH = "docs"
def build_user_index(user_id: str):
doc_dir = os.path.join(USER_DOC_PATH, user_id)
if not os.path.exists(doc_dir):
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
documents = SimpleDirectoryReader(doc_dir).load_data()
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
service_context = ServiceContext.from_defaults(embed_model=embed_model)
# 构建向量索引
index = VectorStoreIndex.from_documents(
documents,
service_context=service_context,
vector_store=FaissVectorStore()
)
# 保存为用户专属 .index 文件
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
faiss.write_index(index.vector_store.index, index_path)
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
if not os.path.exists(index_path):
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
service_context = ServiceContext.from_defaults(embed_model=embed_model)
# 加载主索引
vector_store = FaissVectorStore.from_persist_path(index_path)
index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context)
nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question)
# 加载权限范围内的共享索引
shared_indexes = get_user_allowed_indexes(user_id)
for shared_name in shared_indexes:
shared_path = os.path.join(USER_INDEX_PATH, shared_name)
if os.path.exists(shared_path):
shared_store = FaissVectorStore.from_persist_path(shared_path)
shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context)
nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
# 构造 Prompt
context_str = "\n\n".join([n.get_content() for n in nodes])
prompt_template = PromptTemplate(
"请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}"
)
final_prompt = prompt_template.format(
context=context_str,
query=question,
)
print("[PROMPT构建完成]")
return final_prompt
# 示例:
if __name__ == "__main__":
uid = "user_001"
build_user_index(uid)
prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?")
print("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n")
print(prompt)