This commit is contained in:
parent
c6bb1a381a
commit
256646f426
|
|
@ -4,7 +4,7 @@ class Settings(BaseSettings):
|
|||
EMBEDDING_DIM: int = 768 # 嵌入维度(取决于模型)
|
||||
TOP_K: int = 5 # 默认检索 top K 个段落
|
||||
DOC_PATH: str = "docs/" # 默认文档根目录
|
||||
DEVICE: str = "cpu" # 可改为 "cuda" 使用 GPU
|
||||
DEVICE: str = "cuda" # 可改为 "cuda" 使用 GPU
|
||||
MODEL_NAME: str = "BAAI/bge-m3" # 多语种语义嵌入模型
|
||||
|
||||
settings = Settings()
|
||||
|
|
|
|||
|
|
@ -12,23 +12,47 @@ from scripts.permissions import get_user_allowed_indexes # 导入权限函数
|
|||
USER_INDEX_PATH = "index_data" # 用户索引存储路径
|
||||
USER_DOC_PATH = "docs" # 用户文档存储路径
|
||||
|
||||
|
||||
# 设置日志记录器,记录操作步骤和错误信息
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
# # BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入
|
||||
# class BGEEmbedding(HuggingFaceEmbedding):
|
||||
# def _get_query_embedding(self, query: str) -> List[float]:
|
||||
# # 在查询前加上前缀,生成嵌入向量
|
||||
# prefix = "Represent this sentence for searching relevant passages: "
|
||||
# return super()._get_query_embedding(prefix + query)
|
||||
|
||||
# def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
|
||||
# # 批量生成嵌入向量
|
||||
# prefix = "Represent this sentence for searching relevant passages: "
|
||||
# return super()._get_query_embeddings([prefix + q for q in queries])
|
||||
|
||||
|
||||
# BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入
|
||||
class BGEEmbedding(HuggingFaceEmbedding):
|
||||
def _get_query_embedding(self, query: str) -> List[float]:
|
||||
# 在查询前加上前缀,生成嵌入向量
|
||||
prefix = "Represent this sentence for searching relevant passages: "
|
||||
return super()._get_query_embedding(prefix + query)
|
||||
embedding = super()._get_query_embedding(prefix + query)
|
||||
|
||||
# 使用logger打印数据类型
|
||||
logger.info(f"Query embedding dtype: {embedding.dtype}")
|
||||
return embedding
|
||||
|
||||
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
|
||||
# 批量生成嵌入向量
|
||||
prefix = "Represent this sentence for searching relevant passages: "
|
||||
return super()._get_query_embeddings([prefix + q for q in queries])
|
||||
embeddings = super()._get_query_embeddings([prefix + q for q in queries])
|
||||
|
||||
# 使用logger打印数据类型
|
||||
logger.info(f"Batch query embeddings dtype: {embeddings[0].dtype}")
|
||||
return embeddings
|
||||
|
||||
|
||||
def build_user_index(user_id: str):
|
||||
# 设置日志记录器,记录操作步骤和错误信息
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
logger.info(f"开始为用户 {user_id} 构建索引...")
|
||||
|
||||
# 确认文档目录是否存在
|
||||
|
|
|
|||
Loading…
Reference in New Issue