119 lines
4.3 KiB
Python
119 lines
4.3 KiB
Python
import os
|
|
|
|
from llama_index import (
|
|
SimpleDirectoryReader,
|
|
VectorStoreIndex,
|
|
PromptHelper,
|
|
PromptTemplate,
|
|
ServiceContext,
|
|
)
|
|
from llama_index.embeddings.base import BaseEmbedding
|
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
|
from app.core.config import settings
|
|
from scripts.permissions import get_user_allowed_indexes
|
|
import faiss
|
|
from typing import List
|
|
import asyncio
|
|
|
|
USER_INDEX_PATH = "index_data"
|
|
USER_DOC_PATH = "docs"
|
|
|
|
|
|
class CustomEmbedding(BaseEmbedding):
|
|
model: SentenceTransformer # ✅ 显式声明是字段
|
|
|
|
def __init__(self, model_name: str):
|
|
from sentence_transformers import SentenceTransformer
|
|
self.model = SentenceTransformer(model_name)
|
|
|
|
# 同步方法(必须实现)
|
|
def _get_text_embedding(self, text: str) -> List[float]:
|
|
return self.model.encode(text).tolist()
|
|
|
|
def _get_query_embedding(self, query: str) -> List[float]:
|
|
return self.model.encode(query).tolist()
|
|
|
|
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
return self.model.encode(texts).tolist()
|
|
|
|
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
|
|
return self.model.encode(queries).tolist()
|
|
|
|
# 异步方法(必须实现,哪怕用同步方式包起来)
|
|
async def _aget_query_embedding(self, query: str) -> List[float]:
|
|
return self._get_query_embedding(query)
|
|
|
|
async def _aget_query_embeddings(self, queries: List[str]) -> List[List[float]]:
|
|
return self._get_query_embeddings(queries)
|
|
|
|
|
|
def build_user_index(user_id: str):
|
|
doc_dir = os.path.join(USER_DOC_PATH, user_id)
|
|
if not os.path.exists(doc_dir):
|
|
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
|
|
|
|
documents = SimpleDirectoryReader(doc_dir).load_data()
|
|
embed_model = CustomEmbedding(model_name=settings.MODEL_NAME)
|
|
|
|
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
|
|
|
index = VectorStoreIndex.from_documents(
|
|
documents,
|
|
vector_store=FaissVectorStore(),
|
|
service_context=service_context
|
|
)
|
|
|
|
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
|
faiss.write_index(index.vector_store.index, index_path)
|
|
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
|
|
|
|
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
|
embed_model = CustomEmbedding(model_name=settings.MODEL_NAME)
|
|
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
|
|
|
all_nodes = []
|
|
|
|
# 加载用户主索引
|
|
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
|
if not os.path.exists(index_path):
|
|
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
|
|
user_store = FaissVectorStore.from_persist_path(index_path)
|
|
user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context)
|
|
all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
|
|
|
# 加载共享索引
|
|
shared_indexes = get_user_allowed_indexes(user_id)
|
|
if shared_indexes:
|
|
for shared_name in shared_indexes:
|
|
shared_path = os.path.join(USER_INDEX_PATH, shared_name)
|
|
if os.path.exists(shared_path) and shared_path != index_path:
|
|
shared_store = FaissVectorStore.from_persist_path(shared_path)
|
|
shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context)
|
|
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
|
else:
|
|
print(f"[INFO] 用户 {user_id} 没有共享索引权限")
|
|
|
|
# 合并 + 按 score 排序
|
|
sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
|
|
top_nodes = sorted_nodes[:top_k]
|
|
|
|
context_str = "\n\n".join([n.get_text() for n in top_nodes])
|
|
prompt_template = PromptTemplate(
|
|
"请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}"
|
|
)
|
|
final_prompt = prompt_template.format(
|
|
context=context_str,
|
|
query=question,
|
|
)
|
|
|
|
print("[PROMPT构建完成]")
|
|
return final_prompt
|
|
|
|
# 示例:
|
|
if __name__ == "__main__":
|
|
uid = "user_001"
|
|
build_user_index(uid)
|
|
prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?")
|
|
print("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n")
|
|
print(prompt)
|