From 4dec03c815d9d7e8696e0b4555aa99e78e7192c8 Mon Sep 17 00:00:00 2001 From: hailin Date: Sat, 10 May 2025 00:03:41 +0800 Subject: [PATCH] . --- scripts/rag_build_query.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index df898df..8445070 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -1,34 +1,37 @@ -from typing import List +from fastapi import HTTPException # 导入 HTTPException 用于错误处理 +from typing import List # 导入 List 用于类型注解 import os import logging import faiss from llama_index import SimpleDirectoryReader, VectorStoreIndex, ServiceContext, StorageContext from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.vector_stores.faiss import FaissVectorStore -from app.core.config import settings # 确保配置导入 -from scripts.permissions import get_user_allowed_indexes # 导入权限函数 +from app.core.config import settings # 导入应用配置 +from scripts.permissions import get_user_allowed_indexes # 导入权限函数,管理用户索引 -USER_INDEX_PATH = "index_data" -USER_DOC_PATH = "docs" +USER_INDEX_PATH = "index_data" # 用户索引存储路径 +USER_DOC_PATH = "docs" # 用户文档存储路径 -# BGEEmbedding 类 +# BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入 class BGEEmbedding(HuggingFaceEmbedding): def _get_query_embedding(self, query: str) -> List[float]: + # 在查询前加上前缀,生成嵌入向量 prefix = "Represent this sentence for searching relevant passages: " return super()._get_query_embedding(prefix + query) def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: + # 批量生成嵌入向量 prefix = "Represent this sentence for searching relevant passages: " return super()._get_query_embeddings([prefix + q for q in queries]) def build_user_index(user_id: str): - # 设置日志 + # 设置日志记录器,记录操作步骤和错误信息 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) logger.info(f"开始为用户 {user_id} 构建索引...") - # 确认文档目录 + # 确认文档目录是否存在 doc_dir = os.path.join(USER_DOC_PATH, user_id) if not os.path.exists(doc_dir): logger.error(f"文档目录不存在: {doc_dir}") @@ -36,7 +39,7 @@ def build_user_index(user_id: str): logger.info(f"发现文档目录: {doc_dir}") - # 载入文档 + # 使用 SimpleDirectoryReader 载入文档 try: documents = SimpleDirectoryReader(doc_dir).load_data() logger.info(f"载入文档数量: {len(documents)}") @@ -44,19 +47,22 @@ def build_user_index(user_id: str): logger.error(f"加载文档时出错: {e}") raise HTTPException(status_code=500, detail="文档加载失败") + # 设置嵌入模型 embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) logger.info(f"使用模型: {settings.MODEL_NAME}") + # 创建服务上下文 service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) + # 使用 Faiss 向量存储 faiss_index = faiss.IndexFlatL2(1024) vector_store = FaissVectorStore(faiss_index=faiss_index) - # 确保索引路径存在 + # 确保索引保存路径存在 persist_dir = os.path.join(USER_INDEX_PATH, user_id) logger.info(f"索引保存路径: {persist_dir}") - # 如果目录不存在,创建它 + # 如果目录不存在,则创建 if not os.path.exists(persist_dir): logger.info(f"目录 {persist_dir} 不存在,准备创建") os.makedirs(persist_dir, exist_ok=True) @@ -76,20 +82,20 @@ def build_user_index(user_id: str): else: logger.info(f"已找到 index_store.json,跳过创建") - # 创建 storage context + # 创建 StorageContext,用于存储和管理索引数据 storage_context = StorageContext.from_defaults( persist_dir=persist_dir, vector_store=vector_store, ) try: - # 构建索引 + # 构建索引,并使用 `save()` 方法保存索引 index = VectorStoreIndex.from_documents( documents, service_context=service_context, storage_context=storage_context ) - index.persist(persist_dir=persist_dir) + index.save(persist_dir=persist_dir) # 使用 `save()` 方法代替 `persist()` logger.info(f"索引已保存到 {persist_dir}") except Exception as e: logger.error(f"索引构建失败: {e}")