diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index 901c922..ed77ba1 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -12,27 +12,39 @@ from llama_index.vector_stores.faiss import FaissVectorStore from app.core.config import settings from scripts.permissions import get_user_allowed_indexes import faiss +from typing import List +import asyncio USER_INDEX_PATH = "index_data" USER_DOC_PATH = "docs" + class CustomEmbedding(BaseEmbedding): def __init__(self, model_name: str): from sentence_transformers import SentenceTransformer self.model = SentenceTransformer(model_name) - def _get_text_embedding(self, text: str) -> list[float]: + # 同步方法(必须实现) + def _get_text_embedding(self, text: str) -> List[float]: return self.model.encode(text).tolist() - def _get_query_embedding(self, query: str) -> list[float]: + def _get_query_embedding(self, query: str) -> List[float]: return self.model.encode(query).tolist() - def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]: + def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: return self.model.encode(texts).tolist() - def _get_query_embeddings(self, queries: list[str]) -> list[list[float]]: + def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: return self.model.encode(queries).tolist() - + + # 异步方法(必须实现,哪怕用同步方式包起来) + async def _aget_query_embedding(self, query: str) -> List[float]: + return self._get_query_embedding(query) + + async def _aget_query_embeddings(self, queries: List[str]) -> List[List[float]]: + return self._get_query_embeddings(queries) + + def build_user_index(user_id: str): doc_dir = os.path.join(USER_DOC_PATH, user_id) if not os.path.exists(doc_dir):