faiss_rag_enterprise/scripts/rag_build_query.py

105 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from typing import List
import asyncio
import faiss
from sentence_transformers import SentenceTransformer
from llama_index import (
SimpleDirectoryReader,
VectorStoreIndex,
PromptHelper,
PromptTemplate,
ServiceContext,
)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
from app.core.config import settings
from scripts.permissions import get_user_allowed_indexes
USER_INDEX_PATH = "index_data"
USER_DOC_PATH = "docs"
# ✅ 替代 CustomEmbedding用于 bge-m3 模型,自动加前缀
class BGEEmbedding(HuggingFaceEmbedding):
def _get_query_embedding(self, query: str) -> List[float]:
prefix = "Represent this sentence for searching relevant passages: "
return super()._get_query_embedding(prefix + query)
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
prefix = "Represent this sentence for searching relevant passages: "
return super()._get_query_embeddings([prefix + q for q in queries])
def build_user_index(user_id: str):
doc_dir = os.path.join(USER_DOC_PATH, user_id)
if not os.path.exists(doc_dir):
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
documents = SimpleDirectoryReader(doc_dir).load_data()
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
# ✅ 指定正确维度bge-m3 是 1024
faiss_index = faiss.IndexFlatL2(1024)
vector_store = FaissVectorStore(faiss_index=faiss_index)
index = VectorStoreIndex.from_documents(
documents,
vector_store=vector_store,
service_context=service_context
)
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
faiss.write_index(index.vector_store.index, index_path)
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
all_nodes = []
# 加载用户主索引
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
if not os.path.exists(index_path):
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
user_store = FaissVectorStore.from_persist_path(index_path)
user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context)
all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question)
# 加载共享索引
shared_indexes = get_user_allowed_indexes(user_id)
if shared_indexes:
for shared_name in shared_indexes:
shared_path = os.path.join(USER_INDEX_PATH, shared_name)
if os.path.exists(shared_path) and shared_path != index_path:
shared_store = FaissVectorStore.from_persist_path(shared_path)
shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context)
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
else:
print(f"[INFO] 用户 {user_id} 没有共享索引权限")
# 合并 + 按 score 排序
sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
top_nodes = sorted_nodes[:top_k]
context_str = "\n\n".join([n.get_text() for n in top_nodes])
prompt_template = PromptTemplate(
"请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}"
)
final_prompt = prompt_template.format(
context=context_str,
query=question,
)
print("[PROMPT构建完成]")
return final_prompt
# 示例:
if __name__ == "__main__":
uid = "user_001"
build_user_index(uid)
prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?")
print("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n")
print(prompt)