From 1ff999085f3873b6128dcc738c3e69049be5928e Mon Sep 17 00:00:00 2001 From: hailin Date: Fri, 9 May 2025 20:25:30 +0800 Subject: [PATCH] . --- scripts/rag_build_query.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index 37b6e78..901c922 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -21,12 +21,18 @@ class CustomEmbedding(BaseEmbedding): from sentence_transformers import SentenceTransformer self.model = SentenceTransformer(model_name) - def embed(self, text: str): + def _get_text_embedding(self, text: str) -> list[float]: return self.model.encode(text).tolist() - def embed_batch(self, texts: list): + def _get_query_embedding(self, query: str) -> list[float]: + return self.model.encode(query).tolist() + + def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]: return self.model.encode(texts).tolist() + def _get_query_embeddings(self, queries: list[str]) -> list[list[float]]: + return self.model.encode(queries).tolist() + def build_user_index(user_id: str): doc_dir = os.path.join(USER_DOC_PATH, user_id) if not os.path.exists(doc_dir):