This commit is contained in:
hailin 2025-05-10 12:23:46 +08:00
parent c6bb1a381a
commit 256646f426
2 changed files with 31 additions and 7 deletions

View File

@ -4,7 +4,7 @@ class Settings(BaseSettings):
EMBEDDING_DIM: int = 768 # 嵌入维度(取决于模型)
TOP_K: int = 5 # 默认检索 top K 个段落
DOC_PATH: str = "docs/" # 默认文档根目录
DEVICE: str = "cpu" # 可改为 "cuda" 使用 GPU
DEVICE: str = "cuda" # 可改为 "cuda" 使用 GPU
MODEL_NAME: str = "BAAI/bge-m3" # 多语种语义嵌入模型
settings = Settings()

View File

@ -12,23 +12,47 @@ from scripts.permissions import get_user_allowed_indexes # 导入权限函数
USER_INDEX_PATH = "index_data" # 用户索引存储路径
USER_DOC_PATH = "docs" # 用户文档存储路径
# 设置日志记录器,记录操作步骤和错误信息
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# # BGEEmbedding 类,继承自 HuggingFaceEmbedding用于生成查询的嵌入
# class BGEEmbedding(HuggingFaceEmbedding):
# def _get_query_embedding(self, query: str) -> List[float]:
# # 在查询前加上前缀,生成嵌入向量
# prefix = "Represent this sentence for searching relevant passages: "
# return super()._get_query_embedding(prefix + query)
# def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
# # 批量生成嵌入向量
# prefix = "Represent this sentence for searching relevant passages: "
# return super()._get_query_embeddings([prefix + q for q in queries])
# BGEEmbedding 类,继承自 HuggingFaceEmbedding用于生成查询的嵌入
class BGEEmbedding(HuggingFaceEmbedding):
def _get_query_embedding(self, query: str) -> List[float]:
# 在查询前加上前缀,生成嵌入向量
prefix = "Represent this sentence for searching relevant passages: "
return super()._get_query_embedding(prefix + query)
embedding = super()._get_query_embedding(prefix + query)
# 使用logger打印数据类型
logger.info(f"Query embedding dtype: {embedding.dtype}")
return embedding
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
# 批量生成嵌入向量
prefix = "Represent this sentence for searching relevant passages: "
return super()._get_query_embeddings([prefix + q for q in queries])
embeddings = super()._get_query_embeddings([prefix + q for q in queries])
# 使用logger打印数据类型
logger.info(f"Batch query embeddings dtype: {embeddings[0].dtype}")
return embeddings
def build_user_index(user_id: str):
# 设置日志记录器,记录操作步骤和错误信息
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.info(f"开始为用户 {user_id} 构建索引...")
# 确认文档目录是否存在