This commit is contained in:
hailin 2025-05-11 11:45:48 +08:00
parent b1cf0e61e1
commit 2afc6e6a95
2 changed files with 7 additions and 114 deletions

View File

@ -1,6 +1,6 @@
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from app.core.embedding import embedder from app.core.embedding import embedder # 使用已加载的本地嵌入模型BGEEmbedding
from app.core.config import settings from app.core.config import settings
from llama_index.vector_stores.faiss import FaissVectorStore from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index import VectorStoreIndex, ServiceContext, StorageContext, load_index_from_storage from llama_index import VectorStoreIndex, ServiceContext, StorageContext, load_index_from_storage
@ -56,15 +56,15 @@ def search_docs(request: QueryRequest, user_id: str = Query(..., description="
index = load_index_from_storage(storage_context) index = load_index_from_storage(storage_context)
logger.info("VectorStoreIndex loaded successfully.") logger.info("VectorStoreIndex loaded successfully.")
# 检索结果(包含文本) # 将用户查询通过本地模型生成嵌入向量
retriever = index.as_retriever(similarity_top_k=settings.TOP_K) query_vector = embedder.encode([request.query]) # 使用本地模型生成查询的嵌入向量
logger.info(f"Generated query embedding: {query_vector}")
# 使用 FaissVectorStore 检索最相似的节点
retriever = vector_store.as_retriever(similarity_top_k=settings.TOP_K, query_embedding=query_vector)
logger.info(f"Retrieving top {settings.TOP_K} results for query: {request.query}") logger.info(f"Retrieving top {settings.TOP_K} results for query: {request.query}")
nodes = retriever.retrieve(request.query) nodes = retriever.retrieve(request.query)
# 打印出每个结果的向量和文本 # 打印出每个结果的向量和文本
for i, node in enumerate(nodes): for i, node in enumerate(nodes):
# 打印文本 # 打印文本
@ -76,19 +76,6 @@ def search_docs(request: QueryRequest, user_id: str = Query(..., description="
logger.info(f" Embedding (Vector): {embedding}") # 打印向量 logger.info(f" Embedding (Vector): {embedding}") # 打印向量
logger.info(f" Embedding Length: {len(embedding)}") # 打印向量的长度(即向量的维度) logger.info(f" Embedding Length: {len(embedding)}") # 打印向量的长度(即向量的维度)
# 返回检索结果 # 返回检索结果
result = { result = {
"user_id": user_id, "user_id": user_id,

View File

@ -1,94 +0,0 @@
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from app.core.embedding import embedder # 使用已加载的本地嵌入模型BGEEmbedding
from app.core.config import settings
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index import VectorStoreIndex, ServiceContext, StorageContext, load_index_from_storage
import os
import logging
import faiss # 引入faiss
router = APIRouter()
# 设置日志记录器
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class QueryRequest(BaseModel):
query: str
@router.post("/search")
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
try:
logger.info(f"Received search request from user: {user_id} with query: {request.query}")
# 修正后的索引路径,确保指向整个目录,而不是单个文件
index_path = os.path.join("index_data", user_id) # 使用整个目录路径
logger.info(f"Looking for index at path: {index_path}")
# 检查索引目录是否存在
if not os.path.exists(index_path):
logger.error(f"Index not found for user: {user_id} at {index_path}")
raise HTTPException(status_code=404, detail="用户索引不存在")
# 加载 Faiss 索引
faiss_index_file = os.path.join(index_path, "index.faiss") # 指定faiss索引文件路径
if not os.path.exists(faiss_index_file):
logger.error(f"Faiss index not found at {faiss_index_file}")
raise HTTPException(status_code=404, detail="Faiss索引文件未找到")
faiss_index = faiss.read_index(faiss_index_file) # 使用faiss加载索引文件
logger.info("Faiss index loaded successfully.")
# 创建 FaissVectorStore 实例
vector_store = FaissVectorStore(faiss_index=faiss_index)
logger.info("FaissVectorStore created successfully.")
# 创建 StorageContext 实例(确保同时加载文本和向量)
storage_context = StorageContext.from_defaults(persist_dir=index_path, vector_store=vector_store)
logger.info("Storage context created successfully.")
# 创建 ServiceContext 实例
service_context = ServiceContext.from_defaults(embed_model=embedder, llm=None)
logger.info("Service context created successfully.")
# 使用 load_index_from_storage 加载索引
index = load_index_from_storage(storage_context)
logger.info("VectorStoreIndex loaded successfully.")
# 将用户查询通过本地模型生成嵌入向量
query_vector = embedder.encode([request.query]) # 使用本地模型生成查询的嵌入向量
logger.info(f"Generated query embedding: {query_vector}")
# 使用 FaissVectorStore 检索最相似的节点
retriever = vector_store.as_retriever(similarity_top_k=settings.TOP_K, query_embedding=query_vector)
logger.info(f"Retrieving top {settings.TOP_K} results for query: {request.query}")
nodes = retriever.retrieve(request.query)
# 打印出每个结果的向量和文本
for i, node in enumerate(nodes):
# 打印文本
logger.info(f"Result {i+1}:")
logger.info(f" Text: {node.get_content()}") # 打印文本
# 打印向量及其长度
embedding = node.embedding
logger.info(f" Embedding (Vector): {embedding}") # 打印向量
logger.info(f" Embedding Length: {len(embedding)}") # 打印向量的长度(即向量的维度)
# 返回检索结果
result = {
"user_id": user_id,
"query": request.query,
"results": [
{"score": float(node.score or 0), "text": node.get_content()} # 确保从 Node 中获取文本
for node in nodes
]
}
logger.info(f"Search results for user {user_id}: {result}")
return result
except Exception as e:
logger.error(f"Error processing search request: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))