This commit is contained in:
parent
ad59446d14
commit
668da474aa
|
|
@ -36,7 +36,7 @@ logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
# BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入
|
# BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入
|
||||||
class BGEEmbedding(HuggingFaceEmbedding):
|
class BGEEmbedding(HuggingFaceEmbedding):
|
||||||
def _get_query_embedding(self, query: str) -> List[float]:
|
def _get_query_embedding(self, query: str) -> List[float]:
|
||||||
try:
|
try:
|
||||||
# 在查询前加上前缀,生成嵌入向量
|
# 在查询前加上前缀,生成嵌入向量
|
||||||
|
|
@ -47,6 +47,9 @@ class BGEEmbedding(HuggingFaceEmbedding):
|
||||||
# 转换为 float32 类型
|
# 转换为 float32 类型
|
||||||
embedding = np.array(embedding, dtype=np.float32)
|
embedding = np.array(embedding, dtype=np.float32)
|
||||||
|
|
||||||
|
# 将 numpy 数组转换为列表,确保 embedding 是一个列表
|
||||||
|
embedding = embedding.tolist()
|
||||||
|
|
||||||
# 使用 logger 打印数据类型
|
# 使用 logger 打印数据类型
|
||||||
logger.info(f"Query embedding dtype after conversion: {embedding.dtype}")
|
logger.info(f"Query embedding dtype after conversion: {embedding.dtype}")
|
||||||
return embedding
|
return embedding
|
||||||
|
|
@ -61,8 +64,8 @@ class BGEEmbedding(HuggingFaceEmbedding):
|
||||||
prefix = "Represent this sentence for searching relevant passages: "
|
prefix = "Represent this sentence for searching relevant passages: "
|
||||||
embeddings = super()._get_query_embeddings([prefix + q for q in queries])
|
embeddings = super()._get_query_embeddings([prefix + q for q in queries])
|
||||||
|
|
||||||
# 转换为 float32 类型
|
# 转换为 float32 类型并转换为列表
|
||||||
embeddings = [np.array(embedding, dtype=np.float32) for embedding in embeddings]
|
embeddings = [np.array(embedding, dtype=np.float32).tolist() for embedding in embeddings]
|
||||||
|
|
||||||
# 使用 logger 打印数据类型
|
# 使用 logger 打印数据类型
|
||||||
logger.info(f"Batch query embeddings dtype after conversion: {embeddings[0].dtype}")
|
logger.info(f"Batch query embeddings dtype after conversion: {embeddings[0].dtype}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue