faiss_rag_enterprise/scripts/rag_build_query.py

103 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import HTTPException # 导入 HTTPException 用于错误处理
from typing import List # 导入 List 用于类型注解
import os
import logging
import faiss
from llama_index import SimpleDirectoryReader, VectorStoreIndex, ServiceContext, StorageContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
from app.core.config import settings # 导入应用配置
from scripts.permissions import get_user_allowed_indexes # 导入权限函数,管理用户索引
USER_INDEX_PATH = "index_data" # 用户索引存储路径
USER_DOC_PATH = "docs" # 用户文档存储路径
# BGEEmbedding 类,继承自 HuggingFaceEmbedding用于生成查询的嵌入
class BGEEmbedding(HuggingFaceEmbedding):
def _get_query_embedding(self, query: str) -> List[float]:
# 在查询前加上前缀,生成嵌入向量
prefix = "Represent this sentence for searching relevant passages: "
return super()._get_query_embedding(prefix + query)
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
# 批量生成嵌入向量
prefix = "Represent this sentence for searching relevant passages: "
return super()._get_query_embeddings([prefix + q for q in queries])
def build_user_index(user_id: str):
# 设置日志记录器,记录操作步骤和错误信息
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.info(f"开始为用户 {user_id} 构建索引...")
# 确认文档目录是否存在
doc_dir = os.path.join(USER_DOC_PATH, user_id)
if not os.path.exists(doc_dir):
logger.error(f"文档目录不存在: {doc_dir}")
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
logger.info(f"发现文档目录: {doc_dir}")
# 使用 SimpleDirectoryReader 载入文档
try:
documents = SimpleDirectoryReader(doc_dir).load_data()
logger.info(f"载入文档数量: {len(documents)}")
except Exception as e:
logger.error(f"加载文档时出错: {e}")
raise HTTPException(status_code=500, detail="文档加载失败")
# 设置嵌入模型
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
logger.info(f"使用模型: {settings.MODEL_NAME}")
# 创建服务上下文
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
# 使用 Faiss 向量存储
faiss_index = faiss.IndexFlatL2(1024)
vector_store = FaissVectorStore(faiss_index=faiss_index)
# 确保索引保存路径存在
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
logger.info(f"索引保存路径: {persist_dir}")
# 如果目录不存在,则创建
if not os.path.exists(persist_dir):
logger.info(f"目录 {persist_dir} 不存在,准备创建")
os.makedirs(persist_dir, exist_ok=True)
logger.info(f"目录 {persist_dir} 已创建")
# 确保 index_store.json 文件路径存在
index_store_path = os.path.join(persist_dir, "index_store.json")
if not os.path.exists(index_store_path):
logger.info(f"未找到 index_store.json准备创建")
try:
with open(index_store_path, "w") as f:
f.write("{}") # 创建空的 index_store.json 文件
logger.info(f"已创建 index_store.json 文件")
except Exception as e:
logger.error(f"创建 index_store.json 时出错: {e}")
raise HTTPException(status_code=500, detail="创建 index_store.json 文件失败")
else:
logger.info(f"已找到 index_store.json跳过创建")
# 创建 StorageContext用于存储和管理索引数据
storage_context = StorageContext.from_defaults(
persist_dir=persist_dir,
vector_store=vector_store,
)
try:
# 构建索引,并使用 `storage_context.persist()` 方法保存索引
index = VectorStoreIndex.from_documents(
documents,
service_context=service_context,
storage_context=storage_context
)
storage_context.persist(persist_dir=persist_dir) # 使用 `storage_context.persist()` 保存索引
logger.info(f"索引已保存到 {persist_dir}")
except Exception as e:
logger.error(f"索引构建失败: {e}")
raise HTTPException(status_code=500, detail="索引构建失败")