This commit is contained in:
parent
ae60c49fd1
commit
301795569c
|
|
@ -1,13 +1,16 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, ServiceContext, PromptTemplate
|
from llama_index.core import (
|
||||||
|
SimpleDirectoryReader,
|
||||||
|
VectorStoreIndex,
|
||||||
|
PromptTemplate,
|
||||||
|
Settings,
|
||||||
|
)
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from scripts.permissions import get_user_allowed_indexes
|
from scripts.permissions import get_user_allowed_indexes
|
||||||
|
import faiss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
USER_INDEX_PATH = "index_data"
|
USER_INDEX_PATH = "index_data"
|
||||||
USER_DOC_PATH = "docs"
|
USER_DOC_PATH = "docs"
|
||||||
|
|
@ -19,11 +22,10 @@ def build_user_index(user_id: str):
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(doc_dir).load_data()
|
documents = SimpleDirectoryReader(doc_dir).load_data()
|
||||||
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
||||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
Settings.embed_model = embed_model # ✅ 新式配置
|
||||||
|
|
||||||
index = VectorStoreIndex.from_documents(
|
index = VectorStoreIndex.from_documents(
|
||||||
documents,
|
documents,
|
||||||
service_context=service_context,
|
|
||||||
vector_store=FaissVectorStore()
|
vector_store=FaissVectorStore()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -33,7 +35,7 @@ def build_user_index(user_id: str):
|
||||||
|
|
||||||
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
||||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
Settings.embed_model = embed_model # ✅ 全局设置一次即可
|
||||||
|
|
||||||
all_nodes = []
|
all_nodes = []
|
||||||
|
|
||||||
|
|
@ -42,7 +44,7 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
if not os.path.exists(index_path):
|
if not os.path.exists(index_path):
|
||||||
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
|
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
|
||||||
user_store = FaissVectorStore.from_persist_path(index_path)
|
user_store = FaissVectorStore.from_persist_path(index_path)
|
||||||
user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context)
|
user_index = VectorStoreIndex.from_vector_store(user_store)
|
||||||
all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||||
|
|
||||||
# 加载共享索引
|
# 加载共享索引
|
||||||
|
|
@ -52,7 +54,7 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
shared_path = os.path.join(USER_INDEX_PATH, shared_name)
|
shared_path = os.path.join(USER_INDEX_PATH, shared_name)
|
||||||
if os.path.exists(shared_path) and shared_path != index_path:
|
if os.path.exists(shared_path) and shared_path != index_path:
|
||||||
shared_store = FaissVectorStore.from_persist_path(shared_path)
|
shared_store = FaissVectorStore.from_persist_path(shared_path)
|
||||||
shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context)
|
shared_index = VectorStoreIndex.from_vector_store(shared_store)
|
||||||
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||||
else:
|
else:
|
||||||
print(f"[INFO] 用户 {user_id} 没有共享索引权限")
|
print(f"[INFO] 用户 {user_id} 没有共享索引权限")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue