This commit is contained in:
hailin 2025-05-08 16:38:30 +08:00
parent 3a1fc39c48
commit 2d0a6c11ee
11 changed files with 162 additions and 90 deletions

View File

@ -1,12 +1,13 @@
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
import numpy as np
from app.core.index_manager import UserIndexManager
from app.core.embedding import embedder
from app.core.config import settings
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index import VectorStoreIndex, ServiceContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import os
router = APIRouter()
index_manager = UserIndexManager()
class QueryRequest(BaseModel):
query: str
@ -14,13 +15,27 @@ class QueryRequest(BaseModel):
@router.post("/search")
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
try:
query_vector = embedder.encode([request.query])
index = index_manager.get_index(user_id)
D, I = index.search(query_vector, settings.TOP_K)
index_path = os.path.join("index_data", f"{user_id}.index")
if not os.path.exists(index_path):
raise HTTPException(status_code=404, detail="用户索引不存在")
# 构建 LlamaIndex 检索器
faiss_store = FaissVectorStore.from_persist_path(index_path)
service_context = ServiceContext.from_defaults(embed_model=embedder)
index = VectorStoreIndex.from_vector_store(faiss_store, service_context=service_context)
# 检索结果(真实文本)
retriever = index.as_retriever(similarity_top_k=settings.TOP_K)
nodes = retriever.retrieve(request.query)
return {
"user_id": user_id,
"query": request.query,
"results": [{"id": int(idx), "score": float(dist)} for dist, idx in zip(D, I)]
"results": [
{"score": float(node.score or 0), "text": node.get_content()}
for node in nodes
]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))

View File

@ -1,23 +1,32 @@
from fastapi import APIRouter, UploadFile, File, Form
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
import os
from shutil import copyfileobj
import shutil
from scripts.rag_build_query import build_user_index
router = APIRouter()
ALLOWED_SUFFIXES = {".txt", ".md", ".pdf", ".docx"}
@router.post("/upload")
def upload_user_file(user_id: str = Form(...), file: UploadFile = File(...)):
filename = os.path.basename(file.filename)
suffix = os.path.splitext(filename)[-1].lower()
if suffix not in ALLOWED_SUFFIXES:
raise HTTPException(status_code=400, detail="不支持的文件类型")
user_doc_dir = os.path.join("docs", user_id)
os.makedirs(user_doc_dir, exist_ok=True)
file_path = os.path.join(user_doc_dir, file.filename)
with open(file_path, "wb") as f:
copyfileobj(file.file, f)
file_path = os.path.join(user_doc_dir, filename)
try:
with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
print(f"[UPLOAD] 文件已保存至 {file_path}")
print(f"[UPLOAD] 文件已保存至 {file_path}")
build_user_index(user_id)
print(f"[UPLOAD] 用户 {user_id} 的索引已重建")
# 自动重建用户索引
build_user_index(user_id)
print(f"[UPLOAD] 用户 {user_id} 的索引已重建")
except Exception as e:
print(f"[UPLOAD ERROR] {e}")
raise HTTPException(status_code=500, detail="索引构建失败")
return {"status": "ok", "filename": file.filename}
return {"status": "ok", "filename": filename}

View File

@ -1,12 +1,10 @@
from pydantic_settings import BaseSettings
import os
class Settings(BaseSettings):
INDEX_FILE: str = os.getenv("INDEX_FILE", "index_data/index.faiss")
EMBEDDING_DIM: int = 768
TOP_K: int = 5
DOC_PATH: str = "docs/"
DEVICE: str = "cpu"
MODEL_NAME: str = "BAAI/bge-m3"
EMBEDDING_DIM: int = 768 # 嵌入维度(取决于模型)
TOP_K: int = 5 # 默认检索 top K 个段落
DOC_PATH: str = "docs/" # 默认文档根目录
DEVICE: str = "cpu" # 可改为 "cuda" 使用 GPU
MODEL_NAME: str = "BAAI/bge-m3" # 多语种语义嵌入模型
settings = Settings()

View File

@ -5,21 +5,26 @@ from app.core.config import settings
class BGEEmbedding:
def __init__(self):
self.device = torch.device(settings.DEVICE)
self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
self.model = AutoModel.from_pretrained(settings.MODEL_NAME)
self.model = AutoModel.from_pretrained(settings.MODEL_NAME).to(self.device)
self.model.eval()
def encode(self, texts):
with torch.no_grad():
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
outputs = self.model(**inputs)
embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.cpu().numpy()
def encode(self, texts, batch_size=8):
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
with torch.no_grad():
inputs = self.tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt").to(self.device)
outputs = self.model(**inputs)
embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
all_embeddings.append(embeddings.cpu().numpy())
return np.vstack(all_embeddings)
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
embedder = BGEEmbedding()
embedder = BGEEmbedding()

View File

@ -1,31 +1,28 @@
import faiss
import numpy as np
import threading
from app.core.config import settings
class FaissIndexWrapper:
def __init__(self):
def __init__(self, index_path: str):
self.index_lock = threading.Lock()
self.index = None
self.load_index(settings.INDEX_FILE)
self.index_path = index_path
self.load_index(index_path)
def load_index(self, path):
def load_index(self, path: str = None):
with self.index_lock:
path = path or self.index_path
self.index = faiss.read_index(path)
print(f"[FAISS] Index loaded from {path}")
def update_index(self, path):
def update_index(self, path: str = None):
"""热更新替换当前索引"""
with self.index_lock:
new_index = faiss.read_index(path)
self.index = new_index
path = path or self.index_path
self.index = faiss.read_index(path)
print(f"[FAISS] Index hot-swapped from {path}")
def search(self, vector: np.ndarray, top_k: int = 5):
with self.index_lock:
D, I = self.index.search(vector.astype(np.float32), top_k)
return D[0], I[0]
# 单例,供 API 层引用
faiss_index = FaissIndexWrapper()

View File

@ -1,21 +1,38 @@
import os
import faiss
import threading
class UserIndexManager:
def __init__(self, base_dir="index_data/"):
self.base_dir = base_dir
self.index_map = {} # user_id -> faiss index
self.index_map = {} # user_id -> faiss.Index
self.index_lock = threading.Lock()
self._load_all_indexes()
def _load_all_indexes(self):
for fname in os.listdir(self.base_dir):
if fname.endswith(".index"):
user_id = fname.replace(".index", "")
path = os.path.join(self.base_dir, fname)
self.index_map[user_id] = faiss.read_index(path)
print(f"[INIT] Loaded index for user {user_id}")
self._load_single_index(user_id)
def _load_single_index(self, user_id):
path = os.path.join(self.base_dir, f"{user_id}.index")
if os.path.exists(path):
index = faiss.read_index(path)
self.index_map[user_id] = index
print(f"[INIT] Loaded index for user: {user_id}")
def get_index(self, user_id):
if user_id not in self.index_map:
raise ValueError(f"Index for user {user_id} not loaded.")
return self.index_map[user_id]
with self.index_lock:
if user_id not in self.index_map:
raise FileNotFoundError(f"No FAISS index loaded for user: {user_id}")
return self.index_map[user_id]
def update_index(self, user_id):
"""重新加载用户索引文件"""
path = os.path.join(self.base_dir, f"{user_id}.index")
if not os.path.exists(path):
raise FileNotFoundError(f"No index file found for user: {user_id}")
with self.index_lock:
self.index_map[user_id] = faiss.read_index(path)
print(f"[UPDATE] Index for user {user_id} has been hot-reloaded.")

View File

@ -1,8 +1,8 @@
import os
import faiss
import numpy as np
import sys
from app.core.embedding import embedder
from app.core.config import settings
def load_documents(doc_folder):
texts = []
@ -20,12 +20,20 @@ def build_faiss_index(docs, dim):
return index
if __name__ == "__main__":
print("[BUILD] Loading documents...")
docs = load_documents(settings.DOC_PATH)
if len(sys.argv) < 2:
print("Usage: python -m scripts.build_index <user_id>")
sys.exit(1)
user_id = sys.argv[1]
doc_dir = os.path.join("docs", user_id)
index_path = os.path.join("index_data", f"{user_id}.index")
print(f"[BUILD] Loading documents from {doc_dir} ...")
docs = load_documents(doc_dir)
print(f"[BUILD] Loaded {len(docs)} documents")
print("[BUILD] Building FAISS index...")
index = build_faiss_index(docs, settings.EMBEDDING_DIM)
index = build_faiss_index(docs, dim=768) # 或根据模型配置动态设定
print(f"[BUILD] Saving index to {settings.INDEX_FILE}")
faiss.write_index(index, settings.INDEX_FILE)
print(f"[BUILD] Saving index to {index_path}")
faiss.write_index(index, index_path)

View File

@ -1,5 +1,5 @@
{
"user_001": ["common.index", "engineering.index"],
"user_002": ["common.index", "hr.index"],
"user_003": []
}
"user_001": ["common.index", "engineering.index"], //
"user_002": ["common.index", "hr.index"], //
"user_003": [] //
}

View File

@ -1,12 +1,24 @@
import json
import os
PERMISSION_PATH = "scripts/index_permissions.json"
def get_user_allowed_indexes(user_id: str):
if not os.path.exists(PERMISSION_PATH):
print(f"[Permission Warning] 权限配置文件不存在: {PERMISSION_PATH}")
return []
try:
with open(PERMISSION_PATH, "r", encoding="utf-8") as f:
permission_map = json.load(f)
if user_id not in permission_map:
print(f"[Permission Info] 用户 {user_id} 无共享库配置,使用默认权限。")
return permission_map.get(user_id, [])
except json.JSONDecodeError:
print(f"[Permission Error] 权限配置文件 JSON 格式错误。")
return []
except Exception as e:
print(f"[Permission Error] {e}")
return []
print(f"[Permission Error] 未知错误: {e}")
return []

View File

@ -24,43 +24,47 @@ def build_user_index(user_id: str):
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
service_context = ServiceContext.from_defaults(embed_model=embed_model)
# 构建向量索引
index = VectorStoreIndex.from_documents(
documents,
service_context=service_context,
vector_store=FaissVectorStore()
)
# 保存为用户专属 .index 文件
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
faiss.write_index(index.vector_store.index, index_path)
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
if not os.path.exists(index_path):
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
service_context = ServiceContext.from_defaults(embed_model=embed_model)
# 加载主索引
vector_store = FaissVectorStore.from_persist_path(index_path)
index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context)
nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question)
all_nodes = []
# 加载权限范围内的共享索引
# 加载用户主索引
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
if not os.path.exists(index_path):
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
user_store = FaissVectorStore.from_persist_path(index_path)
user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context)
all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question)
# 加载共享索引
shared_indexes = get_user_allowed_indexes(user_id)
for shared_name in shared_indexes:
shared_path = os.path.join(USER_INDEX_PATH, shared_name)
if os.path.exists(shared_path):
shared_store = FaissVectorStore.from_persist_path(shared_path)
shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context)
nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
if shared_indexes:
for shared_name in shared_indexes:
shared_path = os.path.join(USER_INDEX_PATH, shared_name)
if os.path.exists(shared_path) and shared_path != index_path:
shared_store = FaissVectorStore.from_persist_path(shared_path)
shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context)
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
else:
print(f"[INFO] 用户 {user_id} 没有共享索引权限")
# 构造 Prompt
context_str = "\n\n".join([n.get_content() for n in nodes])
# 合并 + 按 score 排序
sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
top_nodes = sorted_nodes[:top_k]
context_str = "\n\n".join([n.get_content() for n in top_nodes])
prompt_template = PromptTemplate(
"请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}"
)

View File

@ -1,23 +1,30 @@
import time
import os
import sys
from app.core.index import faiss_index
from app.core.config import settings
def watch_and_reload(interval=600):
print("[HOT-RELOAD] Watching index file for updates...")
def watch_and_reload(user_id: str, interval=600):
index_file = os.path.join("index_data", f"{user_id}.index")
print(f"[HOT-RELOAD] Watching {index_file} for updates...")
last_mtime = None
while True:
try:
mtime = os.path.getmtime(settings.INDEX_FILE)
mtime = os.path.getmtime(index_file)
if last_mtime is None:
last_mtime = mtime
elif mtime != last_mtime:
print("[HOT-RELOAD] Detected new index file. Reloading...")
faiss_index.update_index(settings.INDEX_FILE)
faiss_index.update_index(index_file)
last_mtime = mtime
except Exception as e:
print(f"[HOT-RELOAD] Error: {e}")
time.sleep(interval)
if __name__ == "__main__":
watch_and_reload()
if len(sys.argv) < 2:
print("Usage: python -m scripts.update_index <user_id>")
sys.exit(1)
user_id = sys.argv[1]
watch_and_reload(user_id)