faiss_rag_enterprise/app/api/search.py

64 lines
2.5 KiB
Python

from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from app.core.embedding import embedder
from app.core.config import settings
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index import VectorStoreIndex, ServiceContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import os
import logging
router = APIRouter()
# 设置日志记录器
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class QueryRequest(BaseModel):
query: str
@router.post("/search")
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
try:
logger.info(f"Received search request from user: {user_id} with query: {request.query}")
# 修正后的索引路径,确保指向具体的 index.faiss 文件
index_path = os.path.join("index_data", user_id, "index.faiss") # 指向文件,而非目录
logger.info(f"Looking for index at path: {index_path}")
# 检查索引是否存在
if not os.path.exists(index_path):
logger.error(f"Index not found for user: {user_id} at {index_path}")
raise HTTPException(status_code=404, detail="用户索引不存在")
# 构建 LlamaIndex 检索器
logger.info(f"Loading Faiss vector store from path: {index_path}")
faiss_store = FaissVectorStore.from_persist_path(index_path)
service_context = ServiceContext.from_defaults(embed_model=embedder)
logger.info("Service context created successfully.")
index = VectorStoreIndex.from_vector_store(faiss_store, service_context=service_context)
logger.info("VectorStoreIndex created successfully.")
# 检索结果(真实文本)
retriever = index.as_retriever(similarity_top_k=settings.TOP_K)
logger.info(f"Retrieving top {settings.TOP_K} results for query: {request.query}")
nodes = retriever.retrieve(request.query)
# 返回检索结果
result = {
"user_id": user_id,
"query": request.query,
"results": [
{"score": float(node.score or 0), "text": node.get_content()}
for node in nodes
]
}
logger.info(f"Search results for user {user_id}: {result}")
return result
except Exception as e:
logger.error(f"Error processing search request: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))