From b89f81229ec706cdc9348bd871277181f06ad760 Mon Sep 17 00:00:00 2001 From: hailin Date: Sun, 11 May 2025 12:06:06 +0800 Subject: [PATCH] . --- app/api/search.py | 96 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/app/api/search.py b/app/api/search.py index e69de29..3dc6612 100644 --- a/app/api/search.py +++ b/app/api/search.py @@ -0,0 +1,96 @@ +from fastapi import APIRouter, HTTPException, Query +from pydantic import BaseModel +from app.core.embedding import embedder # 使用已加载的本地嵌入模型(BGEEmbedding) +from app.core.config import settings +from llama_index.vector_stores.faiss import FaissVectorStore +from llama_index import VectorStoreIndex, ServiceContext, StorageContext, load_index_from_storage +import os +import logging +import faiss # 引入faiss + +router = APIRouter() + +# 设置日志记录器 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class QueryRequest(BaseModel): + query: str + +@router.post("/search") +def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")): + try: + logger.info(f"Received search request from user: {user_id} with query: {request.query}") + + # 修正后的索引路径,确保指向整个目录,而不是单个文件 + index_path = os.path.join("index_data", user_id) # 使用整个目录路径 + logger.info(f"Looking for index at path: {index_path}") + + # 检查索引目录是否存在 + if not os.path.exists(index_path): + logger.error(f"Index not found for user: {user_id} at {index_path}") + raise HTTPException(status_code=404, detail="用户索引不存在") + + # 加载 Faiss 索引 + faiss_index_file = os.path.join(index_path, "index.faiss") # 指定faiss索引文件路径 + if not os.path.exists(faiss_index_file): + logger.error(f"Faiss index not found at {faiss_index_file}") + raise HTTPException(status_code=404, detail="Faiss索引文件未找到") + + faiss_index = faiss.read_index(faiss_index_file) # 使用faiss加载索引文件 + logger.info("Faiss index loaded successfully.") + + # 创建 FaissVectorStore 实例 + vector_store = FaissVectorStore(faiss_index=faiss_index) + logger.info("FaissVectorStore created successfully.") + + # 创建 StorageContext 实例(确保同时加载文本和向量) + storage_context = StorageContext.from_defaults(persist_dir=index_path, vector_store=vector_store) + logger.info("Storage context created successfully.") + + # 创建 ServiceContext 实例 + service_context = ServiceContext.from_defaults(embed_model=embedder, llm=None) + logger.info("Service context created successfully.") + + # 使用 load_index_from_storage 加载索引 + index = load_index_from_storage(storage_context) + logger.info("VectorStoreIndex loaded successfully.") + + # 将用户查询通过本地模型生成嵌入向量 + query_vector = embedder.encode([request.query]) # 使用本地模型生成查询的嵌入向量 + logger.info(f"Generated query embedding: {query_vector}") + + # 使用 VectorStoreIndex 创建查询引擎 + query_engine = index.as_query_engine() + logger.info(f"Query engine created successfully.") + + # 执行查询 + response = query_engine.query(request.query) + logger.info(f"Query response: {response}") + + # 打印出每个结果的向量和文本 + for i, node in enumerate(response): + logger.info(f"Result {i+1}:") + logger.info(f" Text: {node.get_content()}") # 打印文本 + + # 打印向量及其长度 + embedding = node.embedding + logger.info(f" Embedding (Vector): {embedding}") # 打印向量 + logger.info(f" Embedding Length: {len(embedding)}") # 打印向量的长度(即向量的维度) + + # 返回检索结果 + result = { + "user_id": user_id, + "query": request.query, + "results": [ + {"score": float(node.score or 0), "text": node.get_content()} # 确保从 Node 中获取文本 + for node in response + ] + } + + logger.info(f"Search results for user {user_id}: {result}") + return result + + except Exception as e: + logger.error(f"Error processing search request: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e))