This commit is contained in:
hailin 2025-05-08 16:38:30 +08:00
parent 3a1fc39c48
commit 2d0a6c11ee
11 changed files with 162 additions and 90 deletions

View File

@ -1,12 +1,13 @@
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
import numpy as np
from app.core.index_manager import UserIndexManager
from app.core.embedding import embedder from app.core.embedding import embedder
from app.core.config import settings from app.core.config import settings
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index import VectorStoreIndex, ServiceContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import os
router = APIRouter() router = APIRouter()
index_manager = UserIndexManager()
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str query: str
@ -14,13 +15,27 @@ class QueryRequest(BaseModel):
@router.post("/search") @router.post("/search")
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")): def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
try: try:
query_vector = embedder.encode([request.query]) index_path = os.path.join("index_data", f"{user_id}.index")
index = index_manager.get_index(user_id) if not os.path.exists(index_path):
D, I = index.search(query_vector, settings.TOP_K) raise HTTPException(status_code=404, detail="用户索引不存在")
# 构建 LlamaIndex 检索器
faiss_store = FaissVectorStore.from_persist_path(index_path)
service_context = ServiceContext.from_defaults(embed_model=embedder)
index = VectorStoreIndex.from_vector_store(faiss_store, service_context=service_context)
# 检索结果(真实文本)
retriever = index.as_retriever(similarity_top_k=settings.TOP_K)
nodes = retriever.retrieve(request.query)
return { return {
"user_id": user_id, "user_id": user_id,
"query": request.query, "query": request.query,
"results": [{"id": int(idx), "score": float(dist)} for dist, idx in zip(D, I)] "results": [
{"score": float(node.score or 0), "text": node.get_content()}
for node in nodes
]
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View File

@ -1,23 +1,32 @@
from fastapi import APIRouter, UploadFile, File, Form from fastapi import APIRouter, UploadFile, File, Form, HTTPException
import os import os
from shutil import copyfileobj import shutil
from scripts.rag_build_query import build_user_index from scripts.rag_build_query import build_user_index
router = APIRouter() router = APIRouter()
ALLOWED_SUFFIXES = {".txt", ".md", ".pdf", ".docx"}
@router.post("/upload") @router.post("/upload")
def upload_user_file(user_id: str = Form(...), file: UploadFile = File(...)): def upload_user_file(user_id: str = Form(...), file: UploadFile = File(...)):
filename = os.path.basename(file.filename)
suffix = os.path.splitext(filename)[-1].lower()
if suffix not in ALLOWED_SUFFIXES:
raise HTTPException(status_code=400, detail="不支持的文件类型")
user_doc_dir = os.path.join("docs", user_id) user_doc_dir = os.path.join("docs", user_id)
os.makedirs(user_doc_dir, exist_ok=True) os.makedirs(user_doc_dir, exist_ok=True)
file_path = os.path.join(user_doc_dir, file.filename) file_path = os.path.join(user_doc_dir, filename)
with open(file_path, "wb") as f: try:
copyfileobj(file.file, f) with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
print(f"[UPLOAD] 文件已保存至 {file_path}")
print(f"[UPLOAD] 文件已保存至 {file_path}") build_user_index(user_id)
print(f"[UPLOAD] 用户 {user_id} 的索引已重建")
# 自动重建用户索引 except Exception as e:
build_user_index(user_id) print(f"[UPLOAD ERROR] {e}")
print(f"[UPLOAD] 用户 {user_id} 的索引已重建") raise HTTPException(status_code=500, detail="索引构建失败")
return {"status": "ok", "filename": file.filename} return {"status": "ok", "filename": filename}

View File

@ -1,12 +1,10 @@
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
import os
class Settings(BaseSettings): class Settings(BaseSettings):
INDEX_FILE: str = os.getenv("INDEX_FILE", "index_data/index.faiss") EMBEDDING_DIM: int = 768 # 嵌入维度(取决于模型)
EMBEDDING_DIM: int = 768 TOP_K: int = 5 # 默认检索 top K 个段落
TOP_K: int = 5 DOC_PATH: str = "docs/" # 默认文档根目录
DOC_PATH: str = "docs/" DEVICE: str = "cpu" # 可改为 "cuda" 使用 GPU
DEVICE: str = "cpu" MODEL_NAME: str = "BAAI/bge-m3" # 多语种语义嵌入模型
MODEL_NAME: str = "BAAI/bge-m3"
settings = Settings() settings = Settings()

View File

@ -5,21 +5,26 @@ from app.core.config import settings
class BGEEmbedding: class BGEEmbedding:
def __init__(self): def __init__(self):
self.device = torch.device(settings.DEVICE)
self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME) self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
self.model = AutoModel.from_pretrained(settings.MODEL_NAME) self.model = AutoModel.from_pretrained(settings.MODEL_NAME).to(self.device)
self.model.eval() self.model.eval()
def encode(self, texts): def encode(self, texts, batch_size=8):
with torch.no_grad(): all_embeddings = []
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt") for i in range(0, len(texts), batch_size):
outputs = self.model(**inputs) batch_texts = texts[i:i+batch_size]
embeddings = self.mean_pooling(outputs, inputs['attention_mask']) with torch.no_grad():
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) inputs = self.tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt").to(self.device)
return embeddings.cpu().numpy() outputs = self.model(**inputs)
embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
all_embeddings.append(embeddings.cpu().numpy())
return np.vstack(all_embeddings)
def mean_pooling(self, model_output, attention_mask): def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0] token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
embedder = BGEEmbedding() embedder = BGEEmbedding()

View File

@ -1,31 +1,28 @@
import faiss import faiss
import numpy as np import numpy as np
import threading import threading
from app.core.config import settings
class FaissIndexWrapper: class FaissIndexWrapper:
def __init__(self): def __init__(self, index_path: str):
self.index_lock = threading.Lock() self.index_lock = threading.Lock()
self.index = None self.index = None
self.load_index(settings.INDEX_FILE) self.index_path = index_path
self.load_index(index_path)
def load_index(self, path): def load_index(self, path: str = None):
with self.index_lock: with self.index_lock:
path = path or self.index_path
self.index = faiss.read_index(path) self.index = faiss.read_index(path)
print(f"[FAISS] Index loaded from {path}") print(f"[FAISS] Index loaded from {path}")
def update_index(self, path): def update_index(self, path: str = None):
"""热更新替换当前索引""" """热更新替换当前索引"""
with self.index_lock: with self.index_lock:
new_index = faiss.read_index(path) path = path or self.index_path
self.index = new_index self.index = faiss.read_index(path)
print(f"[FAISS] Index hot-swapped from {path}") print(f"[FAISS] Index hot-swapped from {path}")
def search(self, vector: np.ndarray, top_k: int = 5): def search(self, vector: np.ndarray, top_k: int = 5):
with self.index_lock: with self.index_lock:
D, I = self.index.search(vector.astype(np.float32), top_k) D, I = self.index.search(vector.astype(np.float32), top_k)
return D[0], I[0] return D[0], I[0]
# 单例,供 API 层引用
faiss_index = FaissIndexWrapper()

View File

@ -1,21 +1,38 @@
import os import os
import faiss import faiss
import threading
class UserIndexManager: class UserIndexManager:
def __init__(self, base_dir="index_data/"): def __init__(self, base_dir="index_data/"):
self.base_dir = base_dir self.base_dir = base_dir
self.index_map = {} # user_id -> faiss index self.index_map = {} # user_id -> faiss.Index
self.index_lock = threading.Lock()
self._load_all_indexes() self._load_all_indexes()
def _load_all_indexes(self): def _load_all_indexes(self):
for fname in os.listdir(self.base_dir): for fname in os.listdir(self.base_dir):
if fname.endswith(".index"): if fname.endswith(".index"):
user_id = fname.replace(".index", "") user_id = fname.replace(".index", "")
path = os.path.join(self.base_dir, fname) self._load_single_index(user_id)
self.index_map[user_id] = faiss.read_index(path)
print(f"[INIT] Loaded index for user {user_id}") def _load_single_index(self, user_id):
path = os.path.join(self.base_dir, f"{user_id}.index")
if os.path.exists(path):
index = faiss.read_index(path)
self.index_map[user_id] = index
print(f"[INIT] Loaded index for user: {user_id}")
def get_index(self, user_id): def get_index(self, user_id):
if user_id not in self.index_map: with self.index_lock:
raise ValueError(f"Index for user {user_id} not loaded.") if user_id not in self.index_map:
return self.index_map[user_id] raise FileNotFoundError(f"No FAISS index loaded for user: {user_id}")
return self.index_map[user_id]
def update_index(self, user_id):
"""重新加载用户索引文件"""
path = os.path.join(self.base_dir, f"{user_id}.index")
if not os.path.exists(path):
raise FileNotFoundError(f"No index file found for user: {user_id}")
with self.index_lock:
self.index_map[user_id] = faiss.read_index(path)
print(f"[UPDATE] Index for user {user_id} has been hot-reloaded.")

View File

@ -1,8 +1,8 @@
import os import os
import faiss import faiss
import numpy as np import numpy as np
import sys
from app.core.embedding import embedder from app.core.embedding import embedder
from app.core.config import settings
def load_documents(doc_folder): def load_documents(doc_folder):
texts = [] texts = []
@ -20,12 +20,20 @@ def build_faiss_index(docs, dim):
return index return index
if __name__ == "__main__": if __name__ == "__main__":
print("[BUILD] Loading documents...") if len(sys.argv) < 2:
docs = load_documents(settings.DOC_PATH) print("Usage: python -m scripts.build_index <user_id>")
sys.exit(1)
user_id = sys.argv[1]
doc_dir = os.path.join("docs", user_id)
index_path = os.path.join("index_data", f"{user_id}.index")
print(f"[BUILD] Loading documents from {doc_dir} ...")
docs = load_documents(doc_dir)
print(f"[BUILD] Loaded {len(docs)} documents") print(f"[BUILD] Loaded {len(docs)} documents")
print("[BUILD] Building FAISS index...") print("[BUILD] Building FAISS index...")
index = build_faiss_index(docs, settings.EMBEDDING_DIM) index = build_faiss_index(docs, dim=768) # 或根据模型配置动态设定
print(f"[BUILD] Saving index to {settings.INDEX_FILE}") print(f"[BUILD] Saving index to {index_path}")
faiss.write_index(index, settings.INDEX_FILE) faiss.write_index(index, index_path)

View File

@ -1,5 +1,5 @@
{ {
"user_001": ["common.index", "engineering.index"], "user_001": ["common.index", "engineering.index"], //
"user_002": ["common.index", "hr.index"], "user_002": ["common.index", "hr.index"], //
"user_003": [] "user_003": [] //
} }

View File

@ -1,12 +1,24 @@
import json import json
import os
PERMISSION_PATH = "scripts/index_permissions.json" PERMISSION_PATH = "scripts/index_permissions.json"
def get_user_allowed_indexes(user_id: str): def get_user_allowed_indexes(user_id: str):
if not os.path.exists(PERMISSION_PATH):
print(f"[Permission Warning] 权限配置文件不存在: {PERMISSION_PATH}")
return []
try: try:
with open(PERMISSION_PATH, "r", encoding="utf-8") as f: with open(PERMISSION_PATH, "r", encoding="utf-8") as f:
permission_map = json.load(f) permission_map = json.load(f)
if user_id not in permission_map:
print(f"[Permission Info] 用户 {user_id} 无共享库配置,使用默认权限。")
return permission_map.get(user_id, []) return permission_map.get(user_id, [])
except json.JSONDecodeError:
print(f"[Permission Error] 权限配置文件 JSON 格式错误。")
return []
except Exception as e: except Exception as e:
print(f"[Permission Error] {e}") print(f"[Permission Error] 未知错误: {e}")
return [] return []

View File

@ -24,43 +24,47 @@ def build_user_index(user_id: str):
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
service_context = ServiceContext.from_defaults(embed_model=embed_model) service_context = ServiceContext.from_defaults(embed_model=embed_model)
# 构建向量索引
index = VectorStoreIndex.from_documents( index = VectorStoreIndex.from_documents(
documents, documents,
service_context=service_context, service_context=service_context,
vector_store=FaissVectorStore() vector_store=FaissVectorStore()
) )
# 保存为用户专属 .index 文件
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
faiss.write_index(index.vector_store.index, index_path) faiss.write_index(index.vector_store.index, index_path)
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}") print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
if not os.path.exists(index_path):
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
service_context = ServiceContext.from_defaults(embed_model=embed_model) service_context = ServiceContext.from_defaults(embed_model=embed_model)
# 加载主索引 all_nodes = []
vector_store = FaissVectorStore.from_persist_path(index_path)
index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context)
nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question)
# 加载权限范围内的共享索引 # 加载用户主索引
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
if not os.path.exists(index_path):
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
user_store = FaissVectorStore.from_persist_path(index_path)
user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context)
all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question)
# 加载共享索引
shared_indexes = get_user_allowed_indexes(user_id) shared_indexes = get_user_allowed_indexes(user_id)
for shared_name in shared_indexes: if shared_indexes:
shared_path = os.path.join(USER_INDEX_PATH, shared_name) for shared_name in shared_indexes:
if os.path.exists(shared_path): shared_path = os.path.join(USER_INDEX_PATH, shared_name)
shared_store = FaissVectorStore.from_persist_path(shared_path) if os.path.exists(shared_path) and shared_path != index_path:
shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context) shared_store = FaissVectorStore.from_persist_path(shared_path)
nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context)
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
else:
print(f"[INFO] 用户 {user_id} 没有共享索引权限")
# 构造 Prompt # 合并 + 按 score 排序
context_str = "\n\n".join([n.get_content() for n in nodes]) sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
top_nodes = sorted_nodes[:top_k]
context_str = "\n\n".join([n.get_content() for n in top_nodes])
prompt_template = PromptTemplate( prompt_template = PromptTemplate(
"请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}" "请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}"
) )

View File

@ -1,23 +1,30 @@
import time import time
import os import os
import sys
from app.core.index import faiss_index from app.core.index import faiss_index
from app.core.config import settings
def watch_and_reload(interval=600): def watch_and_reload(user_id: str, interval=600):
print("[HOT-RELOAD] Watching index file for updates...") index_file = os.path.join("index_data", f"{user_id}.index")
print(f"[HOT-RELOAD] Watching {index_file} for updates...")
last_mtime = None last_mtime = None
while True: while True:
try: try:
mtime = os.path.getmtime(settings.INDEX_FILE) mtime = os.path.getmtime(index_file)
if last_mtime is None: if last_mtime is None:
last_mtime = mtime last_mtime = mtime
elif mtime != last_mtime: elif mtime != last_mtime:
print("[HOT-RELOAD] Detected new index file. Reloading...") print("[HOT-RELOAD] Detected new index file. Reloading...")
faiss_index.update_index(settings.INDEX_FILE) faiss_index.update_index(index_file)
last_mtime = mtime last_mtime = mtime
except Exception as e: except Exception as e:
print(f"[HOT-RELOAD] Error: {e}") print(f"[HOT-RELOAD] Error: {e}")
time.sleep(interval) time.sleep(interval)
if __name__ == "__main__": if __name__ == "__main__":
watch_and_reload() if len(sys.argv) < 2:
print("Usage: python -m scripts.update_index <user_id>")
sys.exit(1)
user_id = sys.argv[1]
watch_and_reload(user_id)