103 lines
4.4 KiB
Python
103 lines
4.4 KiB
Python
from fastapi import HTTPException # 导入 HTTPException 用于错误处理
|
||
from typing import List # 导入 List 用于类型注解
|
||
import os
|
||
import logging
|
||
import faiss
|
||
from llama_index import SimpleDirectoryReader, VectorStoreIndex, ServiceContext, StorageContext
|
||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||
from app.core.config import settings # 导入应用配置
|
||
from scripts.permissions import get_user_allowed_indexes # 导入权限函数,管理用户索引
|
||
|
||
USER_INDEX_PATH = "index_data" # 用户索引存储路径
|
||
USER_DOC_PATH = "docs" # 用户文档存储路径
|
||
|
||
# BGEEmbedding 类,继承自 HuggingFaceEmbedding,用于生成查询的嵌入
|
||
class BGEEmbedding(HuggingFaceEmbedding):
|
||
def _get_query_embedding(self, query: str) -> List[float]:
|
||
# 在查询前加上前缀,生成嵌入向量
|
||
prefix = "Represent this sentence for searching relevant passages: "
|
||
return super()._get_query_embedding(prefix + query)
|
||
|
||
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
|
||
# 批量生成嵌入向量
|
||
prefix = "Represent this sentence for searching relevant passages: "
|
||
return super()._get_query_embeddings([prefix + q for q in queries])
|
||
|
||
def build_user_index(user_id: str):
|
||
# 设置日志记录器,记录操作步骤和错误信息
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(logging.INFO)
|
||
|
||
logger.info(f"开始为用户 {user_id} 构建索引...")
|
||
|
||
# 确认文档目录是否存在
|
||
doc_dir = os.path.join(USER_DOC_PATH, user_id)
|
||
if not os.path.exists(doc_dir):
|
||
logger.error(f"文档目录不存在: {doc_dir}")
|
||
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
|
||
|
||
logger.info(f"发现文档目录: {doc_dir}")
|
||
|
||
# 使用 SimpleDirectoryReader 载入文档
|
||
try:
|
||
documents = SimpleDirectoryReader(doc_dir).load_data()
|
||
logger.info(f"载入文档数量: {len(documents)}")
|
||
except Exception as e:
|
||
logger.error(f"加载文档时出错: {e}")
|
||
raise HTTPException(status_code=500, detail="文档加载失败")
|
||
|
||
# 设置嵌入模型
|
||
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
|
||
logger.info(f"使用模型: {settings.MODEL_NAME}")
|
||
|
||
# 创建服务上下文
|
||
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
|
||
|
||
# 使用 Faiss 向量存储
|
||
faiss_index = faiss.IndexFlatL2(1024)
|
||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||
|
||
# 确保索引保存路径存在,使用用户 ID 区分索引文件
|
||
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
|
||
logger.info(f"索引保存路径: {persist_dir}")
|
||
|
||
# 如果目录不存在,则创建
|
||
if not os.path.exists(persist_dir):
|
||
logger.info(f"目录 {persist_dir} 不存在,准备创建")
|
||
os.makedirs(persist_dir, exist_ok=True)
|
||
logger.info(f"目录 {persist_dir} 已创建")
|
||
|
||
# 确保 index_store.json 文件路径存在
|
||
index_store_path = os.path.join(persist_dir, "index_store.json")
|
||
if not os.path.exists(index_store_path):
|
||
logger.info(f"未找到 index_store.json,准备创建")
|
||
try:
|
||
with open(index_store_path, "w") as f:
|
||
f.write("{}") # 创建空的 index_store.json 文件
|
||
logger.info(f"已创建 index_store.json 文件")
|
||
except Exception as e:
|
||
logger.error(f"创建 index_store.json 时出错: {e}")
|
||
raise HTTPException(status_code=500, detail="创建 index_store.json 文件失败")
|
||
else:
|
||
logger.info(f"已找到 index_store.json,跳过创建")
|
||
|
||
# 创建 StorageContext,用于存储和管理索引数据
|
||
storage_context = StorageContext.from_defaults(
|
||
persist_dir=persist_dir,
|
||
vector_store=vector_store,
|
||
)
|
||
|
||
try:
|
||
# 构建索引,并使用 `storage_context.persist()` 方法保存索引
|
||
index = VectorStoreIndex.from_documents(
|
||
documents,
|
||
service_context=service_context,
|
||
storage_context=storage_context
|
||
)
|
||
storage_context.persist(persist_dir=persist_dir) # 使用 `storage_context.persist()` 保存索引
|
||
logger.info(f"索引已保存到 {persist_dir}")
|
||
except Exception as e:
|
||
logger.error(f"索引构建失败: {e}")
|
||
raise HTTPException(status_code=500, detail="索引构建失败")
|