This commit is contained in:
parent
31458b7ac7
commit
344e98a73c
|
|
@ -19,7 +19,7 @@ from scripts.permissions import get_user_allowed_indexes
|
||||||
USER_INDEX_PATH = "index_data"
|
USER_INDEX_PATH = "index_data"
|
||||||
USER_DOC_PATH = "docs"
|
USER_DOC_PATH = "docs"
|
||||||
|
|
||||||
# ✅ 替代 CustomEmbedding,用于 bge-m3 模型,自动加前缀
|
# ✅ BGE-m3 模型嵌入类,加前缀
|
||||||
class BGEEmbedding(HuggingFaceEmbedding):
|
class BGEEmbedding(HuggingFaceEmbedding):
|
||||||
def _get_query_embedding(self, query: str) -> List[float]:
|
def _get_query_embedding(self, query: str) -> List[float]:
|
||||||
prefix = "Represent this sentence for searching relevant passages: "
|
prefix = "Represent this sentence for searching relevant passages: "
|
||||||
|
|
@ -36,11 +36,9 @@ def build_user_index(user_id: str):
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(doc_dir).load_data()
|
documents = SimpleDirectoryReader(doc_dir).load_data()
|
||||||
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
|
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
|
||||||
|
|
||||||
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
|
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
|
||||||
|
|
||||||
# ✅ 指定正确维度:bge-m3 是 1024
|
faiss_index = faiss.IndexFlatL2(1024) # ✅ bge-m3 是 1024维
|
||||||
faiss_index = faiss.IndexFlatL2(1024)
|
|
||||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
|
|
||||||
index = VectorStoreIndex.from_documents(
|
index = VectorStoreIndex.from_documents(
|
||||||
|
|
@ -50,7 +48,7 @@ def build_user_index(user_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
||||||
faiss.write_index(index.vector_store.index, index_path)
|
faiss.write_index(vector_store.index, index_path) # ✅ 用传入的 vector_store
|
||||||
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
|
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
|
||||||
|
|
||||||
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
|
|
@ -59,7 +57,6 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
|
|
||||||
all_nodes = []
|
all_nodes = []
|
||||||
|
|
||||||
# 加载用户主索引
|
|
||||||
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
||||||
if not os.path.exists(index_path):
|
if not os.path.exists(index_path):
|
||||||
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
|
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
|
||||||
|
|
@ -67,7 +64,6 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context)
|
user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context)
|
||||||
all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||||
|
|
||||||
# 加载共享索引
|
|
||||||
shared_indexes = get_user_allowed_indexes(user_id)
|
shared_indexes = get_user_allowed_indexes(user_id)
|
||||||
if shared_indexes:
|
if shared_indexes:
|
||||||
for shared_name in shared_indexes:
|
for shared_name in shared_indexes:
|
||||||
|
|
@ -79,7 +75,6 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
else:
|
else:
|
||||||
print(f"[INFO] 用户 {user_id} 没有共享索引权限")
|
print(f"[INFO] 用户 {user_id} 没有共享索引权限")
|
||||||
|
|
||||||
# 合并 + 按 score 排序
|
|
||||||
sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
|
sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
|
||||||
top_nodes = sorted_nodes[:top_k]
|
top_nodes = sorted_nodes[:top_k]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue