This commit is contained in:
hailin 2025-05-11 00:49:12 +08:00
parent 668da474aa
commit bca56cc457
1 changed files with 15 additions and 22 deletions

View File

@ -22,19 +22,6 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
# # BGEEmbedding 类,继承自 HuggingFaceEmbedding用于生成查询的嵌入
# class BGEEmbedding(HuggingFaceEmbedding):
# def _get_query_embedding(self, query: str) -> List[float]:
# # 在查询前加上前缀,生成嵌入向量
# prefix = "Represent this sentence for searching relevant passages: "
# return super()._get_query_embedding(prefix + query)
# def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
# # 批量生成嵌入向量
# prefix = "Represent this sentence for searching relevant passages: "
# return super()._get_query_embeddings([prefix + q for q in queries])
# BGEEmbedding 类,继承自 HuggingFaceEmbedding用于生成查询的嵌入 # BGEEmbedding 类,继承自 HuggingFaceEmbedding用于生成查询的嵌入
class BGEEmbedding(HuggingFaceEmbedding): class BGEEmbedding(HuggingFaceEmbedding):
def _get_query_embedding(self, query: str) -> List[float]: def _get_query_embedding(self, query: str) -> List[float]:
@ -44,14 +31,14 @@ class BGEEmbedding(HuggingFaceEmbedding):
prefix = "Represent this sentence for searching relevant passages: " prefix = "Represent this sentence for searching relevant passages: "
embedding = super()._get_query_embedding(prefix + query) embedding = super()._get_query_embedding(prefix + query)
# 转换为 float32 类型 # 转换为 numpy 数组并记录 dtype
embedding = np.array(embedding, dtype=np.float32) embedding = np.array(embedding, dtype=np.float32)
# 将 numpy 数组转换为列表,确保 embedding 是一个列表 # 使用 logger 打印数据类型dtype 在 numpy 数组上有效)
embedding = embedding.tolist()
# 使用 logger 打印数据类型
logger.info(f"Query embedding dtype after conversion: {embedding.dtype}") logger.info(f"Query embedding dtype after conversion: {embedding.dtype}")
# 转换为列表返回
embedding = embedding.tolist() # 转换为 Python 列表
return embedding return embedding
except Exception as e: except Exception as e:
logger.error(f"Error in _get_query_embedding: {e}") logger.error(f"Error in _get_query_embedding: {e}")
@ -64,17 +51,23 @@ class BGEEmbedding(HuggingFaceEmbedding):
prefix = "Represent this sentence for searching relevant passages: " prefix = "Represent this sentence for searching relevant passages: "
embeddings = super()._get_query_embeddings([prefix + q for q in queries]) embeddings = super()._get_query_embeddings([prefix + q for q in queries])
# 转换为 float32 类型并转换为列表 # 转换为 numpy 数组并记录 dtype
embeddings = [np.array(embedding, dtype=np.float32).tolist() for embedding in embeddings] embeddings = [np.array(embedding, dtype=np.float32) for embedding in embeddings]
# 使用 logger 打印数据类型 # 使用 logger 打印数据类型dtype 在 numpy 数组上有效)
logger.info(f"Batch query embeddings dtype after conversion: {embeddings[0].dtype}") logger.info(f"Batch query embeddings dtype after conversion: {embeddings[0].dtype}")
# 转换为列表返回
embeddings = [embedding.tolist() for embedding in embeddings]
return embeddings return embeddings
except Exception as e: except Exception as e:
logger.error(f"Error in _get_query_embeddings: {e}") logger.error(f"Error in _get_query_embeddings: {e}")
raise raise
def build_user_index(user_id: str): def build_user_index(user_id: str):
logger.info(f"开始为用户 {user_id} 构建索引...") logger.info(f"开始为用户 {user_id} 构建索引...")