From 256646f4262448ca7e6816d3862c1891e9510505 Mon Sep 17 00:00:00 2001 From: hailin Date: Sat, 10 May 2025 12:23:46 +0800 Subject: [PATCH] . --- app/core/config.py | 2 +- rag_build_query.py | 36 ++++++++++++++++++++++++++++++------ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 14e3405..f09fdd4 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -4,7 +4,7 @@ class Settings(BaseSettings): EMBEDDING_DIM: int = 768 # 嵌入维度(取决于模型) TOP_K: int = 5 # 默认检索 top K 个段落 DOC_PATH: str = "docs/" # 默认文档根目录 - DEVICE: str = "cpu" # 可改为 "cuda" 使用 GPU + DEVICE: str = "cuda" # 可改为 "cuda" 使用 GPU MODEL_NAME: str = "BAAI/bge-m3" # 多语种语义嵌入模型 settings = Settings() diff --git a/rag_build_query.py b/rag_build_query.py index ff45e71..39264f6 100644 --- a/rag_build_query.py +++ b/rag_build_query.py @@ -12,23 +12,47 @@ from scripts.permissions import get_user_allowed_indexes # 导入权限函数 USER_INDEX_PATH = "index_data" # 用户索引存储路径 USER_DOC_PATH = "docs" # 用户文档存储路径 + +# 设置日志记录器,记录操作步骤和错误信息 +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +# # BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入 +# class BGEEmbedding(HuggingFaceEmbedding): +# def _get_query_embedding(self, query: str) -> List[float]: +# # 在查询前加上前缀,生成嵌入向量 +# prefix = "Represent this sentence for searching relevant passages: " +# return super()._get_query_embedding(prefix + query) + +# def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: +# # 批量生成嵌入向量 +# prefix = "Represent this sentence for searching relevant passages: " +# return super()._get_query_embeddings([prefix + q for q in queries]) + + # BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入 class BGEEmbedding(HuggingFaceEmbedding): def _get_query_embedding(self, query: str) -> List[float]: # 在查询前加上前缀,生成嵌入向量 prefix = "Represent this sentence for searching relevant passages: " - return super()._get_query_embedding(prefix + query) + embedding = super()._get_query_embedding(prefix + query) + + # 使用logger打印数据类型 + logger.info(f"Query embedding dtype: {embedding.dtype}") + return embedding def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: # 批量生成嵌入向量 prefix = "Represent this sentence for searching relevant passages: " - return super()._get_query_embeddings([prefix + q for q in queries]) + embeddings = super()._get_query_embeddings([prefix + q for q in queries]) + + # 使用logger打印数据类型 + logger.info(f"Batch query embeddings dtype: {embeddings[0].dtype}") + return embeddings + def build_user_index(user_id: str): - # 设置日志记录器,记录操作步骤和错误信息 - logger = logging.getLogger(__name__) - logger.setLevel(logging.INFO) - logger.info(f"开始为用户 {user_id} 构建索引...") # 确认文档目录是否存在