faiss_rag_enterprise/app/api/search.py

107 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from app.core.embedding import embedder # 使用已加载的本地嵌入模型BGEEmbedding
from app.core.config import settings
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index import VectorStoreIndex, ServiceContext, StorageContext, load_index_from_storage
from llama_index import set_global_service_context, ServiceContext
import os
import logging
import faiss # 引入faiss
router = APIRouter()
# 设置日志记录器
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class QueryRequest(BaseModel):
query: str
@router.post("/search")
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
try:
logger.info(f"Received search request from user: {user_id} with query: {request.query}")
# 修正后的索引路径,确保指向整个目录,而不是单个文件
index_path = os.path.join("index_data", user_id) # 使用整个目录路径
logger.info(f"Looking for index at path: {index_path}")
# 检查索引目录是否存在
if not os.path.exists(index_path):
logger.error(f"Index not found for user: {user_id} at {index_path}")
raise HTTPException(status_code=404, detail="用户索引不存在")
# 加载 Faiss 索引
faiss_index_file = os.path.join(index_path, "index.faiss") # 指定faiss索引文件路径
if not os.path.exists(faiss_index_file):
logger.error(f"Faiss index not found at {faiss_index_file}")
raise HTTPException(status_code=404, detail="Faiss索引文件未找到")
faiss_index = faiss.read_index(faiss_index_file) # 使用faiss加载索引文件
logger.info("Faiss index loaded successfully.")
# 创建 FaissVectorStore 实例
vector_store = FaissVectorStore(faiss_index=faiss_index)
logger.info("FaissVectorStore created successfully.")
# 创建 ServiceContext 实例,确保使用本地嵌入模型,且明确禁用 LLM
service_context = ServiceContext.from_defaults(embed_model=embedder, llm=None, llm_predictor=None)
logger.info("Service context created successfully.")
set_global_service_context(service_context)
# 创建 StorageContext 实例(确保同时加载文本和向量)
storage_context = StorageContext.from_defaults(persist_dir=index_path, vector_store=vector_store)
logger.info("Storage context created successfully.")
# 使用 load_index_from_storage 加载索引
index = load_index_from_storage(storage_context)
logger.info("VectorStoreIndex loaded successfully.")
# 设置索引的 ServiceContext
index.set_service_context(service_context)
logger.info("Service context set to index.")
# 将用户查询通过本地模型生成嵌入向量
query_vector = embedder.encode([request.query]) # 使用本地模型生成查询的嵌入向量
logger.info(f"Generated query embedding: {query_vector}")
# 使用 VectorStoreIndex 创建查询引擎
query_engine = index.as_query_engine(service_context=service_context)
logger.info(f"Query engine created successfully.")
# 执行查询
response = query_engine.query(request.query)
logger.info(f"Query response: {response}")
# 打印出每个结果的向量和文本
for i, node in enumerate(response):
logger.info(f"Result {i+1}:")
logger.info(f" Text: {node.get_content()}") # 打印文本
# 打印向量及其长度
embedding = node.embedding
logger.info(f" Embedding (Vector): {embedding}") # 打印向量
logger.info(f" Embedding Length: {len(embedding)}") # 打印向量的长度(即向量的维度)
# 返回检索结果
result = {
"user_id": user_id,
"query": request.query,
"results": [
{"score": float(node.score or 0), "text": node.get_content()} # 确保从 Node 中获取文本
for node in response
]
}
logger.info(f"Search results for user {user_id}: {result}")
return result
except Exception as e:
logger.error(f"Error processing search request: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))