faiss_rag_enterprise/app/api/search.py

160 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from app.core.embedding import embedder # 使用已加载的本地嵌入模型BGEEmbedding
from app.core.config import settings
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index import VectorStoreIndex, ServiceContext, StorageContext, load_index_from_storage
from llama_index import set_global_service_context, ServiceContext
import os
import logging
import faiss # 引入faiss
router = APIRouter()
# 设置日志记录器
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class QueryRequest(BaseModel):
query: str
@router.post("/search")
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
try:
logger.info(f"Received search request from user: {user_id} with query: {request.query}")
# 修正后的索引路径,确保指向整个目录,而不是单个文件
index_path = os.path.join("index_data", user_id) # 使用整个目录路径
logger.info(f"Looking for index at path: {index_path}")
# 检查索引目录是否存在
if not os.path.exists(index_path):
logger.error(f"Index not found for user: {user_id} at {index_path}")
raise HTTPException(status_code=404, detail="用户索引不存在")
# 加载 Faiss 索引
faiss_index_file = os.path.join(index_path, "index.faiss") # 指定faiss索引文件路径
if not os.path.exists(faiss_index_file):
logger.error(f"Faiss index not found at {faiss_index_file}")
raise HTTPException(status_code=404, detail="Faiss索引文件未找到")
faiss_index = faiss.read_index(faiss_index_file) # 使用faiss加载索引文件
logger.info("Faiss index loaded successfully.")
# 创建 FaissVectorStore 实例
vector_store = FaissVectorStore(faiss_index=faiss_index)
logger.info("FaissVectorStore created successfully.")
# 创建 ServiceContext 实例,确保使用本地嵌入模型,且明确禁用 LLM
service_context = ServiceContext.from_defaults(embed_model=embedder, llm=None, llm_predictor=None)
logger.info("Service context created successfully.")
set_global_service_context(service_context)
# 创建 StorageContext 实例(确保同时加载文本和向量)
storage_context = StorageContext.from_defaults(persist_dir=index_path, vector_store=vector_store)
logger.info("Storage context created successfully.")
# 使用 load_index_from_storage 加载索引
index = load_index_from_storage(storage_context)
logger.info("VectorStoreIndex loaded successfully.")
# # 设置索引的 ServiceContext
# index.set_service_context(service_context)
# logger.info("Service context set to index.")
# 将用户查询通过本地模型生成嵌入向量
query_vector = embedder.encode([request.query]) # 使用本地模型生成查询的嵌入向量
logger.info(f"Generated query embedding: {query_vector}")
# 使用 VectorStoreIndex 创建查询引擎
query_engine = index.as_query_engine(service_context=service_context)
logger.info(f"Query engine created successfully.")
# 执行查询
response = query_engine.query(request.query)
logger.info(f"Query response: {response}")
# 打印出每个结果的向量和文本
if response and isinstance(response, list): # 检查 response 是否有效且为可迭代的列表
for i, node in enumerate(response):
# 确保每个 node 是合法的,具有必要的方法和属性
if hasattr(node, 'get_content') and hasattr(node, 'embedding'):
logger.info(f"Result {i + 1}:")
logger.info(f" Text: {node.get_content()}") # 打印文本
# 打印向量及其长度
embedding = node.embedding
logger.info(f" Embedding (Vector): {embedding}") # 打印向量
logger.info(f" Embedding Length: {len(embedding)}") # 打印向量的长度(即向量的维度)
else:
logger.warning(f"Result {i + 1} is missing necessary attributes.")
else:
logger.warning("No valid results found in response or response is not a list.")
# # 返回检索结果
# result = {
# "user_id": user_id,
# "query": request.query,
# "results": [
# {"score": float(node.score or 0), "text": node.get_content()} # 确保从 Node 中获取文本
# for node in response
# ]
# }
# 确保 response 是一个有效的列表且非空
if isinstance(response, list) and response: # 检查 response 是否是非空列表
# 返回检索结果
result = {
"user_id": user_id,
"query": request.query,
"results": []
}
for i, node in enumerate(response):
# 检查每个 node 是否具有必要的方法和属性
if hasattr(node, 'get_content') and hasattr(node, 'embedding'):
result["results"].append({
"score": float(node.score or 0), # 评分(如果有的话)
"text": node.get_content() # 获取文本内容
})
# 打印相关信息(如果需要日志记录)
logger.info(f"Result {i + 1}:")
logger.info(f" Text: {node.get_content()}") # 打印文本
embedding = node.embedding
logger.info(f" Embedding (Vector): {embedding}") # 打印向量
logger.info(f" Embedding Length: {len(embedding)}") # 打印向量长度(维度)
else:
logger.warning(f"Result {i + 1} is missing necessary attributes.")
else:
logger.warning("No valid results found in response or response is not a list.")
result = {
"user_id": user_id,
"query": request.query,
"results": [] # 返回空结果
}
logger.info(f"Search results for user {user_id}: {result}")
return result
except Exception as e:
logger.error(f"Error processing search request: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))