faiss_rag_enterprise/app/api/search.py

64 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from app.core.embedding import embedder
from app.core.config import settings
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index import VectorStoreIndex, ServiceContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import os
import logging
router = APIRouter()
# 设置日志记录器
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class QueryRequest(BaseModel):
query: str
@router.post("/search")
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
try:
logger.info(f"Received search request from user: {user_id} with query: {request.query}")
# 修正后的索引路径使用用户ID并且不带 ".index" 后缀
index_path = os.path.join("index_data", user_id) # 仅使用 user_id 作为文件夹名
logger.info(f"Looking for index at path: {index_path}")
# 检查索引是否存在
if not os.path.exists(index_path):
logger.error(f"Index not found for user: {user_id} at {index_path}")
raise HTTPException(status_code=404, detail="用户索引不存在")
# 构建 LlamaIndex 检索器
logger.info(f"Loading Faiss vector store from path: {index_path}")
faiss_store = FaissVectorStore.from_persist_path(index_path)
service_context = ServiceContext.from_defaults(embed_model=embedder)
logger.info("Service context created successfully.")
index = VectorStoreIndex.from_vector_store(faiss_store, service_context=service_context)
logger.info("VectorStoreIndex created successfully.")
# 检索结果(真实文本)
retriever = index.as_retriever(similarity_top_k=settings.TOP_K)
logger.info(f"Retrieving top {settings.TOP_K} results for query: {request.query}")
nodes = retriever.retrieve(request.query)
# 返回检索结果
result = {
"user_id": user_id,
"query": request.query,
"results": [
{"score": float(node.score or 0), "text": node.get_content()}
for node in nodes
]
}
logger.info(f"Search results for user {user_id}: {result}")
return result
except Exception as e:
logger.error(f"Error processing search request: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))