This commit is contained in:
hailin 2025-05-10 12:41:12 +08:00
parent 979e14430e
commit cdc9650067
1 changed files with 29 additions and 13 deletions

View File

@ -34,24 +34,40 @@ logger.setLevel(logging.INFO)
# BGEEmbedding 类,继承自 HuggingFaceEmbedding用于生成查询的嵌入 # BGEEmbedding 类,继承自 HuggingFaceEmbedding用于生成查询的嵌入
class BGEEmbedding(HuggingFaceEmbedding): class BGEEmbedding(HuggingFaceEmbedding):
def _get_query_embedding(self, query: str) -> List[float]: def _get_query_embedding(self, query: str) -> List[float]:
# 在查询前加上前缀,生成嵌入向量 try:
prefix = "Represent this sentence for searching relevant passages: " # 在查询前加上前缀,生成嵌入向量
embedding = super()._get_query_embedding(prefix + query) logger.info("Calling _get_query_embedding method...")
prefix = "Represent this sentence for searching relevant passages: "
embedding = super()._get_query_embedding(prefix + query)
# 使用logger打印数据类型 # 转换为 float32 类型
logger.info(f"Query embedding dtype: {embedding.dtype}") embedding = np.array(embedding, dtype=np.float32)
return embedding
# 使用 logger 打印数据类型
logger.info(f"Query embedding dtype after conversion: {embedding.dtype}")
return embedding
except Exception as e:
logger.error(f"Error in _get_query_embedding: {e}")
raise
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
# 批量生成嵌入向量 try:
prefix = "Represent this sentence for searching relevant passages: " # 批量生成嵌入向量
embeddings = super()._get_query_embeddings([prefix + q for q in queries]) logger.info("Calling _get_query_embeddings method...")
prefix = "Represent this sentence for searching relevant passages: "
embeddings = super()._get_query_embeddings([prefix + q for q in queries])
# 使用logger打印数据类型 # 转换为 float32 类型
logger.info(f"Batch query embeddings dtype: {embeddings[0].dtype}") embeddings = [np.array(embedding, dtype=np.float32) for embedding in embeddings]
return embeddings
# 使用 logger 打印数据类型
logger.info(f"Batch query embeddings dtype after conversion: {embeddings[0].dtype}")
return embeddings
except Exception as e:
logger.error(f"Error in _get_query_embeddings: {e}")
raise
def build_user_index(user_id: str): def build_user_index(user_id: str):
logger.info(f"开始为用户 {user_id} 构建索引...") logger.info(f"开始为用户 {user_id} 构建索引...")
@ -81,7 +97,7 @@ def build_user_index(user_id: str):
# 直接检查模型嵌入方法是否被调用 # 直接检查模型嵌入方法是否被调用
logger.info(f"Embedding method being used: {embed_model._get_query_embedding('test query')}") logger.info(f"Embedding method being used: {embed_model._get_query_embedding('test query')}")
# 使用 Faiss 向量存储 # 使用 Faiss 向量存储
faiss_index = faiss.IndexFlatL2(1024) faiss_index = faiss.IndexFlatL2(1024)
vector_store = FaissVectorStore(faiss_index=faiss_index) vector_store = FaissVectorStore(faiss_index=faiss_index)