diff --git a/app/api/search copy.py b/app/api/search copy.py index cb3c4ce..ae12f5f 100644 --- a/app/api/search copy.py +++ b/app/api/search copy.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel -from app.core.embedding import embedder +from app.core.embedding import embedder # 使用已加载的本地嵌入模型(BGEEmbedding) from app.core.config import settings from llama_index.vector_stores.faiss import FaissVectorStore from llama_index import VectorStoreIndex, ServiceContext, StorageContext, load_index_from_storage @@ -56,15 +56,15 @@ def search_docs(request: QueryRequest, user_id: str = Query(..., description=" index = load_index_from_storage(storage_context) logger.info("VectorStoreIndex loaded successfully.") - # 检索结果(包含文本) - retriever = index.as_retriever(similarity_top_k=settings.TOP_K) + # 将用户查询通过本地模型生成嵌入向量 + query_vector = embedder.encode([request.query]) # 使用本地模型生成查询的嵌入向量 + logger.info(f"Generated query embedding: {query_vector}") + + # 使用 FaissVectorStore 检索最相似的节点 + retriever = vector_store.as_retriever(similarity_top_k=settings.TOP_K, query_embedding=query_vector) logger.info(f"Retrieving top {settings.TOP_K} results for query: {request.query}") nodes = retriever.retrieve(request.query) - - - - # 打印出每个结果的向量和文本 for i, node in enumerate(nodes): # 打印文本 @@ -76,19 +76,6 @@ def search_docs(request: QueryRequest, user_id: str = Query(..., description=" logger.info(f" Embedding (Vector): {embedding}") # 打印向量 logger.info(f" Embedding Length: {len(embedding)}") # 打印向量的长度(即向量的维度) - - - - - - - - - - - - - # 返回检索结果 result = { "user_id": user_id, diff --git a/app/api/search.py b/app/api/search.py index ae12f5f..e69de29 100644 --- a/app/api/search.py +++ b/app/api/search.py @@ -1,94 +0,0 @@ -from fastapi import APIRouter, HTTPException, Query -from pydantic import BaseModel -from app.core.embedding import embedder # 使用已加载的本地嵌入模型(BGEEmbedding) -from app.core.config import settings -from llama_index.vector_stores.faiss import FaissVectorStore -from llama_index import VectorStoreIndex, ServiceContext, StorageContext, load_index_from_storage -import os -import logging -import faiss # 引入faiss - -router = APIRouter() - -# 设置日志记录器 -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -class QueryRequest(BaseModel): - query: str - -@router.post("/search") -def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")): - try: - logger.info(f"Received search request from user: {user_id} with query: {request.query}") - - # 修正后的索引路径,确保指向整个目录,而不是单个文件 - index_path = os.path.join("index_data", user_id) # 使用整个目录路径 - logger.info(f"Looking for index at path: {index_path}") - - # 检查索引目录是否存在 - if not os.path.exists(index_path): - logger.error(f"Index not found for user: {user_id} at {index_path}") - raise HTTPException(status_code=404, detail="用户索引不存在") - - # 加载 Faiss 索引 - faiss_index_file = os.path.join(index_path, "index.faiss") # 指定faiss索引文件路径 - if not os.path.exists(faiss_index_file): - logger.error(f"Faiss index not found at {faiss_index_file}") - raise HTTPException(status_code=404, detail="Faiss索引文件未找到") - - faiss_index = faiss.read_index(faiss_index_file) # 使用faiss加载索引文件 - logger.info("Faiss index loaded successfully.") - - # 创建 FaissVectorStore 实例 - vector_store = FaissVectorStore(faiss_index=faiss_index) - logger.info("FaissVectorStore created successfully.") - - # 创建 StorageContext 实例(确保同时加载文本和向量) - storage_context = StorageContext.from_defaults(persist_dir=index_path, vector_store=vector_store) - logger.info("Storage context created successfully.") - - # 创建 ServiceContext 实例 - service_context = ServiceContext.from_defaults(embed_model=embedder, llm=None) - logger.info("Service context created successfully.") - - # 使用 load_index_from_storage 加载索引 - index = load_index_from_storage(storage_context) - logger.info("VectorStoreIndex loaded successfully.") - - # 将用户查询通过本地模型生成嵌入向量 - query_vector = embedder.encode([request.query]) # 使用本地模型生成查询的嵌入向量 - logger.info(f"Generated query embedding: {query_vector}") - - # 使用 FaissVectorStore 检索最相似的节点 - retriever = vector_store.as_retriever(similarity_top_k=settings.TOP_K, query_embedding=query_vector) - logger.info(f"Retrieving top {settings.TOP_K} results for query: {request.query}") - nodes = retriever.retrieve(request.query) - - # 打印出每个结果的向量和文本 - for i, node in enumerate(nodes): - # 打印文本 - logger.info(f"Result {i+1}:") - logger.info(f" Text: {node.get_content()}") # 打印文本 - - # 打印向量及其长度 - embedding = node.embedding - logger.info(f" Embedding (Vector): {embedding}") # 打印向量 - logger.info(f" Embedding Length: {len(embedding)}") # 打印向量的长度(即向量的维度) - - # 返回检索结果 - result = { - "user_id": user_id, - "query": request.query, - "results": [ - {"score": float(node.score or 0), "text": node.get_content()} # 确保从 Node 中获取文本 - for node in nodes - ] - } - - logger.info(f"Search results for user {user_id}: {result}") - return result - - except Exception as e: - logger.error(f"Error processing search request: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=str(e))