This commit is contained in:
parent
c6bb1a381a
commit
256646f426
|
|
@ -4,7 +4,7 @@ class Settings(BaseSettings):
|
||||||
EMBEDDING_DIM: int = 768 # 嵌入维度(取决于模型)
|
EMBEDDING_DIM: int = 768 # 嵌入维度(取决于模型)
|
||||||
TOP_K: int = 5 # 默认检索 top K 个段落
|
TOP_K: int = 5 # 默认检索 top K 个段落
|
||||||
DOC_PATH: str = "docs/" # 默认文档根目录
|
DOC_PATH: str = "docs/" # 默认文档根目录
|
||||||
DEVICE: str = "cpu" # 可改为 "cuda" 使用 GPU
|
DEVICE: str = "cuda" # 可改为 "cuda" 使用 GPU
|
||||||
MODEL_NAME: str = "BAAI/bge-m3" # 多语种语义嵌入模型
|
MODEL_NAME: str = "BAAI/bge-m3" # 多语种语义嵌入模型
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
|
||||||
|
|
@ -12,23 +12,47 @@ from scripts.permissions import get_user_allowed_indexes # 导入权限函数
|
||||||
USER_INDEX_PATH = "index_data" # 用户索引存储路径
|
USER_INDEX_PATH = "index_data" # 用户索引存储路径
|
||||||
USER_DOC_PATH = "docs" # 用户文档存储路径
|
USER_DOC_PATH = "docs" # 用户文档存储路径
|
||||||
|
|
||||||
|
|
||||||
|
# 设置日志记录器,记录操作步骤和错误信息
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
# # BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入
|
||||||
|
# class BGEEmbedding(HuggingFaceEmbedding):
|
||||||
|
# def _get_query_embedding(self, query: str) -> List[float]:
|
||||||
|
# # 在查询前加上前缀,生成嵌入向量
|
||||||
|
# prefix = "Represent this sentence for searching relevant passages: "
|
||||||
|
# return super()._get_query_embedding(prefix + query)
|
||||||
|
|
||||||
|
# def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
|
||||||
|
# # 批量生成嵌入向量
|
||||||
|
# prefix = "Represent this sentence for searching relevant passages: "
|
||||||
|
# return super()._get_query_embeddings([prefix + q for q in queries])
|
||||||
|
|
||||||
|
|
||||||
# BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入
|
# BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入
|
||||||
class BGEEmbedding(HuggingFaceEmbedding):
|
class BGEEmbedding(HuggingFaceEmbedding):
|
||||||
def _get_query_embedding(self, query: str) -> List[float]:
|
def _get_query_embedding(self, query: str) -> List[float]:
|
||||||
# 在查询前加上前缀,生成嵌入向量
|
# 在查询前加上前缀,生成嵌入向量
|
||||||
prefix = "Represent this sentence for searching relevant passages: "
|
prefix = "Represent this sentence for searching relevant passages: "
|
||||||
return super()._get_query_embedding(prefix + query)
|
embedding = super()._get_query_embedding(prefix + query)
|
||||||
|
|
||||||
|
# 使用logger打印数据类型
|
||||||
|
logger.info(f"Query embedding dtype: {embedding.dtype}")
|
||||||
|
return embedding
|
||||||
|
|
||||||
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
|
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
|
||||||
# 批量生成嵌入向量
|
# 批量生成嵌入向量
|
||||||
prefix = "Represent this sentence for searching relevant passages: "
|
prefix = "Represent this sentence for searching relevant passages: "
|
||||||
return super()._get_query_embeddings([prefix + q for q in queries])
|
embeddings = super()._get_query_embeddings([prefix + q for q in queries])
|
||||||
|
|
||||||
|
# 使用logger打印数据类型
|
||||||
|
logger.info(f"Batch query embeddings dtype: {embeddings[0].dtype}")
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def build_user_index(user_id: str):
|
def build_user_index(user_id: str):
|
||||||
# 设置日志记录器,记录操作步骤和错误信息
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
logger.info(f"开始为用户 {user_id} 构建索引...")
|
logger.info(f"开始为用户 {user_id} 构建索引...")
|
||||||
|
|
||||||
# 确认文档目录是否存在
|
# 确认文档目录是否存在
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue