diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index bc28d29..c268217 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -22,19 +22,6 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -# # BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入 -# class BGEEmbedding(HuggingFaceEmbedding): -# def _get_query_embedding(self, query: str) -> List[float]: -# # 在查询前加上前缀,生成嵌入向量 -# prefix = "Represent this sentence for searching relevant passages: " -# return super()._get_query_embedding(prefix + query) - -# def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: -# # 批量生成嵌入向量 -# prefix = "Represent this sentence for searching relevant passages: " -# return super()._get_query_embeddings([prefix + q for q in queries]) - - # BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入 class BGEEmbedding(HuggingFaceEmbedding): def _get_query_embedding(self, query: str) -> List[float]: @@ -44,14 +31,14 @@ class BGEEmbedding(HuggingFaceEmbedding): prefix = "Represent this sentence for searching relevant passages: " embedding = super()._get_query_embedding(prefix + query) - # 转换为 float32 类型 + # 转换为 numpy 数组并记录 dtype embedding = np.array(embedding, dtype=np.float32) - # 将 numpy 数组转换为列表,确保 embedding 是一个列表 - embedding = embedding.tolist() - - # 使用 logger 打印数据类型 + # 使用 logger 打印数据类型(dtype 在 numpy 数组上有效) logger.info(f"Query embedding dtype after conversion: {embedding.dtype}") + + # 转换为列表返回 + embedding = embedding.tolist() # 转换为 Python 列表 return embedding except Exception as e: logger.error(f"Error in _get_query_embedding: {e}") @@ -64,17 +51,23 @@ class BGEEmbedding(HuggingFaceEmbedding): prefix = "Represent this sentence for searching relevant passages: " embeddings = super()._get_query_embeddings([prefix + q for q in queries]) - # 转换为 float32 类型并转换为列表 - embeddings = [np.array(embedding, dtype=np.float32).tolist() for embedding in embeddings] + # 转换为 numpy 数组并记录 dtype + embeddings = [np.array(embedding, dtype=np.float32) for embedding in embeddings] - # 使用 logger 打印数据类型 + # 使用 logger 打印数据类型(dtype 在 numpy 数组上有效) logger.info(f"Batch query embeddings dtype after conversion: {embeddings[0].dtype}") + + # 转换为列表返回 + embeddings = [embedding.tolist() for embedding in embeddings] return embeddings except Exception as e: logger.error(f"Error in _get_query_embeddings: {e}") raise - + + + + def build_user_index(user_id: str): logger.info(f"开始为用户 {user_id} 构建索引...")