This commit is contained in:
parent
92282c2060
commit
8ffe24de94
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
import faiss
|
import faiss
|
||||||
from llama_index import (
|
from llama_index import (
|
||||||
|
|
@ -16,6 +17,10 @@ from scripts.permissions import get_user_allowed_indexes
|
||||||
USER_INDEX_PATH = "index_data"
|
USER_INDEX_PATH = "index_data"
|
||||||
USER_DOC_PATH = "docs"
|
USER_DOC_PATH = "docs"
|
||||||
|
|
||||||
|
# 设置日志配置
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ✅ 自动加前缀的 BGE-m3 embedding 封装类
|
# ✅ 自动加前缀的 BGE-m3 embedding 封装类
|
||||||
class BGEEmbedding(HuggingFaceEmbedding):
|
class BGEEmbedding(HuggingFaceEmbedding):
|
||||||
def _get_query_embedding(self, query: str) -> List[float]:
|
def _get_query_embedding(self, query: str) -> List[float]:
|
||||||
|
|
@ -27,47 +32,68 @@ class BGEEmbedding(HuggingFaceEmbedding):
|
||||||
return super()._get_query_embeddings([prefix + q for q in queries])
|
return super()._get_query_embeddings([prefix + q for q in queries])
|
||||||
|
|
||||||
def build_user_index(user_id: str):
|
def build_user_index(user_id: str):
|
||||||
|
logger.info(f"开始为用户 {user_id} 构建索引...")
|
||||||
doc_dir = os.path.join(USER_DOC_PATH, user_id)
|
doc_dir = os.path.join(USER_DOC_PATH, user_id)
|
||||||
|
|
||||||
if not os.path.exists(doc_dir):
|
if not os.path.exists(doc_dir):
|
||||||
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
|
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
|
||||||
|
|
||||||
|
logger.info(f"发现文档目录: {doc_dir}")
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(doc_dir).load_data()
|
documents = SimpleDirectoryReader(doc_dir).load_data()
|
||||||
|
logger.info(f"载入文档数量: {len(documents)}")
|
||||||
|
|
||||||
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
|
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
|
||||||
|
logger.info(f"使用模型: {settings.MODEL_NAME}")
|
||||||
|
|
||||||
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
|
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
|
||||||
|
|
||||||
faiss_index = faiss.IndexFlatL2(1024)
|
faiss_index = faiss.IndexFlatL2(1024)
|
||||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
|
|
||||||
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
|
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
|
||||||
os.makedirs(persist_dir, exist_ok=True)
|
os.makedirs(persist_dir, exist_ok=True)
|
||||||
|
logger.info(f"索引保存路径: {persist_dir}")
|
||||||
|
|
||||||
storage_context = StorageContext.from_defaults(
|
storage_context = StorageContext.from_defaults(
|
||||||
persist_dir=persist_dir,
|
persist_dir=persist_dir,
|
||||||
vector_store=vector_store,
|
vector_store=vector_store,
|
||||||
)
|
)
|
||||||
|
|
||||||
index = VectorStoreIndex.from_documents(
|
try:
|
||||||
documents,
|
index = VectorStoreIndex.from_documents(
|
||||||
service_context=service_context,
|
documents,
|
||||||
storage_context=storage_context
|
service_context=service_context,
|
||||||
)
|
storage_context=storage_context
|
||||||
|
)
|
||||||
index.persist(persist_dir=persist_dir)
|
index.persist(persist_dir=persist_dir)
|
||||||
print(f"[BUILD] 为用户 {user_id} 构建并保存了完整索引 → {persist_dir}")
|
logger.info(f"索引已保存到 {persist_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"索引构建失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="索引构建失败")
|
||||||
|
|
||||||
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
|
logger.info(f"为用户 {user_id} 查询问题:{question}")
|
||||||
|
|
||||||
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
|
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
|
||||||
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
|
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
|
||||||
|
|
||||||
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
|
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
|
||||||
if not os.path.exists(persist_dir):
|
if not os.path.exists(persist_dir):
|
||||||
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引目录不存在")
|
raise FileNotFoundError(f"用户 {user_id} 的索引目录不存在")
|
||||||
|
|
||||||
|
logger.info(f"加载索引目录: {persist_dir}")
|
||||||
|
|
||||||
storage_context = StorageContext.from_defaults(persist_dir=persist_dir)
|
storage_context = StorageContext.from_defaults(persist_dir=persist_dir)
|
||||||
index = VectorStoreIndex.load_from_storage(storage_context, service_context=service_context)
|
index = VectorStoreIndex.load_from_storage(storage_context, service_context=service_context)
|
||||||
|
|
||||||
|
logger.info(f"加载索引成功,开始检索相关文档...")
|
||||||
|
|
||||||
all_nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
all_nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||||
|
|
||||||
shared_indexes = get_user_allowed_indexes(user_id)
|
shared_indexes = get_user_allowed_indexes(user_id)
|
||||||
|
logger.info(f"用户 {user_id} 被允许共享的索引:{shared_indexes}")
|
||||||
|
|
||||||
for shared_name in shared_indexes:
|
for shared_name in shared_indexes:
|
||||||
shared_dir = os.path.join(USER_INDEX_PATH, shared_name)
|
shared_dir = os.path.join(USER_INDEX_PATH, shared_name)
|
||||||
if os.path.exists(shared_dir) and shared_dir != persist_dir:
|
if os.path.exists(shared_dir) and shared_dir != persist_dir:
|
||||||
|
|
@ -75,6 +101,8 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
shared_index = VectorStoreIndex.load_from_storage(shared_context, service_context=service_context)
|
shared_index = VectorStoreIndex.load_from_storage(shared_context, service_context=service_context)
|
||||||
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||||
|
|
||||||
|
logger.info(f"共检索到 {len(all_nodes)} 个相关文档")
|
||||||
|
|
||||||
sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
|
sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
|
||||||
top_nodes = sorted_nodes[:top_k]
|
top_nodes = sorted_nodes[:top_k]
|
||||||
|
|
||||||
|
|
@ -84,12 +112,12 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
)
|
)
|
||||||
final_prompt = prompt_template.format(context=context_str, query=question)
|
final_prompt = prompt_template.format(context=context_str, query=question)
|
||||||
|
|
||||||
print("[PROMPT构建完成]")
|
logger.info("[PROMPT构建完成]")
|
||||||
return final_prompt
|
return final_prompt
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uid = "user_001"
|
uid = "user_001"
|
||||||
build_user_index(uid)
|
build_user_index(uid)
|
||||||
prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?")
|
prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?")
|
||||||
print("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n")
|
logger.info("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n")
|
||||||
print(prompt)
|
logger.info(prompt)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue