This commit is contained in:
parent
3a1fc39c48
commit
2d0a6c11ee
|
|
@ -1,12 +1,13 @@
|
|||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
import numpy as np
|
||||
from app.core.index_manager import UserIndexManager
|
||||
from app.core.embedding import embedder
|
||||
from app.core.config import settings
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
from llama_index import VectorStoreIndex, ServiceContext
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
import os
|
||||
|
||||
router = APIRouter()
|
||||
index_manager = UserIndexManager()
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
|
|
@ -14,13 +15,27 @@ class QueryRequest(BaseModel):
|
|||
@router.post("/search")
|
||||
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
|
||||
try:
|
||||
query_vector = embedder.encode([request.query])
|
||||
index = index_manager.get_index(user_id)
|
||||
D, I = index.search(query_vector, settings.TOP_K)
|
||||
index_path = os.path.join("index_data", f"{user_id}.index")
|
||||
if not os.path.exists(index_path):
|
||||
raise HTTPException(status_code=404, detail="用户索引不存在")
|
||||
|
||||
# 构建 LlamaIndex 检索器
|
||||
faiss_store = FaissVectorStore.from_persist_path(index_path)
|
||||
service_context = ServiceContext.from_defaults(embed_model=embedder)
|
||||
index = VectorStoreIndex.from_vector_store(faiss_store, service_context=service_context)
|
||||
|
||||
# 检索结果(真实文本)
|
||||
retriever = index.as_retriever(similarity_top_k=settings.TOP_K)
|
||||
nodes = retriever.retrieve(request.query)
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"query": request.query,
|
||||
"results": [{"id": int(idx), "score": float(dist)} for dist, idx in zip(D, I)]
|
||||
"results": [
|
||||
{"score": float(node.score or 0), "text": node.get_content()}
|
||||
for node in nodes
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
|
|||
|
|
@ -1,23 +1,32 @@
|
|||
from fastapi import APIRouter, UploadFile, File, Form
|
||||
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
|
||||
import os
|
||||
from shutil import copyfileobj
|
||||
import shutil
|
||||
from scripts.rag_build_query import build_user_index
|
||||
|
||||
router = APIRouter()
|
||||
ALLOWED_SUFFIXES = {".txt", ".md", ".pdf", ".docx"}
|
||||
|
||||
@router.post("/upload")
|
||||
def upload_user_file(user_id: str = Form(...), file: UploadFile = File(...)):
|
||||
filename = os.path.basename(file.filename)
|
||||
suffix = os.path.splitext(filename)[-1].lower()
|
||||
if suffix not in ALLOWED_SUFFIXES:
|
||||
raise HTTPException(status_code=400, detail="不支持的文件类型")
|
||||
|
||||
user_doc_dir = os.path.join("docs", user_id)
|
||||
os.makedirs(user_doc_dir, exist_ok=True)
|
||||
|
||||
file_path = os.path.join(user_doc_dir, file.filename)
|
||||
with open(file_path, "wb") as f:
|
||||
copyfileobj(file.file, f)
|
||||
file_path = os.path.join(user_doc_dir, filename)
|
||||
try:
|
||||
with open(file_path, "wb") as f:
|
||||
shutil.copyfileobj(file.file, f)
|
||||
print(f"[UPLOAD] 文件已保存至 {file_path}")
|
||||
|
||||
print(f"[UPLOAD] 文件已保存至 {file_path}")
|
||||
build_user_index(user_id)
|
||||
print(f"[UPLOAD] 用户 {user_id} 的索引已重建")
|
||||
|
||||
# 自动重建用户索引
|
||||
build_user_index(user_id)
|
||||
print(f"[UPLOAD] 用户 {user_id} 的索引已重建")
|
||||
except Exception as e:
|
||||
print(f"[UPLOAD ERROR] {e}")
|
||||
raise HTTPException(status_code=500, detail="索引构建失败")
|
||||
|
||||
return {"status": "ok", "filename": file.filename}
|
||||
return {"status": "ok", "filename": filename}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,10 @@
|
|||
from pydantic_settings import BaseSettings
|
||||
import os
|
||||
|
||||
class Settings(BaseSettings):
|
||||
INDEX_FILE: str = os.getenv("INDEX_FILE", "index_data/index.faiss")
|
||||
EMBEDDING_DIM: int = 768
|
||||
TOP_K: int = 5
|
||||
DOC_PATH: str = "docs/"
|
||||
DEVICE: str = "cpu"
|
||||
MODEL_NAME: str = "BAAI/bge-m3"
|
||||
EMBEDDING_DIM: int = 768 # 嵌入维度(取决于模型)
|
||||
TOP_K: int = 5 # 默认检索 top K 个段落
|
||||
DOC_PATH: str = "docs/" # 默认文档根目录
|
||||
DEVICE: str = "cpu" # 可改为 "cuda" 使用 GPU
|
||||
MODEL_NAME: str = "BAAI/bge-m3" # 多语种语义嵌入模型
|
||||
|
||||
settings = Settings()
|
||||
|
|
|
|||
|
|
@ -5,21 +5,26 @@ from app.core.config import settings
|
|||
|
||||
class BGEEmbedding:
|
||||
def __init__(self):
|
||||
self.device = torch.device(settings.DEVICE)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
|
||||
self.model = AutoModel.from_pretrained(settings.MODEL_NAME)
|
||||
self.model = AutoModel.from_pretrained(settings.MODEL_NAME).to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def encode(self, texts):
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
||||
outputs = self.model(**inputs)
|
||||
embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
|
||||
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
||||
return embeddings.cpu().numpy()
|
||||
def encode(self, texts, batch_size=8):
|
||||
all_embeddings = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i+batch_size]
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt").to(self.device)
|
||||
outputs = self.model(**inputs)
|
||||
embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
|
||||
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
||||
all_embeddings.append(embeddings.cpu().numpy())
|
||||
return np.vstack(all_embeddings)
|
||||
|
||||
def mean_pooling(self, model_output, attention_mask):
|
||||
token_embeddings = model_output[0]
|
||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
|
||||
embedder = BGEEmbedding()
|
||||
embedder = BGEEmbedding()
|
||||
|
|
|
|||
|
|
@ -1,31 +1,28 @@
|
|||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import threading
|
||||
from app.core.config import settings
|
||||
|
||||
class FaissIndexWrapper:
|
||||
def __init__(self):
|
||||
def __init__(self, index_path: str):
|
||||
self.index_lock = threading.Lock()
|
||||
self.index = None
|
||||
self.load_index(settings.INDEX_FILE)
|
||||
self.index_path = index_path
|
||||
self.load_index(index_path)
|
||||
|
||||
def load_index(self, path):
|
||||
def load_index(self, path: str = None):
|
||||
with self.index_lock:
|
||||
path = path or self.index_path
|
||||
self.index = faiss.read_index(path)
|
||||
print(f"[FAISS] Index loaded from {path}")
|
||||
|
||||
def update_index(self, path):
|
||||
def update_index(self, path: str = None):
|
||||
"""热更新替换当前索引"""
|
||||
with self.index_lock:
|
||||
new_index = faiss.read_index(path)
|
||||
self.index = new_index
|
||||
path = path or self.index_path
|
||||
self.index = faiss.read_index(path)
|
||||
print(f"[FAISS] Index hot-swapped from {path}")
|
||||
|
||||
def search(self, vector: np.ndarray, top_k: int = 5):
|
||||
with self.index_lock:
|
||||
D, I = self.index.search(vector.astype(np.float32), top_k)
|
||||
return D[0], I[0]
|
||||
|
||||
# 单例,供 API 层引用
|
||||
faiss_index = FaissIndexWrapper()
|
||||
|
|
|
|||
|
|
@ -1,21 +1,38 @@
|
|||
import os
|
||||
import faiss
|
||||
import threading
|
||||
|
||||
class UserIndexManager:
|
||||
def __init__(self, base_dir="index_data/"):
|
||||
self.base_dir = base_dir
|
||||
self.index_map = {} # user_id -> faiss index
|
||||
self.index_map = {} # user_id -> faiss.Index
|
||||
self.index_lock = threading.Lock()
|
||||
self._load_all_indexes()
|
||||
|
||||
def _load_all_indexes(self):
|
||||
for fname in os.listdir(self.base_dir):
|
||||
if fname.endswith(".index"):
|
||||
user_id = fname.replace(".index", "")
|
||||
path = os.path.join(self.base_dir, fname)
|
||||
self.index_map[user_id] = faiss.read_index(path)
|
||||
print(f"[INIT] Loaded index for user {user_id}")
|
||||
self._load_single_index(user_id)
|
||||
|
||||
def _load_single_index(self, user_id):
|
||||
path = os.path.join(self.base_dir, f"{user_id}.index")
|
||||
if os.path.exists(path):
|
||||
index = faiss.read_index(path)
|
||||
self.index_map[user_id] = index
|
||||
print(f"[INIT] Loaded index for user: {user_id}")
|
||||
|
||||
def get_index(self, user_id):
|
||||
if user_id not in self.index_map:
|
||||
raise ValueError(f"Index for user {user_id} not loaded.")
|
||||
return self.index_map[user_id]
|
||||
with self.index_lock:
|
||||
if user_id not in self.index_map:
|
||||
raise FileNotFoundError(f"No FAISS index loaded for user: {user_id}")
|
||||
return self.index_map[user_id]
|
||||
|
||||
def update_index(self, user_id):
|
||||
"""重新加载用户索引文件"""
|
||||
path = os.path.join(self.base_dir, f"{user_id}.index")
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"No index file found for user: {user_id}")
|
||||
with self.index_lock:
|
||||
self.index_map[user_id] = faiss.read_index(path)
|
||||
print(f"[UPDATE] Index for user {user_id} has been hot-reloaded.")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import os
|
||||
import faiss
|
||||
import numpy as np
|
||||
import sys
|
||||
from app.core.embedding import embedder
|
||||
from app.core.config import settings
|
||||
|
||||
def load_documents(doc_folder):
|
||||
texts = []
|
||||
|
|
@ -20,12 +20,20 @@ def build_faiss_index(docs, dim):
|
|||
return index
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("[BUILD] Loading documents...")
|
||||
docs = load_documents(settings.DOC_PATH)
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python -m scripts.build_index <user_id>")
|
||||
sys.exit(1)
|
||||
|
||||
user_id = sys.argv[1]
|
||||
doc_dir = os.path.join("docs", user_id)
|
||||
index_path = os.path.join("index_data", f"{user_id}.index")
|
||||
|
||||
print(f"[BUILD] Loading documents from {doc_dir} ...")
|
||||
docs = load_documents(doc_dir)
|
||||
print(f"[BUILD] Loaded {len(docs)} documents")
|
||||
|
||||
print("[BUILD] Building FAISS index...")
|
||||
index = build_faiss_index(docs, settings.EMBEDDING_DIM)
|
||||
index = build_faiss_index(docs, dim=768) # 或根据模型配置动态设定
|
||||
|
||||
print(f"[BUILD] Saving index to {settings.INDEX_FILE}")
|
||||
faiss.write_index(index, settings.INDEX_FILE)
|
||||
print(f"[BUILD] Saving index to {index_path}")
|
||||
faiss.write_index(index, index_path)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
{
|
||||
"user_001": ["common.index", "engineering.index"],
|
||||
"user_002": ["common.index", "hr.index"],
|
||||
"user_003": []
|
||||
}
|
||||
"user_001": ["common.index", "engineering.index"], // 工程部员工
|
||||
"user_002": ["common.index", "hr.index"], // 人力资源员工
|
||||
"user_003": [] // 普通用户,无共享库
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,24 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
PERMISSION_PATH = "scripts/index_permissions.json"
|
||||
|
||||
def get_user_allowed_indexes(user_id: str):
|
||||
if not os.path.exists(PERMISSION_PATH):
|
||||
print(f"[Permission Warning] 权限配置文件不存在: {PERMISSION_PATH}")
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(PERMISSION_PATH, "r", encoding="utf-8") as f:
|
||||
permission_map = json.load(f)
|
||||
|
||||
if user_id not in permission_map:
|
||||
print(f"[Permission Info] 用户 {user_id} 无共享库配置,使用默认权限。")
|
||||
return permission_map.get(user_id, [])
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"[Permission Error] 权限配置文件 JSON 格式错误。")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"[Permission Error] {e}")
|
||||
return []
|
||||
print(f"[Permission Error] 未知错误: {e}")
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -24,43 +24,47 @@ def build_user_index(user_id: str):
|
|||
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||
|
||||
# 构建向量索引
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
service_context=service_context,
|
||||
vector_store=FaissVectorStore()
|
||||
)
|
||||
|
||||
# 保存为用户专属 .index 文件
|
||||
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
||||
faiss.write_index(index.vector_store.index, index_path)
|
||||
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
|
||||
|
||||
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
||||
if not os.path.exists(index_path):
|
||||
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
|
||||
|
||||
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||
|
||||
# 加载主索引
|
||||
vector_store = FaissVectorStore.from_persist_path(index_path)
|
||||
index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context)
|
||||
|
||||
nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||
all_nodes = []
|
||||
|
||||
# 加载权限范围内的共享索引
|
||||
# 加载用户主索引
|
||||
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
||||
if not os.path.exists(index_path):
|
||||
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
|
||||
user_store = FaissVectorStore.from_persist_path(index_path)
|
||||
user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context)
|
||||
all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||
|
||||
# 加载共享索引
|
||||
shared_indexes = get_user_allowed_indexes(user_id)
|
||||
for shared_name in shared_indexes:
|
||||
shared_path = os.path.join(USER_INDEX_PATH, shared_name)
|
||||
if os.path.exists(shared_path):
|
||||
shared_store = FaissVectorStore.from_persist_path(shared_path)
|
||||
shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context)
|
||||
nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||
if shared_indexes:
|
||||
for shared_name in shared_indexes:
|
||||
shared_path = os.path.join(USER_INDEX_PATH, shared_name)
|
||||
if os.path.exists(shared_path) and shared_path != index_path:
|
||||
shared_store = FaissVectorStore.from_persist_path(shared_path)
|
||||
shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context)
|
||||
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||
else:
|
||||
print(f"[INFO] 用户 {user_id} 没有共享索引权限")
|
||||
|
||||
# 构造 Prompt
|
||||
context_str = "\n\n".join([n.get_content() for n in nodes])
|
||||
# 合并 + 按 score 排序
|
||||
sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
|
||||
top_nodes = sorted_nodes[:top_k]
|
||||
|
||||
context_str = "\n\n".join([n.get_content() for n in top_nodes])
|
||||
prompt_template = PromptTemplate(
|
||||
"请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,23 +1,30 @@
|
|||
import time
|
||||
import os
|
||||
import sys
|
||||
from app.core.index import faiss_index
|
||||
from app.core.config import settings
|
||||
|
||||
def watch_and_reload(interval=600):
|
||||
print("[HOT-RELOAD] Watching index file for updates...")
|
||||
def watch_and_reload(user_id: str, interval=600):
|
||||
index_file = os.path.join("index_data", f"{user_id}.index")
|
||||
print(f"[HOT-RELOAD] Watching {index_file} for updates...")
|
||||
last_mtime = None
|
||||
|
||||
while True:
|
||||
try:
|
||||
mtime = os.path.getmtime(settings.INDEX_FILE)
|
||||
mtime = os.path.getmtime(index_file)
|
||||
if last_mtime is None:
|
||||
last_mtime = mtime
|
||||
elif mtime != last_mtime:
|
||||
print("[HOT-RELOAD] Detected new index file. Reloading...")
|
||||
faiss_index.update_index(settings.INDEX_FILE)
|
||||
faiss_index.update_index(index_file)
|
||||
last_mtime = mtime
|
||||
except Exception as e:
|
||||
print(f"[HOT-RELOAD] Error: {e}")
|
||||
time.sleep(interval)
|
||||
|
||||
if __name__ == "__main__":
|
||||
watch_and_reload()
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python -m scripts.update_index <user_id>")
|
||||
sys.exit(1)
|
||||
|
||||
user_id = sys.argv[1]
|
||||
watch_and_reload(user_id)
|
||||
|
|
|
|||
Loading…
Reference in New Issue