faiss_rag_enterprise/scripts/rag_build_query.py

96 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import logging
import faiss
from llama_index import SimpleDirectoryReader, VectorStoreIndex, ServiceContext, StorageContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
from app.core.config import settings # 确保配置导入
from scripts.permissions import get_user_allowed_indexes # 导入权限函数
USER_INDEX_PATH = "index_data"
USER_DOC_PATH = "docs"
# BGEEmbedding 类
class BGEEmbedding(HuggingFaceEmbedding):
def _get_query_embedding(self, query: str) -> List[float]:
prefix = "Represent this sentence for searching relevant passages: "
return super()._get_query_embedding(prefix + query)
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
prefix = "Represent this sentence for searching relevant passages: "
return super()._get_query_embeddings([prefix + q for q in queries])
def build_user_index(user_id: str):
# 设置日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.info(f"开始为用户 {user_id} 构建索引...")
# 确认文档目录
doc_dir = os.path.join(USER_DOC_PATH, user_id)
if not os.path.exists(doc_dir):
logger.error(f"文档目录不存在: {doc_dir}")
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
logger.info(f"发现文档目录: {doc_dir}")
# 载入文档
try:
documents = SimpleDirectoryReader(doc_dir).load_data()
logger.info(f"载入文档数量: {len(documents)}")
except Exception as e:
logger.error(f"加载文档时出错: {e}")
raise HTTPException(status_code=500, detail="文档加载失败")
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
logger.info(f"使用模型: {settings.MODEL_NAME}")
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
faiss_index = faiss.IndexFlatL2(1024)
vector_store = FaissVectorStore(faiss_index=faiss_index)
# 确保索引路径存在
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
logger.info(f"索引保存路径: {persist_dir}")
# 如果目录不存在,创建它
if not os.path.exists(persist_dir):
logger.info(f"目录 {persist_dir} 不存在,准备创建")
os.makedirs(persist_dir, exist_ok=True)
logger.info(f"目录 {persist_dir} 已创建")
# 确保 index_store.json 文件路径存在
index_store_path = os.path.join(persist_dir, "index_store.json")
if not os.path.exists(index_store_path):
logger.info(f"未找到 index_store.json准备创建")
try:
with open(index_store_path, "w") as f:
f.write("{}") # 创建空的 index_store.json 文件
logger.info(f"已创建 index_store.json 文件")
except Exception as e:
logger.error(f"创建 index_store.json 时出错: {e}")
raise HTTPException(status_code=500, detail="创建 index_store.json 文件失败")
else:
logger.info(f"已找到 index_store.json跳过创建")
# 创建 storage context
storage_context = StorageContext.from_defaults(
persist_dir=persist_dir,
vector_store=vector_store,
)
try:
# 构建索引
index = VectorStoreIndex.from_documents(
documents,
service_context=service_context,
storage_context=storage_context
)
index.persist(persist_dir=persist_dir)
logger.info(f"索引已保存到 {persist_dir}")
except Exception as e:
logger.error(f"索引构建失败: {e}")
raise HTTPException(status_code=500, detail="索引构建失败")