85 lines
3.4 KiB
Python
85 lines
3.4 KiB
Python
from fastapi import APIRouter, HTTPException, Query
|
|
from pydantic import BaseModel
|
|
from app.core.embedding import embedder
|
|
from app.core.config import settings
|
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
|
from llama_index import VectorStoreIndex, ServiceContext, StorageContext
|
|
import os
|
|
import logging
|
|
import chardet
|
|
|
|
router = APIRouter()
|
|
|
|
# 设置日志记录器
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class QueryRequest(BaseModel):
|
|
query: str
|
|
|
|
# 自动检测文件编码并加载
|
|
def read_file_with_detected_encoding(file_path: str):
|
|
with open(file_path, "rb") as f:
|
|
raw_data = f.read()
|
|
result = chardet.detect(raw_data)
|
|
encoding = result['encoding'] # 获取检测到的编码
|
|
with open(file_path, "r", encoding=encoding) as f:
|
|
return f.read()
|
|
|
|
@router.post("/search")
|
|
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
|
|
try:
|
|
logger.info(f"Received search request from user: {user_id} with query: {request.query}")
|
|
|
|
# 修正后的索引路径,确保指向整个目录,而不是单个文件
|
|
index_path = os.path.join("index_data", user_id) # 使用整个目录路径
|
|
logger.info(f"Looking for index at path: {index_path}")
|
|
|
|
# 检查索引目录是否存在
|
|
if not os.path.exists(index_path):
|
|
logger.error(f"Index not found for user: {user_id} at {index_path}")
|
|
raise HTTPException(status_code=404, detail="用户索引不存在")
|
|
|
|
# 创建 StorageContext 并加载 Faiss 向量存储目录
|
|
logger.info(f"Loading Faiss vector store from path: {index_path}")
|
|
storage_context = StorageContext.from_defaults(persist_dir=index_path)
|
|
|
|
# 自动读取文件内容并解码
|
|
vector_store_file = os.path.join(index_path, "vector_store.json")
|
|
if os.path.exists(vector_store_file):
|
|
file_content = read_file_with_detected_encoding(vector_store_file)
|
|
logger.info(f"Successfully read vector_store.json with detected encoding.")
|
|
else:
|
|
logger.error(f"vector_store.json not found at {vector_store_file}")
|
|
raise HTTPException(status_code=404, detail="vector_store.json not found")
|
|
|
|
# 加载 Faiss 向量存储
|
|
faiss_store = FaissVectorStore.from_persist_path(storage_context)
|
|
service_context = ServiceContext.from_defaults(embed_model=embedder, llm=None)
|
|
logger.info("Service context created successfully.")
|
|
|
|
index = VectorStoreIndex.from_vector_store(faiss_store, service_context=service_context)
|
|
logger.info("VectorStoreIndex created successfully.")
|
|
|
|
# 检索结果(真实文本)
|
|
retriever = index.as_retriever(similarity_top_k=settings.TOP_K)
|
|
logger.info(f"Retrieving top {settings.TOP_K} results for query: {request.query}")
|
|
nodes = retriever.retrieve(request.query)
|
|
|
|
# 返回检索结果
|
|
result = {
|
|
"user_id": user_id,
|
|
"query": request.query,
|
|
"results": [
|
|
{"score": float(node.score or 0), "text": node.get_content()}
|
|
for node in nodes
|
|
]
|
|
}
|
|
|
|
logger.info(f"Search results for user {user_id}: {result}")
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing search request: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|