This commit is contained in:
parent
ce38e21219
commit
810ba37217
|
|
@ -1,23 +1,16 @@
|
|||
import os
|
||||
import logging
|
||||
from typing import List
|
||||
import faiss
|
||||
from llama_index import (
|
||||
SimpleDirectoryReader,
|
||||
VectorStoreIndex,
|
||||
ServiceContext,
|
||||
StorageContext,
|
||||
PromptTemplate
|
||||
)
|
||||
from llama_index import SimpleDirectoryReader, VectorStoreIndex, ServiceContext, StorageContext
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
from app.core.config import settings # 确保配置导入
|
||||
from scripts.permissions import get_user_allowed_indexes # 确保权限导入
|
||||
from scripts.permissions import get_user_allowed_indexes # 导入权限函数
|
||||
|
||||
USER_INDEX_PATH = "index_data"
|
||||
USER_DOC_PATH = "docs"
|
||||
|
||||
# ✅ 自动加前缀的 BGE-m3 embedding 封装类
|
||||
# BGEEmbedding 类
|
||||
class BGEEmbedding(HuggingFaceEmbedding):
|
||||
def _get_query_embedding(self, query: str) -> List[float]:
|
||||
prefix = "Represent this sentence for searching relevant passages: "
|
||||
|
|
@ -34,14 +27,21 @@ def build_user_index(user_id: str):
|
|||
|
||||
logger.info(f"开始为用户 {user_id} 构建索引...")
|
||||
|
||||
# 确认文档目录
|
||||
doc_dir = os.path.join(USER_DOC_PATH, user_id)
|
||||
if not os.path.exists(doc_dir):
|
||||
logger.error(f"文档目录不存在: {doc_dir}")
|
||||
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
|
||||
|
||||
logger.info(f"发现文档目录: {doc_dir}")
|
||||
|
||||
documents = SimpleDirectoryReader(doc_dir).load_data()
|
||||
logger.info(f"载入文档数量: {len(documents)}")
|
||||
# 载入文档
|
||||
try:
|
||||
documents = SimpleDirectoryReader(doc_dir).load_data()
|
||||
logger.info(f"载入文档数量: {len(documents)}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载文档时出错: {e}")
|
||||
raise HTTPException(status_code=500, detail="文档加载失败")
|
||||
|
||||
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
|
||||
logger.info(f"使用模型: {settings.MODEL_NAME}")
|
||||
|
|
@ -51,17 +51,31 @@ def build_user_index(user_id: str):
|
|||
faiss_index = faiss.IndexFlatL2(1024)
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
|
||||
# 确保索引路径存在
|
||||
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
|
||||
os.makedirs(persist_dir, exist_ok=True)
|
||||
logger.info(f"索引保存路径: {persist_dir}")
|
||||
|
||||
# 检查目录中是否存在 index_store.json 文件
|
||||
|
||||
# 如果目录不存在,创建它
|
||||
if not os.path.exists(persist_dir):
|
||||
logger.info(f"目录 {persist_dir} 不存在,准备创建")
|
||||
os.makedirs(persist_dir, exist_ok=True)
|
||||
logger.info(f"目录 {persist_dir} 已创建")
|
||||
|
||||
# 确保 index_store.json 文件路径存在
|
||||
index_store_path = os.path.join(persist_dir, "index_store.json")
|
||||
if not os.path.exists(index_store_path):
|
||||
logger.info(f"未找到 index_store.json,准备创建")
|
||||
try:
|
||||
with open(index_store_path, "w") as f:
|
||||
f.write("{}") # 创建空的 index_store.json 文件
|
||||
logger.info(f"已创建 index_store.json 文件")
|
||||
except Exception as e:
|
||||
logger.error(f"创建 index_store.json 时出错: {e}")
|
||||
raise HTTPException(status_code=500, detail="创建 index_store.json 文件失败")
|
||||
else:
|
||||
logger.info(f"已找到 index_store.json,跳过创建")
|
||||
|
||||
# 创建 storage context
|
||||
storage_context = StorageContext.from_defaults(
|
||||
persist_dir=persist_dir,
|
||||
vector_store=vector_store,
|
||||
|
|
|
|||
Loading…
Reference in New Issue