This commit is contained in:
parent
a3b6021762
commit
cb949ac721
|
|
@ -16,7 +16,7 @@ from scripts.permissions import get_user_allowed_indexes
|
||||||
USER_INDEX_PATH = "index_data"
|
USER_INDEX_PATH = "index_data"
|
||||||
USER_DOC_PATH = "docs"
|
USER_DOC_PATH = "docs"
|
||||||
|
|
||||||
# ✅ 自动加前缀的 bge-m3 embedding
|
# ✅ 自动加前缀的 BGE-m3 embedding 封装类
|
||||||
class BGEEmbedding(HuggingFaceEmbedding):
|
class BGEEmbedding(HuggingFaceEmbedding):
|
||||||
def _get_query_embedding(self, query: str) -> List[float]:
|
def _get_query_embedding(self, query: str) -> List[float]:
|
||||||
prefix = "Represent this sentence for searching relevant passages: "
|
prefix = "Represent this sentence for searching relevant passages: "
|
||||||
|
|
@ -38,16 +38,19 @@ def build_user_index(user_id: str):
|
||||||
faiss_index = faiss.IndexFlatL2(1024)
|
faiss_index = faiss.IndexFlatL2(1024)
|
||||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
|
persist_dir = os.path.join(USER_INDEX_PATH, user_id)
|
||||||
|
os.makedirs(persist_dir, exist_ok=True)
|
||||||
|
|
||||||
storage_context = StorageContext.from_defaults(
|
storage_context = StorageContext.from_defaults(
|
||||||
persist_dir=persist_dir,
|
persist_dir=persist_dir,
|
||||||
vector_store=vector_store
|
vector_store=vector_store,
|
||||||
)
|
)
|
||||||
|
|
||||||
index = VectorStoreIndex.from_documents(
|
index = VectorStoreIndex.from_documents(
|
||||||
documents,
|
documents,
|
||||||
storage_context=storage_context,
|
service_context=service_context,
|
||||||
service_context=service_context
|
storage_context=storage_context
|
||||||
)
|
)
|
||||||
|
|
||||||
index.persist(persist_dir=persist_dir)
|
index.persist(persist_dir=persist_dir)
|
||||||
print(f"[BUILD] 为用户 {user_id} 构建并保存了完整索引 → {persist_dir}")
|
print(f"[BUILD] 为用户 {user_id} 构建并保存了完整索引 → {persist_dir}")
|
||||||
|
|
||||||
|
|
@ -65,15 +68,12 @@ def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
all_nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
all_nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||||
|
|
||||||
shared_indexes = get_user_allowed_indexes(user_id)
|
shared_indexes = get_user_allowed_indexes(user_id)
|
||||||
if shared_indexes:
|
for shared_name in shared_indexes:
|
||||||
for shared_name in shared_indexes:
|
shared_dir = os.path.join(USER_INDEX_PATH, shared_name)
|
||||||
shared_path = os.path.join(USER_INDEX_PATH, shared_name)
|
if os.path.exists(shared_dir) and shared_dir != persist_dir:
|
||||||
if os.path.exists(shared_path) and shared_path != persist_dir:
|
shared_context = StorageContext.from_defaults(persist_dir=shared_dir)
|
||||||
shared_storage = StorageContext.from_defaults(persist_dir=shared_path)
|
shared_index = VectorStoreIndex.load_from_storage(shared_context, service_context=service_context)
|
||||||
shared_index = VectorStoreIndex.load_from_storage(shared_storage, service_context=service_context)
|
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||||
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
|
||||||
else:
|
|
||||||
print(f"[INFO] 用户 {user_id} 没有共享索引权限")
|
|
||||||
|
|
||||||
sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
|
sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
|
||||||
top_nodes = sorted_nodes[:top_k]
|
top_nodes = sorted_nodes[:top_k]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue