This commit is contained in:
parent
979e14430e
commit
cdc9650067
|
|
@ -34,22 +34,38 @@ logger.setLevel(logging.INFO)
|
|||
# BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入
|
||||
class BGEEmbedding(HuggingFaceEmbedding):
|
||||
def _get_query_embedding(self, query: str) -> List[float]:
|
||||
try:
|
||||
# 在查询前加上前缀,生成嵌入向量
|
||||
logger.info("Calling _get_query_embedding method...")
|
||||
prefix = "Represent this sentence for searching relevant passages: "
|
||||
embedding = super()._get_query_embedding(prefix + query)
|
||||
|
||||
# 转换为 float32 类型
|
||||
embedding = np.array(embedding, dtype=np.float32)
|
||||
|
||||
# 使用 logger 打印数据类型
|
||||
logger.info(f"Query embedding dtype: {embedding.dtype}")
|
||||
logger.info(f"Query embedding dtype after conversion: {embedding.dtype}")
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _get_query_embedding: {e}")
|
||||
raise
|
||||
|
||||
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
|
||||
try:
|
||||
# 批量生成嵌入向量
|
||||
logger.info("Calling _get_query_embeddings method...")
|
||||
prefix = "Represent this sentence for searching relevant passages: "
|
||||
embeddings = super()._get_query_embeddings([prefix + q for q in queries])
|
||||
|
||||
# 转换为 float32 类型
|
||||
embeddings = [np.array(embedding, dtype=np.float32) for embedding in embeddings]
|
||||
|
||||
# 使用 logger 打印数据类型
|
||||
logger.info(f"Batch query embeddings dtype: {embeddings[0].dtype}")
|
||||
logger.info(f"Batch query embeddings dtype after conversion: {embeddings[0].dtype}")
|
||||
return embeddings
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _get_query_embeddings: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def build_user_index(user_id: str):
|
||||
|
|
|
|||
Loading…
Reference in New Issue