This commit is contained in:
hailin 2025-05-09 22:44:13 +08:00
parent 79c176c1d3
commit a3b6021762
1 changed files with 24 additions and 28 deletions

View File

@ -1,15 +1,12 @@
import os import os
from typing import List from typing import List
import asyncio
import faiss import faiss
from sentence_transformers import SentenceTransformer
from llama_index import ( from llama_index import (
SimpleDirectoryReader, SimpleDirectoryReader,
VectorStoreIndex, VectorStoreIndex,
PromptHelper,
PromptTemplate,
ServiceContext, ServiceContext,
PromptTemplate,
StorageContext,
) )
from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore from llama_index.vector_stores.faiss import FaissVectorStore
@ -19,7 +16,7 @@ from scripts.permissions import get_user_allowed_indexes
USER_INDEX_PATH = "index_data" USER_INDEX_PATH = "index_data"
USER_DOC_PATH = "docs" USER_DOC_PATH = "docs"
# ✅ BGE-m3 模型嵌入类,加前缀 # ✅ 自动加前缀的 bge-m3 embedding
class BGEEmbedding(HuggingFaceEmbedding): class BGEEmbedding(HuggingFaceEmbedding):
def _get_query_embedding(self, query: str) -> List[float]: def _get_query_embedding(self, query: str) -> List[float]:
prefix = "Represent this sentence for searching relevant passages: " prefix = "Represent this sentence for searching relevant passages: "
@ -33,44 +30,47 @@ def build_user_index(user_id: str):
doc_dir = os.path.join(USER_DOC_PATH, user_id) doc_dir = os.path.join(USER_DOC_PATH, user_id)
if not os.path.exists(doc_dir): if not os.path.exists(doc_dir):
raise FileNotFoundError(f"文档目录不存在: {doc_dir}") raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
documents = SimpleDirectoryReader(doc_dir).load_data() documents = SimpleDirectoryReader(doc_dir).load_data()
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
faiss_index = faiss.IndexFlatL2(1024) # ✅ bge-m3 是 1024维 faiss_index = faiss.IndexFlatL2(1024)
vector_store = FaissVectorStore(faiss_index=faiss_index) vector_store = FaissVectorStore(faiss_index=faiss_index)
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
storage_context = StorageContext.from_defaults(
persist_dir=persist_dir,
vector_store=vector_store
)
index = VectorStoreIndex.from_documents( index = VectorStoreIndex.from_documents(
documents, documents,
vector_store=vector_store, storage_context=storage_context,
service_context=service_context service_context=service_context
) )
index.persist(persist_dir=persist_dir)
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") print(f"[BUILD] 为用户 {user_id} 构建并保存了完整索引 → {persist_dir}")
faiss.write_index(faiss_index, index_path) # ✅ 改这里,直接写 faiss_index
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
all_nodes = [] persist_dir = os.path.join(USER_INDEX_PATH, user_id)
if not os.path.exists(persist_dir):
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引目录不存在")
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") storage_context = StorageContext.from_defaults(persist_dir=persist_dir)
if not os.path.exists(index_path): index = VectorStoreIndex.load_from_storage(storage_context, service_context=service_context)
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
user_store = FaissVectorStore.from_persist_path(index_path) all_nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question)
user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context)
all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question)
shared_indexes = get_user_allowed_indexes(user_id) shared_indexes = get_user_allowed_indexes(user_id)
if shared_indexes: if shared_indexes:
for shared_name in shared_indexes: for shared_name in shared_indexes:
shared_path = os.path.join(USER_INDEX_PATH, shared_name) shared_path = os.path.join(USER_INDEX_PATH, shared_name)
if os.path.exists(shared_path) and shared_path != index_path: if os.path.exists(shared_path) and shared_path != persist_dir:
shared_store = FaissVectorStore.from_persist_path(shared_path) shared_storage = StorageContext.from_defaults(persist_dir=shared_path)
shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context) shared_index = VectorStoreIndex.load_from_storage(shared_storage, service_context=service_context)
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
else: else:
print(f"[INFO] 用户 {user_id} 没有共享索引权限") print(f"[INFO] 用户 {user_id} 没有共享索引权限")
@ -82,15 +82,11 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
prompt_template = PromptTemplate( prompt_template = PromptTemplate(
"请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}" "请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}"
) )
final_prompt = prompt_template.format( final_prompt = prompt_template.format(context=context_str, query=question)
context=context_str,
query=question,
)
print("[PROMPT构建完成]") print("[PROMPT构建完成]")
return final_prompt return final_prompt
# 示例:
if __name__ == "__main__": if __name__ == "__main__":
uid = "user_001" uid = "user_001"
build_user_index(uid) build_user_index(uid)