From 31458b7ac73a1af6a6413c14b25cc77b708c239d Mon Sep 17 00:00:00 2001 From: hailin Date: Fri, 9 May 2025 22:36:29 +0800 Subject: [PATCH] . --- scripts/rag_build_query.py | 52 +++++++++++++------------------------- 1 file changed, 18 insertions(+), 34 deletions(-) diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index d1e9777..ce0547a 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -1,4 +1,8 @@ import os +from typing import List +import asyncio +import faiss +from sentence_transformers import SentenceTransformer from llama_index import ( SimpleDirectoryReader, @@ -7,47 +11,23 @@ from llama_index import ( PromptTemplate, ServiceContext, ) -from llama_index.embeddings.base import BaseEmbedding +from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.vector_stores.faiss import FaissVectorStore from app.core.config import settings from scripts.permissions import get_user_allowed_indexes -import faiss -from typing import List -import asyncio -from sentence_transformers import SentenceTransformer -from llama_index.embeddings.huggingface import HuggingFaceEmbedding USER_INDEX_PATH = "index_data" USER_DOC_PATH = "docs" - -class CustomEmbedding(BaseEmbedding): - model: SentenceTransformer # ✅ 显式声明是字段 - - def __init__(self, model_name: str): - from sentence_transformers import SentenceTransformer - self.model = SentenceTransformer(model_name) - - # 同步方法(必须实现) - def _get_text_embedding(self, text: str) -> List[float]: - return self.model.encode(text).tolist() - +# ✅ 替代 CustomEmbedding,用于 bge-m3 模型,自动加前缀 +class BGEEmbedding(HuggingFaceEmbedding): def _get_query_embedding(self, query: str) -> List[float]: - return self.model.encode(query).tolist() - - def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: - return self.model.encode(texts).tolist() + prefix = "Represent this sentence for searching relevant passages: " + return super()._get_query_embedding(prefix + query) def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: - return self.model.encode(queries).tolist() - - # 异步方法(必须实现,哪怕用同步方式包起来) - async def _aget_query_embedding(self, query: str) -> List[float]: - return self._get_query_embedding(query) - - async def _aget_query_embeddings(self, queries: List[str]) -> List[List[float]]: - return self._get_query_embeddings(queries) - + prefix = "Represent this sentence for searching relevant passages: " + return super()._get_query_embeddings([prefix + q for q in queries]) def build_user_index(user_id: str): doc_dir = os.path.join(USER_DOC_PATH, user_id) @@ -55,13 +35,17 @@ def build_user_index(user_id: str): raise FileNotFoundError(f"文档目录不存在: {doc_dir}") documents = SimpleDirectoryReader(doc_dir).load_data() - embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) + embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) + # ✅ 指定正确维度:bge-m3 是 1024 + faiss_index = faiss.IndexFlatL2(1024) + vector_store = FaissVectorStore(faiss_index=faiss_index) + index = VectorStoreIndex.from_documents( documents, - vector_store=FaissVectorStore(), + vector_store=vector_store, service_context=service_context ) @@ -70,7 +54,7 @@ def build_user_index(user_id: str): print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}") def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: - embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) + embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) all_nodes = []