faiss_rag_enterprise/scripts/rag_build_query.py

176 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import HTTPException # 导入 HTTPException 用于错误处理
from typing import List # 导入 List 用于类型注解
import os
import logging
import faiss
from llama_index import SimpleDirectoryReader, VectorStoreIndex, ServiceContext, StorageContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
from app.core.config import settings # 导入应用配置
from scripts.permissions import get_user_allowed_indexes # 导入权限函数,管理用户索引
import numpy as np
USER_INDEX_PATH = "index_data" # 用户索引存储路径
USER_DOC_PATH = "docs" # 用户文档存储路径
# 设置日志记录器,记录操作步骤和错误信息
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# # BGEEmbedding 类,继承自 HuggingFaceEmbedding用于生成查询的嵌入
# class BGEEmbedding(HuggingFaceEmbedding):
# def _get_query_embedding(self, query: str) -> List[float]:
# # 在查询前加上前缀,生成嵌入向量
# prefix = "Represent this sentence for searching relevant passages: "
# return super()._get_query_embedding(prefix + query)
# def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
# # 批量生成嵌入向量
# prefix = "Represent this sentence for searching relevant passages: "
# return super()._get_query_embeddings([prefix + q for q in queries])
# BGEEmbedding 类,继承自 HuggingFaceEmbedding用于生成查询的嵌入
class BGEEmbedding(HuggingFaceEmbedding):
def _get_query_embedding(self, query: str) -> List[float]:
try:
# 在查询前加上前缀,生成嵌入向量
logger.info("Calling _get_query_embedding method...")
prefix = "Represent this sentence for searching relevant passages: "
embedding = super()._get_query_embedding(prefix + query)
# 转换为 float32 类型
embedding = np.array(embedding, dtype=np.float32)
# 使用 logger 打印数据类型
logger.info(f"Query embedding dtype after conversion: {embedding.dtype}")
return embedding
except Exception as e:
logger.error(f"Error in _get_query_embedding: {e}")
raise
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
try:
# 批量生成嵌入向量
logger.info("Calling _get_query_embeddings method...")
prefix = "Represent this sentence for searching relevant passages: "
embeddings = super()._get_query_embeddings([prefix + q for q in queries])
# 转换为 float32 类型
embeddings = [np.array(embedding, dtype=np.float32) for embedding in embeddings]
# 使用 logger 打印数据类型
logger.info(f"Batch query embeddings dtype after conversion: {embeddings[0].dtype}")
return embeddings
except Exception as e:
logger.error(f"Error in _get_query_embeddings: {e}")
raise
def build_user_index(user_id: str):
logger.info(f"开始为用户 {user_id} 构建索引...")
# 确认文档目录是否存在
doc_dir = os.path.join(USER_DOC_PATH, user_id)
if not os.path.exists(doc_dir):
logger.error(f"文档目录不存在: {doc_dir}")
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
logger.info(f"发现文档目录: {doc_dir}")
# 使用 SimpleDirectoryReader 载入文档
try:
documents = SimpleDirectoryReader(doc_dir).load_data()
logger.info(f"载入文档数量: {len(documents)}")
except Exception as e:
logger.error(f"加载文档时出错: {e}")
raise HTTPException(status_code=500, detail="文档加载失败")
# 设置嵌入模型
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
logger.info(f"使用模型: {settings.MODEL_NAME}")
# 创建服务上下文
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
# 直接检查模型嵌入方法是否被调用
logger.info(f"Embedding method being used: {embed_model._get_query_embedding('test query')}")
# 使用 Faiss 向量存储
faiss_index = faiss.IndexFlatL2(1024)
vector_store = FaissVectorStore(faiss_index=faiss_index)
# 确保索引保存路径存在,使用用户 ID 区分索引文件
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
logger.info(f"索引保存路径: {persist_dir}")
# 如果目录不存在,则创建
if not os.path.exists(persist_dir):
logger.info(f"目录 {persist_dir} 不存在,准备创建")
os.makedirs(persist_dir, exist_ok=True)
logger.info(f"目录 {persist_dir} 已创建")
# 确保 index_store.json 文件路径存在
index_store_path = os.path.join(persist_dir, "index_store.json")
if not os.path.exists(index_store_path):
logger.info(f"未找到 index_store.json准备创建")
try:
with open(index_store_path, "w") as f:
f.write("{}") # 创建空的 index_store.json 文件
logger.info(f"已创建 index_store.json 文件")
except Exception as e:
logger.error(f"创建 index_store.json 时出错: {e}")
raise HTTPException(status_code=500, detail="创建 index_store.json 文件失败")
else:
logger.info(f"已找到 index_store.json跳过创建")
# 创建 StorageContext用于存储和管理索引数据
storage_context = StorageContext.from_defaults(
persist_dir=persist_dir,
vector_store=vector_store,
)
try:
# 构建索引,并使用 `storage_context.persist()` 方法保存索引
index = VectorStoreIndex.from_documents(
documents,
service_context=service_context,
storage_context=storage_context
)
# 保存 Faiss 索引为文件,使用正确的路径
faiss_index_file = os.path.join(persist_dir, "index.faiss")
faiss.write_index(faiss_index, faiss_index_file) # 使用 Faiss 的 write_index 方法保存索引
logger.info(f"Faiss 索引已保存到 {faiss_index_file}")
# 使用 storage_context.persist() 保存其他索引数据
storage_context.persist(persist_dir=persist_dir)
logger.info(f"索引数据已保存到 {persist_dir}")
# 持久化存储之后,加载已保存的存储上下文信息
logger.info(f"开始加载持久化存储的数据...")
# 创建一个新的 StorageContext 实例,使用相同的目录
loaded_storage_context = StorageContext.from_defaults(
persist_dir=persist_dir, # 使用与之前相同的目录
vector_store=FaissVectorStore(faiss_index=faiss_index) # 使用之前保存的 FAISS 索引
)
# 确认存储是否加载成功,检查索引数据
logger.info("已成功加载存储上下文。")
# 加载索引,进行检查
loaded_index = VectorStoreIndex.from_storage_context(loaded_storage_context)
logger.info(f"加载的索引数量: {len(loaded_index.get_documents())}")
except Exception as e:
logger.error(f"索引构建失败: {e}")
raise HTTPException(status_code=500, detail="索引构建失败")