120 lines
4.9 KiB
Python
120 lines
4.9 KiB
Python
from fastapi import APIRouter, HTTPException, Query
|
||
from pydantic import BaseModel
|
||
from app.core.embedding import embedder # 使用已加载的本地嵌入模型(BGEEmbedding)
|
||
from app.core.config import settings
|
||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||
from llama_index import VectorStoreIndex, ServiceContext, StorageContext, load_index_from_storage
|
||
|
||
from llama_index import set_global_service_context, ServiceContext
|
||
|
||
|
||
import os
|
||
import logging
|
||
import faiss # 引入faiss
|
||
|
||
router = APIRouter()
|
||
|
||
# 设置日志记录器
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class QueryRequest(BaseModel):
|
||
query: str
|
||
|
||
@router.post("/search")
|
||
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
|
||
try:
|
||
logger.info(f"Received search request from user: {user_id} with query: {request.query}")
|
||
|
||
# 修正后的索引路径,确保指向整个目录,而不是单个文件
|
||
index_path = os.path.join("index_data", user_id) # 使用整个目录路径
|
||
logger.info(f"Looking for index at path: {index_path}")
|
||
|
||
# 检查索引目录是否存在
|
||
if not os.path.exists(index_path):
|
||
logger.error(f"Index not found for user: {user_id} at {index_path}")
|
||
raise HTTPException(status_code=404, detail="用户索引不存在")
|
||
|
||
# 加载 Faiss 索引
|
||
faiss_index_file = os.path.join(index_path, "index.faiss") # 指定faiss索引文件路径
|
||
if not os.path.exists(faiss_index_file):
|
||
logger.error(f"Faiss index not found at {faiss_index_file}")
|
||
raise HTTPException(status_code=404, detail="Faiss索引文件未找到")
|
||
|
||
faiss_index = faiss.read_index(faiss_index_file) # 使用faiss加载索引文件
|
||
logger.info("Faiss index loaded successfully.")
|
||
|
||
# 创建 FaissVectorStore 实例
|
||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||
logger.info("FaissVectorStore created successfully.")
|
||
|
||
# 创建 ServiceContext 实例,确保使用本地嵌入模型,且明确禁用 LLM
|
||
service_context = ServiceContext.from_defaults(embed_model=embedder, llm=None, llm_predictor=None)
|
||
logger.info("Service context created successfully.")
|
||
|
||
set_global_service_context(service_context)
|
||
|
||
# 创建 StorageContext 实例(确保同时加载文本和向量)
|
||
storage_context = StorageContext.from_defaults(persist_dir=index_path, vector_store=vector_store)
|
||
logger.info("Storage context created successfully.")
|
||
|
||
# 使用 load_index_from_storage 加载索引
|
||
index = load_index_from_storage(storage_context)
|
||
logger.info("VectorStoreIndex loaded successfully.")
|
||
|
||
# # 设置索引的 ServiceContext
|
||
# index.set_service_context(service_context)
|
||
# logger.info("Service context set to index.")
|
||
|
||
# 将用户查询通过本地模型生成嵌入向量
|
||
query_vector = embedder.encode([request.query]) # 使用本地模型生成查询的嵌入向量
|
||
logger.info(f"Generated query embedding: {query_vector}")
|
||
|
||
# 使用 VectorStoreIndex 创建查询引擎
|
||
query_engine = index.as_query_engine(service_context=service_context)
|
||
logger.info(f"Query engine created successfully.")
|
||
|
||
# 执行查询
|
||
response = query_engine.query(request.query)
|
||
logger.info(f"Query response: {response}")
|
||
|
||
|
||
|
||
|
||
# 打印出每个结果的向量和文本
|
||
if response and isinstance(response, list): # 检查 response 是否有效且为可迭代的列表
|
||
for i, node in enumerate(response):
|
||
# 确保每个 node 是合法的,具有必要的方法和属性
|
||
if hasattr(node, 'get_content') and hasattr(node, 'embedding'):
|
||
logger.info(f"Result {i + 1}:")
|
||
logger.info(f" Text: {node.get_content()}") # 打印文本
|
||
|
||
# 打印向量及其长度
|
||
embedding = node.embedding
|
||
logger.info(f" Embedding (Vector): {embedding}") # 打印向量
|
||
logger.info(f" Embedding Length: {len(embedding)}") # 打印向量的长度(即向量的维度)
|
||
else:
|
||
logger.warning(f"Result {i + 1} is missing necessary attributes.")
|
||
else:
|
||
logger.warning("No valid results found in response or response is not a list.")
|
||
|
||
|
||
|
||
|
||
# 返回检索结果
|
||
result = {
|
||
"user_id": user_id,
|
||
"query": request.query,
|
||
"results": [
|
||
{"score": float(node.score or 0), "text": node.get_content()} # 确保从 Node 中获取文本
|
||
for node in response
|
||
]
|
||
}
|
||
|
||
logger.info(f"Search results for user {user_id}: {result}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error processing search request: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|