This commit is contained in:
hailin 2025-05-09 20:24:04 +08:00
parent 301795569c
commit 01f2317bcc
1 changed files with 25 additions and 11 deletions

View File

@ -1,12 +1,13 @@
import os import os
from llama_index.core import ( from llama_index import (
SimpleDirectoryReader, SimpleDirectoryReader,
VectorStoreIndex, VectorStoreIndex,
PromptHelper,
PromptTemplate, PromptTemplate,
Settings, ServiceContext,
) )
from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.embeddings.base import BaseEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore from llama_index.vector_stores.faiss import FaissVectorStore
from app.core.config import settings from app.core.config import settings
from scripts.permissions import get_user_allowed_indexes from scripts.permissions import get_user_allowed_indexes
@ -15,18 +16,31 @@ import faiss
USER_INDEX_PATH = "index_data" USER_INDEX_PATH = "index_data"
USER_DOC_PATH = "docs" USER_DOC_PATH = "docs"
class CustomEmbedding(BaseEmbedding):
def __init__(self, model_name: str):
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model_name)
def embed(self, text: str):
return self.model.encode(text).tolist()
def embed_batch(self, texts: list):
return self.model.encode(texts).tolist()
def build_user_index(user_id: str): def build_user_index(user_id: str):
doc_dir = os.path.join(USER_DOC_PATH, user_id) doc_dir = os.path.join(USER_DOC_PATH, user_id)
if not os.path.exists(doc_dir): if not os.path.exists(doc_dir):
raise FileNotFoundError(f"文档目录不存在: {doc_dir}") raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
documents = SimpleDirectoryReader(doc_dir).load_data() documents = SimpleDirectoryReader(doc_dir).load_data()
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) embed_model = CustomEmbedding(model_name=settings.MODEL_NAME)
Settings.embed_model = embed_model # ✅ 新式配置
service_context = ServiceContext.from_defaults(embed_model=embed_model)
index = VectorStoreIndex.from_documents( index = VectorStoreIndex.from_documents(
documents, documents,
vector_store=FaissVectorStore() vector_store=FaissVectorStore(),
service_context=service_context
) )
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
@ -34,8 +48,8 @@ def build_user_index(user_id: str):
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}") print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) embed_model = CustomEmbedding(model_name=settings.MODEL_NAME)
Settings.embed_model = embed_model # ✅ 全局设置一次即可 service_context = ServiceContext.from_defaults(embed_model=embed_model)
all_nodes = [] all_nodes = []
@ -44,7 +58,7 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
if not os.path.exists(index_path): if not os.path.exists(index_path):
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在") raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
user_store = FaissVectorStore.from_persist_path(index_path) user_store = FaissVectorStore.from_persist_path(index_path)
user_index = VectorStoreIndex.from_vector_store(user_store) user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context)
all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question) all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question)
# 加载共享索引 # 加载共享索引
@ -54,7 +68,7 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
shared_path = os.path.join(USER_INDEX_PATH, shared_name) shared_path = os.path.join(USER_INDEX_PATH, shared_name)
if os.path.exists(shared_path) and shared_path != index_path: if os.path.exists(shared_path) and shared_path != index_path:
shared_store = FaissVectorStore.from_persist_path(shared_path) shared_store = FaissVectorStore.from_persist_path(shared_path)
shared_index = VectorStoreIndex.from_vector_store(shared_store) shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context)
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
else: else:
print(f"[INFO] 用户 {user_id} 没有共享索引权限") print(f"[INFO] 用户 {user_id} 没有共享索引权限")
@ -63,7 +77,7 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0)) sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
top_nodes = sorted_nodes[:top_k] top_nodes = sorted_nodes[:top_k]
context_str = "\n\n".join([n.get_content() for n in top_nodes]) context_str = "\n\n".join([n.get_text() for n in top_nodes])
prompt_template = PromptTemplate( prompt_template = PromptTemplate(
"请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}" "请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}"
) )