This commit is contained in:
hailin 2025-05-09 23:45:35 +08:00
parent 8ffe24de94
commit a572143f62
1 changed files with 14 additions and 82 deletions

View File

@ -1,40 +1,14 @@
import os import os
import logging import logging
from typing import List
import faiss
from llama_index import (
SimpleDirectoryReader,
VectorStoreIndex,
ServiceContext,
PromptTemplate,
StorageContext,
)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
from app.core.config import settings
from scripts.permissions import get_user_allowed_indexes
USER_INDEX_PATH = "index_data"
USER_DOC_PATH = "docs"
# 设置日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ✅ 自动加前缀的 BGE-m3 embedding 封装类
class BGEEmbedding(HuggingFaceEmbedding):
def _get_query_embedding(self, query: str) -> List[float]:
prefix = "Represent this sentence for searching relevant passages: "
return super()._get_query_embedding(prefix + query)
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
prefix = "Represent this sentence for searching relevant passages: "
return super()._get_query_embeddings([prefix + q for q in queries])
def build_user_index(user_id: str): def build_user_index(user_id: str):
# 设置日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.info(f"开始为用户 {user_id} 构建索引...") logger.info(f"开始为用户 {user_id} 构建索引...")
doc_dir = os.path.join(USER_DOC_PATH, user_id)
doc_dir = os.path.join(USER_DOC_PATH, user_id)
if not os.path.exists(doc_dir): if not os.path.exists(doc_dir):
raise FileNotFoundError(f"文档目录不存在: {doc_dir}") raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
@ -54,13 +28,21 @@ def build_user_index(user_id: str):
persist_dir = os.path.join(USER_INDEX_PATH, user_id) persist_dir = os.path.join(USER_INDEX_PATH, user_id)
os.makedirs(persist_dir, exist_ok=True) os.makedirs(persist_dir, exist_ok=True)
logger.info(f"索引保存路径: {persist_dir}") logger.info(f"索引保存路径: {persist_dir}")
# 检查目录中是否存在 index_store.json 文件
index_store_path = os.path.join(persist_dir, "index_store.json")
if not os.path.exists(index_store_path):
logger.info(f"未找到 index_store.json准备创建")
else:
logger.info(f"已找到 index_store.json跳过创建")
storage_context = StorageContext.from_defaults( storage_context = StorageContext.from_defaults(
persist_dir=persist_dir, persist_dir=persist_dir,
vector_store=vector_store, vector_store=vector_store,
) )
try: try:
# 构建索引
index = VectorStoreIndex.from_documents( index = VectorStoreIndex.from_documents(
documents, documents,
service_context=service_context, service_context=service_context,
@ -71,53 +53,3 @@ def build_user_index(user_id: str):
except Exception as e: except Exception as e:
logger.error(f"索引构建失败: {e}") logger.error(f"索引构建失败: {e}")
raise HTTPException(status_code=500, detail="索引构建失败") raise HTTPException(status_code=500, detail="索引构建失败")
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
logger.info(f"为用户 {user_id} 查询问题:{question}")
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
if not os.path.exists(persist_dir):
raise FileNotFoundError(f"用户 {user_id} 的索引目录不存在")
logger.info(f"加载索引目录: {persist_dir}")
storage_context = StorageContext.from_defaults(persist_dir=persist_dir)
index = VectorStoreIndex.load_from_storage(storage_context, service_context=service_context)
logger.info(f"加载索引成功,开始检索相关文档...")
all_nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question)
shared_indexes = get_user_allowed_indexes(user_id)
logger.info(f"用户 {user_id} 被允许共享的索引:{shared_indexes}")
for shared_name in shared_indexes:
shared_dir = os.path.join(USER_INDEX_PATH, shared_name)
if os.path.exists(shared_dir) and shared_dir != persist_dir:
shared_context = StorageContext.from_defaults(persist_dir=shared_dir)
shared_index = VectorStoreIndex.load_from_storage(shared_context, service_context=service_context)
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
logger.info(f"共检索到 {len(all_nodes)} 个相关文档")
sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
top_nodes = sorted_nodes[:top_k]
context_str = "\n\n".join([n.get_text() for n in top_nodes])
prompt_template = PromptTemplate(
"请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}"
)
final_prompt = prompt_template.format(context=context_str, query=question)
logger.info("[PROMPT构建完成]")
return final_prompt
if __name__ == "__main__":
uid = "user_001"
build_user_index(uid)
prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?")
logger.info("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n")
logger.info(prompt)