diff --git a/rag_build_query.py b/rag_build_query.py deleted file mode 100644 index 56c0b01..0000000 --- a/rag_build_query.py +++ /dev/null @@ -1,174 +0,0 @@ -from fastapi import HTTPException # 导入 HTTPException 用于错误处理 -from typing import List # 导入 List 用于类型注解 -import os -import logging -import faiss -from llama_index import SimpleDirectoryReader, VectorStoreIndex, ServiceContext, StorageContext -from llama_index.embeddings.huggingface import HuggingFaceEmbedding -from llama_index.vector_stores.faiss import FaissVectorStore -from app.core.config import settings # 导入应用配置 -from scripts.permissions import get_user_allowed_indexes # 导入权限函数,管理用户索引 - -USER_INDEX_PATH = "index_data" # 用户索引存储路径 -USER_DOC_PATH = "docs" # 用户文档存储路径 - - -# 设置日志记录器,记录操作步骤和错误信息 -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -# # BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入 -# class BGEEmbedding(HuggingFaceEmbedding): -# def _get_query_embedding(self, query: str) -> List[float]: -# # 在查询前加上前缀,生成嵌入向量 -# prefix = "Represent this sentence for searching relevant passages: " -# return super()._get_query_embedding(prefix + query) - -# def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: -# # 批量生成嵌入向量 -# prefix = "Represent this sentence for searching relevant passages: " -# return super()._get_query_embeddings([prefix + q for q in queries]) - - -# BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入 -class BGEEmbedding(HuggingFaceEmbedding): - def _get_query_embedding(self, query: str) -> List[float]: - try: - # 在查询前加上前缀,生成嵌入向量 - logger.info("Calling _get_query_embedding method...") - prefix = "Represent this sentence for searching relevant passages: " - embedding = super()._get_query_embedding(prefix + query) - - # 转换为 float32 类型 - embedding = np.array(embedding, dtype=np.float32) - - # 使用 logger 打印数据类型 - logger.info(f"Query embedding dtype after conversion: {embedding.dtype}") - return embedding - except Exception as e: - logger.error(f"Error in _get_query_embedding: {e}") - raise - - def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: - try: - # 批量生成嵌入向量 - logger.info("Calling _get_query_embeddings method...") - prefix = "Represent this sentence for searching relevant passages: " - embeddings = super()._get_query_embeddings([prefix + q for q in queries]) - - # 转换为 float32 类型 - embeddings = [np.array(embedding, dtype=np.float32) for embedding in embeddings] - - # 使用 logger 打印数据类型 - logger.info(f"Batch query embeddings dtype after conversion: {embeddings[0].dtype}") - return embeddings - except Exception as e: - logger.error(f"Error in _get_query_embeddings: {e}") - raise - - -def build_user_index(user_id: str): - logger.info(f"开始为用户 {user_id} 构建索引...") - - # 确认文档目录是否存在 - doc_dir = os.path.join(USER_DOC_PATH, user_id) - if not os.path.exists(doc_dir): - logger.error(f"文档目录不存在: {doc_dir}") - raise FileNotFoundError(f"文档目录不存在: {doc_dir}") - - logger.info(f"发现文档目录: {doc_dir}") - - # 使用 SimpleDirectoryReader 载入文档 - try: - documents = SimpleDirectoryReader(doc_dir).load_data() - logger.info(f"载入文档数量: {len(documents)}") - except Exception as e: - logger.error(f"加载文档时出错: {e}") - raise HTTPException(status_code=500, detail="文档加载失败") - - # 设置嵌入模型 - embed_model = BGEEmbedding(model_name=settings.MODEL_NAME) - logger.info(f"使用模型: {settings.MODEL_NAME}") - - # 创建服务上下文 - service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) - - # 直接检查模型嵌入方法是否被调用 - logger.info(f"Embedding method being used: {embed_model._get_query_embedding('test query')}") - - - # 使用 Faiss 向量存储 - faiss_index = faiss.IndexFlatL2(1024) - vector_store = FaissVectorStore(faiss_index=faiss_index) - - # 确保索引保存路径存在,使用用户 ID 区分索引文件 - persist_dir = os.path.join(USER_INDEX_PATH, user_id) - logger.info(f"索引保存路径: {persist_dir}") - - # 如果目录不存在,则创建 - if not os.path.exists(persist_dir): - logger.info(f"目录 {persist_dir} 不存在,准备创建") - os.makedirs(persist_dir, exist_ok=True) - logger.info(f"目录 {persist_dir} 已创建") - - # 确保 index_store.json 文件路径存在 - index_store_path = os.path.join(persist_dir, "index_store.json") - if not os.path.exists(index_store_path): - logger.info(f"未找到 index_store.json,准备创建") - try: - with open(index_store_path, "w") as f: - f.write("{}") # 创建空的 index_store.json 文件 - logger.info(f"已创建 index_store.json 文件") - except Exception as e: - logger.error(f"创建 index_store.json 时出错: {e}") - raise HTTPException(status_code=500, detail="创建 index_store.json 文件失败") - else: - logger.info(f"已找到 index_store.json,跳过创建") - - # 创建 StorageContext,用于存储和管理索引数据 - storage_context = StorageContext.from_defaults( - persist_dir=persist_dir, - vector_store=vector_store, - ) - - try: - # 构建索引,并使用 `storage_context.persist()` 方法保存索引 - index = VectorStoreIndex.from_documents( - documents, - service_context=service_context, - storage_context=storage_context - ) - - # 保存 Faiss 索引为文件,使用正确的路径 - faiss_index_file = os.path.join(persist_dir, "index.faiss") - faiss.write_index(faiss_index, faiss_index_file) # 使用 Faiss 的 write_index 方法保存索引 - logger.info(f"Faiss 索引已保存到 {faiss_index_file}") - - # 使用 storage_context.persist() 保存其他索引数据 - storage_context.persist(persist_dir=persist_dir) - logger.info(f"索引数据已保存到 {persist_dir}") - - - - # 持久化存储之后,加载已保存的存储上下文信息 - logger.info(f"开始加载持久化存储的数据...") - - # 创建一个新的 StorageContext 实例,使用相同的目录 - loaded_storage_context = StorageContext.from_defaults( - persist_dir=persist_dir, # 使用与之前相同的目录 - vector_store=FaissVectorStore(faiss_index=faiss_index) # 使用之前保存的 FAISS 索引 - ) - - # 确认存储是否加载成功,检查索引数据 - logger.info("已成功加载存储上下文。") - - # 加载索引,进行检查 - loaded_index = VectorStoreIndex.from_storage_context(loaded_storage_context) - logger.info(f"加载的索引数量: {len(loaded_index.get_documents())}") - - - - except Exception as e: - logger.error(f"索引构建失败: {e}") - raise HTTPException(status_code=500, detail="索引构建失败") diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index 2f0d24e..56c0b01 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -8,28 +8,67 @@ from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.vector_stores.faiss import FaissVectorStore from app.core.config import settings # 导入应用配置 from scripts.permissions import get_user_allowed_indexes # 导入权限函数,管理用户索引 -import numpy as np USER_INDEX_PATH = "index_data" # 用户索引存储路径 USER_DOC_PATH = "docs" # 用户文档存储路径 + +# 设置日志记录器,记录操作步骤和错误信息 +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +# # BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入 +# class BGEEmbedding(HuggingFaceEmbedding): +# def _get_query_embedding(self, query: str) -> List[float]: +# # 在查询前加上前缀,生成嵌入向量 +# prefix = "Represent this sentence for searching relevant passages: " +# return super()._get_query_embedding(prefix + query) + +# def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: +# # 批量生成嵌入向量 +# prefix = "Represent this sentence for searching relevant passages: " +# return super()._get_query_embeddings([prefix + q for q in queries]) + + # BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入 class BGEEmbedding(HuggingFaceEmbedding): def _get_query_embedding(self, query: str) -> List[float]: - # 在查询前加上前缀,生成嵌入向量 - prefix = "Represent this sentence for searching relevant passages: " - return super()._get_query_embedding(prefix + query) + try: + # 在查询前加上前缀,生成嵌入向量 + logger.info("Calling _get_query_embedding method...") + prefix = "Represent this sentence for searching relevant passages: " + embedding = super()._get_query_embedding(prefix + query) + + # 转换为 float32 类型 + embedding = np.array(embedding, dtype=np.float32) + + # 使用 logger 打印数据类型 + logger.info(f"Query embedding dtype after conversion: {embedding.dtype}") + return embedding + except Exception as e: + logger.error(f"Error in _get_query_embedding: {e}") + raise def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]: - # 批量生成嵌入向量 - prefix = "Represent this sentence for searching relevant passages: " - return super()._get_query_embeddings([prefix + q for q in queries]) + try: + # 批量生成嵌入向量 + logger.info("Calling _get_query_embeddings method...") + prefix = "Represent this sentence for searching relevant passages: " + embeddings = super()._get_query_embeddings([prefix + q for q in queries]) + # 转换为 float32 类型 + embeddings = [np.array(embedding, dtype=np.float32) for embedding in embeddings] + + # 使用 logger 打印数据类型 + logger.info(f"Batch query embeddings dtype after conversion: {embeddings[0].dtype}") + return embeddings + except Exception as e: + logger.error(f"Error in _get_query_embeddings: {e}") + raise + + def build_user_index(user_id: str): - # 设置日志记录器,记录操作步骤和错误信息 - logger = logging.getLogger(__name__) - logger.setLevel(logging.INFO) - logger.info(f"开始为用户 {user_id} 构建索引...") # 确认文档目录是否存在 @@ -55,8 +94,12 @@ def build_user_index(user_id: str): # 创建服务上下文 service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None) - # 创建 Faiss 向量存储 - faiss_index = faiss.IndexFlatL2(1024) # 假设每个向量是1024维度 + # 直接检查模型嵌入方法是否被调用 + logger.info(f"Embedding method being used: {embed_model._get_query_embedding('test query')}") + + + # 使用 Faiss 向量存储 + faiss_index = faiss.IndexFlatL2(1024) vector_store = FaissVectorStore(faiss_index=faiss_index) # 确保索引保存路径存在,使用用户 ID 区分索引文件 @@ -90,21 +133,14 @@ def build_user_index(user_id: str): ) try: - # 生成嵌入向量 - embeddings = [embed_model._get_query_embedding(doc.text) for doc in documents] - logger.info(f"已生成 {len(embeddings)} 个嵌入向量") - - # 将嵌入向量存储到 Faiss - faiss_index.add(np.array(embeddings)) # 将嵌入向量添加到 Faiss 索引中 - - # 创建 VectorStoreIndex + # 构建索引,并使用 `storage_context.persist()` 方法保存索引 index = VectorStoreIndex.from_documents( documents, service_context=service_context, storage_context=storage_context ) - # 保存 Faiss 索引为文件 + # 保存 Faiss 索引为文件,使用正确的路径 faiss_index_file = os.path.join(persist_dir, "index.faiss") faiss.write_index(faiss_index, faiss_index_file) # 使用 Faiss 的 write_index 方法保存索引 logger.info(f"Faiss 索引已保存到 {faiss_index_file}") @@ -113,6 +149,26 @@ def build_user_index(user_id: str): storage_context.persist(persist_dir=persist_dir) logger.info(f"索引数据已保存到 {persist_dir}") + + + # 持久化存储之后,加载已保存的存储上下文信息 + logger.info(f"开始加载持久化存储的数据...") + + # 创建一个新的 StorageContext 实例,使用相同的目录 + loaded_storage_context = StorageContext.from_defaults( + persist_dir=persist_dir, # 使用与之前相同的目录 + vector_store=FaissVectorStore(faiss_index=faiss_index) # 使用之前保存的 FAISS 索引 + ) + + # 确认存储是否加载成功,检查索引数据 + logger.info("已成功加载存储上下文。") + + # 加载索引,进行检查 + loaded_index = VectorStoreIndex.from_storage_context(loaded_storage_context) + logger.info(f"加载的索引数量: {len(loaded_index.get_documents())}") + + + except Exception as e: logger.error(f"索引构建失败: {e}") raise HTTPException(status_code=500, detail="索引构建失败")