faiss_rag_enterprise/scripts/rag_build_query.py

82 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import logging
from typing import List
import faiss
from llama_index import (
SimpleDirectoryReader,
VectorStoreIndex,
ServiceContext,
StorageContext,
PromptTemplate
)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
from app.core.config import settings # 确保配置导入
from scripts.permissions import get_user_allowed_indexes # 确保权限导入
USER_INDEX_PATH = "index_data"
USER_DOC_PATH = "docs"
# ✅ 自动加前缀的 BGE-m3 embedding 封装类
class BGEEmbedding(HuggingFaceEmbedding):
def _get_query_embedding(self, query: str) -> List[float]:
prefix = "Represent this sentence for searching relevant passages: "
return super()._get_query_embedding(prefix + query)
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
prefix = "Represent this sentence for searching relevant passages: "
return super()._get_query_embeddings([prefix + q for q in queries])
def build_user_index(user_id: str):
# 设置日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.info(f"开始为用户 {user_id} 构建索引...")
doc_dir = os.path.join(USER_DOC_PATH, user_id)
if not os.path.exists(doc_dir):
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
logger.info(f"发现文档目录: {doc_dir}")
documents = SimpleDirectoryReader(doc_dir).load_data()
logger.info(f"载入文档数量: {len(documents)}")
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
logger.info(f"使用模型: {settings.MODEL_NAME}")
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
faiss_index = faiss.IndexFlatL2(1024)
vector_store = FaissVectorStore(faiss_index=faiss_index)
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
os.makedirs(persist_dir, exist_ok=True)
logger.info(f"索引保存路径: {persist_dir}")
# 检查目录中是否存在 index_store.json 文件
index_store_path = os.path.join(persist_dir, "index_store.json")
if not os.path.exists(index_store_path):
logger.info(f"未找到 index_store.json准备创建")
else:
logger.info(f"已找到 index_store.json跳过创建")
storage_context = StorageContext.from_defaults(
persist_dir=persist_dir,
vector_store=vector_store,
)
try:
# 构建索引
index = VectorStoreIndex.from_documents(
documents,
service_context=service_context,
storage_context=storage_context
)
index.persist(persist_dir=persist_dir)
logger.info(f"索引已保存到 {persist_dir}")
except Exception as e:
logger.error(f"索引构建失败: {e}")
raise HTTPException(status_code=500, detail="索引构建失败")