diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index 3540dd4..2b8035c 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -15,6 +15,7 @@ import faiss from typing import List import asyncio from sentence_transformers import SentenceTransformer +from llama_index.embeddings.huggingface import HuggingFaceEmbedding USER_INDEX_PATH = "index_data" USER_DOC_PATH = "docs" @@ -54,7 +55,7 @@ def build_user_index(user_id: str): raise FileNotFoundError(f"文档目录不存在: {doc_dir}") documents = SimpleDirectoryReader(doc_dir).load_data() - embed_model = CustomEmbedding(model_name=settings.MODEL_NAME) + embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model) @@ -69,7 +70,7 @@ def build_user_index(user_id: str): print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}") def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: - embed_model = CustomEmbedding(model_name=settings.MODEL_NAME) + embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model) all_nodes = []