From cdc965006716db2a560b735811af3e5ecffdb15d Mon Sep 17 00:00:00 2001 From: hailin Date: Sat, 10 May 2025 12:41:12 +0800 Subject: [PATCH] . --- rag_build_query.py | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/rag_build_query.py b/rag_build_query.py index 71a9fec..56c0b01 100644 --- a/rag_build_query.py +++ b/rag_build_query.py @@ -34,24 +34,40 @@ logger.setLevel(logging.INFO) # BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入 class BGEEmbedding(HuggingFaceEmbedding): def _get_query_embedding(self, query: str) -> List[float]: - # 在查询前加上前缀,生成嵌入向量 - prefix = "Represent this sentence for searching relevant passages: " - embedding = super()._get_query_embedding(prefix + query) + try: + # 在查询前加上前缀,生成嵌入向量 + logger.info("Calling _get_query_embedding method...") + prefix = "Represent this sentence for searching relevant passages: " + embedding = super()._get_query_embedding(prefix + query) - # 使用logger打印数据类型 - logger.info(f"Query embedding dtype: {embedding.dtype}") - return embedding + # 转换为 float32 类型 + embedding = np.array(embedding, dtype=np.float32) + + # 使用 logger 打印数据类型 + logger.info(f"Query embedding dtype after conversion: {embedding.dtype}") + return embedding + except Exception as e: + logger.error(f"Error in _get_query_embedding: {e}") + raise def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: - # 批量生成嵌入向量 - prefix = "Represent this sentence for searching relevant passages: " - embeddings = super()._get_query_embeddings([prefix + q for q in queries]) + try: + # 批量生成嵌入向量 + logger.info("Calling _get_query_embeddings method...") + prefix = "Represent this sentence for searching relevant passages: " + embeddings = super()._get_query_embeddings([prefix + q for q in queries]) - # 使用logger打印数据类型 - logger.info(f"Batch query embeddings dtype: {embeddings[0].dtype}") - return embeddings + # 转换为 float32 类型 + embeddings = [np.array(embedding, dtype=np.float32) for embedding in embeddings] + # 使用 logger 打印数据类型 + logger.info(f"Batch query embeddings dtype after conversion: {embeddings[0].dtype}") + return embeddings + except Exception as e: + logger.error(f"Error in _get_query_embeddings: {e}") + raise + def build_user_index(user_id: str): logger.info(f"开始为用户 {user_id} 构建索引...") @@ -81,7 +97,7 @@ def build_user_index(user_id: str): # 直接检查模型嵌入方法是否被调用 logger.info(f"Embedding method being used: {embed_model._get_query_embedding('test query')}") - + # 使用 Faiss 向量存储 faiss_index = faiss.IndexFlatL2(1024) vector_store = FaissVectorStore(faiss_index=faiss_index)