This commit is contained in:
parent
ef7871eabd
commit
3a1fc39c48
|
|
@ -1,4 +1,4 @@
|
||||||
from pydantic import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
import os
|
import os
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
|
@ -6,7 +6,7 @@ class Settings(BaseSettings):
|
||||||
EMBEDDING_DIM: int = 768
|
EMBEDDING_DIM: int = 768
|
||||||
TOP_K: int = 5
|
TOP_K: int = 5
|
||||||
DOC_PATH: str = "docs/"
|
DOC_PATH: str = "docs/"
|
||||||
DEVICE: str = "cpu" # 可设置为 cuda:0
|
DEVICE: str = "cpu"
|
||||||
MODEL_NAME: str = "BAAI/bge-m3"
|
MODEL_NAME: str = "BAAI/bge-m3"
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
|
||||||
|
|
@ -5,4 +5,5 @@ gunicorn
|
||||||
pydantic
|
pydantic
|
||||||
numpy
|
numpy
|
||||||
transformers
|
transformers
|
||||||
torch
|
torch
|
||||||
|
llama-index==0.12.34
|
||||||
|
|
@ -9,10 +9,9 @@ from llama_index import (
|
||||||
)
|
)
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
from llama_index.llms.base import ChatMessage
|
from app.core.config import settings
|
||||||
|
from scripts.permissions import get_user_allowed_indexes
|
||||||
|
|
||||||
# 假设你要用的本地嵌入模型
|
|
||||||
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
|
||||||
USER_INDEX_PATH = "index_data"
|
USER_INDEX_PATH = "index_data"
|
||||||
USER_DOC_PATH = "docs"
|
USER_DOC_PATH = "docs"
|
||||||
|
|
||||||
|
|
@ -22,7 +21,7 @@ def build_user_index(user_id: str):
|
||||||
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
|
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(doc_dir).load_data()
|
documents = SimpleDirectoryReader(doc_dir).load_data()
|
||||||
embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_NAME)
|
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
||||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||||
|
|
||||||
# 构建向量索引
|
# 构建向量索引
|
||||||
|
|
@ -37,22 +36,28 @@ def build_user_index(user_id: str):
|
||||||
faiss.write_index(index.vector_store.index, index_path)
|
faiss.write_index(index.vector_store.index, index_path)
|
||||||
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
|
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
|
||||||
|
|
||||||
from scripts.permissions import get_user_allowed_indexes
|
|
||||||
|
|
||||||
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
||||||
if not os.path.exists(index_path):
|
if not os.path.exists(index_path):
|
||||||
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
|
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
|
||||||
|
|
||||||
embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_NAME)
|
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
||||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||||
|
|
||||||
# 加载索引
|
# 加载主索引
|
||||||
vector_store = FaissVectorStore.from_persist_path(index_path)
|
vector_store = FaissVectorStore.from_persist_path(index_path)
|
||||||
index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context)
|
index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context)
|
||||||
|
|
||||||
retriever = index.as_retriever(similarity_top_k=top_k)
|
nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||||
nodes = retriever.retrieve(question)
|
|
||||||
|
# 加载权限范围内的共享索引
|
||||||
|
shared_indexes = get_user_allowed_indexes(user_id)
|
||||||
|
for shared_name in shared_indexes:
|
||||||
|
shared_path = os.path.join(USER_INDEX_PATH, shared_name)
|
||||||
|
if os.path.exists(shared_path):
|
||||||
|
shared_store = FaissVectorStore.from_persist_path(shared_path)
|
||||||
|
shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context)
|
||||||
|
nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||||
|
|
||||||
# 构造 Prompt
|
# 构造 Prompt
|
||||||
context_str = "\n\n".join([n.get_content() for n in nodes])
|
context_str = "\n\n".join([n.get_content() for n in nodes])
|
||||||
|
|
@ -73,4 +78,4 @@ if __name__ == "__main__":
|
||||||
build_user_index(uid)
|
build_user_index(uid)
|
||||||
prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?")
|
prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?")
|
||||||
print("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n")
|
print("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n")
|
||||||
print(prompt)
|
print(prompt)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue