diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index bdf878d..bc28d29 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -36,7 +36,7 @@ logger.setLevel(logging.INFO) # BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入 -class BGEEmbedding(HuggingFaceEmbedding): +class BGEEmbedding(HuggingFaceEmbedding): def _get_query_embedding(self, query: str) -> List[float]: try: # 在查询前加上前缀,生成嵌入向量 @@ -47,6 +47,9 @@ class BGEEmbedding(HuggingFaceEmbedding): # 转换为 float32 类型 embedding = np.array(embedding, dtype=np.float32) + # 将 numpy 数组转换为列表,确保 embedding 是一个列表 + embedding = embedding.tolist() + # 使用 logger 打印数据类型 logger.info(f"Query embedding dtype after conversion: {embedding.dtype}") return embedding @@ -61,8 +64,8 @@ class BGEEmbedding(HuggingFaceEmbedding): prefix = "Represent this sentence for searching relevant passages: " embeddings = super()._get_query_embeddings([prefix + q for q in queries]) - # 转换为 float32 类型 - embeddings = [np.array(embedding, dtype=np.float32) for embedding in embeddings] + # 转换为 float32 类型并转换为列表 + embeddings = [np.array(embedding, dtype=np.float32).tolist() for embedding in embeddings] # 使用 logger 打印数据类型 logger.info(f"Batch query embeddings dtype after conversion: {embeddings[0].dtype}")