diff --git a/app/api/search.py b/app/api/search.py index 43876cd..b6d30c3 100644 --- a/app/api/search.py +++ b/app/api/search.py @@ -1,12 +1,13 @@ from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel -import numpy as np -from app.core.index_manager import UserIndexManager from app.core.embedding import embedder from app.core.config import settings +from llama_index.vector_stores.faiss import FaissVectorStore +from llama_index import VectorStoreIndex, ServiceContext +from llama_index.embeddings.huggingface import HuggingFaceEmbedding +import os router = APIRouter() -index_manager = UserIndexManager() class QueryRequest(BaseModel): query: str @@ -14,13 +15,27 @@ class QueryRequest(BaseModel): @router.post("/search") def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")): try: - query_vector = embedder.encode([request.query]) - index = index_manager.get_index(user_id) - D, I = index.search(query_vector, settings.TOP_K) + index_path = os.path.join("index_data", f"{user_id}.index") + if not os.path.exists(index_path): + raise HTTPException(status_code=404, detail="用户索引不存在") + + # 构建 LlamaIndex 检索器 + faiss_store = FaissVectorStore.from_persist_path(index_path) + service_context = ServiceContext.from_defaults(embed_model=embedder) + index = VectorStoreIndex.from_vector_store(faiss_store, service_context=service_context) + + # 检索结果(真实文本) + retriever = index.as_retriever(similarity_top_k=settings.TOP_K) + nodes = retriever.retrieve(request.query) + return { "user_id": user_id, "query": request.query, - "results": [{"id": int(idx), "score": float(dist)} for dist, idx in zip(D, I)] + "results": [ + {"score": float(node.score or 0), "text": node.get_content()} + for node in nodes + ] } + except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file + raise HTTPException(status_code=500, detail=str(e)) diff --git a/app/api/upload.py b/app/api/upload.py index 861487c..1baba46 100644 --- a/app/api/upload.py +++ b/app/api/upload.py @@ -1,23 +1,32 @@ -from fastapi import APIRouter, UploadFile, File, Form +from fastapi import APIRouter, UploadFile, File, Form, HTTPException import os -from shutil import copyfileobj +import shutil from scripts.rag_build_query import build_user_index router = APIRouter() +ALLOWED_SUFFIXES = {".txt", ".md", ".pdf", ".docx"} @router.post("/upload") def upload_user_file(user_id: str = Form(...), file: UploadFile = File(...)): + filename = os.path.basename(file.filename) + suffix = os.path.splitext(filename)[-1].lower() + if suffix not in ALLOWED_SUFFIXES: + raise HTTPException(status_code=400, detail="不支持的文件类型") + user_doc_dir = os.path.join("docs", user_id) os.makedirs(user_doc_dir, exist_ok=True) - file_path = os.path.join(user_doc_dir, file.filename) - with open(file_path, "wb") as f: - copyfileobj(file.file, f) + file_path = os.path.join(user_doc_dir, filename) + try: + with open(file_path, "wb") as f: + shutil.copyfileobj(file.file, f) + print(f"[UPLOAD] 文件已保存至 {file_path}") - print(f"[UPLOAD] 文件已保存至 {file_path}") + build_user_index(user_id) + print(f"[UPLOAD] 用户 {user_id} 的索引已重建") - # 自动重建用户索引 - build_user_index(user_id) - print(f"[UPLOAD] 用户 {user_id} 的索引已重建") + except Exception as e: + print(f"[UPLOAD ERROR] {e}") + raise HTTPException(status_code=500, detail="索引构建失败") - return {"status": "ok", "filename": file.filename} \ No newline at end of file + return {"status": "ok", "filename": filename} diff --git a/app/core/config.py b/app/core/config.py index f6b786d..14e3405 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -1,12 +1,10 @@ from pydantic_settings import BaseSettings -import os class Settings(BaseSettings): - INDEX_FILE: str = os.getenv("INDEX_FILE", "index_data/index.faiss") - EMBEDDING_DIM: int = 768 - TOP_K: int = 5 - DOC_PATH: str = "docs/" - DEVICE: str = "cpu" - MODEL_NAME: str = "BAAI/bge-m3" + EMBEDDING_DIM: int = 768 # 嵌入维度(取决于模型) + TOP_K: int = 5 # 默认检索 top K 个段落 + DOC_PATH: str = "docs/" # 默认文档根目录 + DEVICE: str = "cpu" # 可改为 "cuda" 使用 GPU + MODEL_NAME: str = "BAAI/bge-m3" # 多语种语义嵌入模型 settings = Settings() diff --git a/app/core/embedding.py b/app/core/embedding.py index 814c906..86ddb8a 100644 --- a/app/core/embedding.py +++ b/app/core/embedding.py @@ -5,21 +5,26 @@ from app.core.config import settings class BGEEmbedding: def __init__(self): + self.device = torch.device(settings.DEVICE) self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME) - self.model = AutoModel.from_pretrained(settings.MODEL_NAME) + self.model = AutoModel.from_pretrained(settings.MODEL_NAME).to(self.device) self.model.eval() - def encode(self, texts): - with torch.no_grad(): - inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt") - outputs = self.model(**inputs) - embeddings = self.mean_pooling(outputs, inputs['attention_mask']) - embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) - return embeddings.cpu().numpy() + def encode(self, texts, batch_size=8): + all_embeddings = [] + for i in range(0, len(texts), batch_size): + batch_texts = texts[i:i+batch_size] + with torch.no_grad(): + inputs = self.tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt").to(self.device) + outputs = self.model(**inputs) + embeddings = self.mean_pooling(outputs, inputs['attention_mask']) + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + all_embeddings.append(embeddings.cpu().numpy()) + return np.vstack(all_embeddings) def mean_pooling(self, model_output, attention_mask): token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) -embedder = BGEEmbedding() \ No newline at end of file +embedder = BGEEmbedding() diff --git a/app/core/index.py b/app/core/index.py index 049c264..9dbd8dd 100644 --- a/app/core/index.py +++ b/app/core/index.py @@ -1,31 +1,28 @@ - import faiss import numpy as np import threading -from app.core.config import settings class FaissIndexWrapper: - def __init__(self): + def __init__(self, index_path: str): self.index_lock = threading.Lock() self.index = None - self.load_index(settings.INDEX_FILE) + self.index_path = index_path + self.load_index(index_path) - def load_index(self, path): + def load_index(self, path: str = None): with self.index_lock: + path = path or self.index_path self.index = faiss.read_index(path) print(f"[FAISS] Index loaded from {path}") - def update_index(self, path): + def update_index(self, path: str = None): """热更新替换当前索引""" with self.index_lock: - new_index = faiss.read_index(path) - self.index = new_index + path = path or self.index_path + self.index = faiss.read_index(path) print(f"[FAISS] Index hot-swapped from {path}") def search(self, vector: np.ndarray, top_k: int = 5): with self.index_lock: D, I = self.index.search(vector.astype(np.float32), top_k) return D[0], I[0] - -# 单例,供 API 层引用 -faiss_index = FaissIndexWrapper() diff --git a/app/core/index_manager.py b/app/core/index_manager.py index 8626fbf..9841b33 100644 --- a/app/core/index_manager.py +++ b/app/core/index_manager.py @@ -1,21 +1,38 @@ import os import faiss +import threading class UserIndexManager: def __init__(self, base_dir="index_data/"): self.base_dir = base_dir - self.index_map = {} # user_id -> faiss index + self.index_map = {} # user_id -> faiss.Index + self.index_lock = threading.Lock() self._load_all_indexes() def _load_all_indexes(self): for fname in os.listdir(self.base_dir): if fname.endswith(".index"): user_id = fname.replace(".index", "") - path = os.path.join(self.base_dir, fname) - self.index_map[user_id] = faiss.read_index(path) - print(f"[INIT] Loaded index for user {user_id}") + self._load_single_index(user_id) + + def _load_single_index(self, user_id): + path = os.path.join(self.base_dir, f"{user_id}.index") + if os.path.exists(path): + index = faiss.read_index(path) + self.index_map[user_id] = index + print(f"[INIT] Loaded index for user: {user_id}") def get_index(self, user_id): - if user_id not in self.index_map: - raise ValueError(f"Index for user {user_id} not loaded.") - return self.index_map[user_id] \ No newline at end of file + with self.index_lock: + if user_id not in self.index_map: + raise FileNotFoundError(f"No FAISS index loaded for user: {user_id}") + return self.index_map[user_id] + + def update_index(self, user_id): + """重新加载用户索引文件""" + path = os.path.join(self.base_dir, f"{user_id}.index") + if not os.path.exists(path): + raise FileNotFoundError(f"No index file found for user: {user_id}") + with self.index_lock: + self.index_map[user_id] = faiss.read_index(path) + print(f"[UPDATE] Index for user {user_id} has been hot-reloaded.") diff --git a/scripts/build_index.py b/scripts/build_index.py index 21c19c0..f3cbd9e 100644 --- a/scripts/build_index.py +++ b/scripts/build_index.py @@ -1,8 +1,8 @@ import os import faiss import numpy as np +import sys from app.core.embedding import embedder -from app.core.config import settings def load_documents(doc_folder): texts = [] @@ -20,12 +20,20 @@ def build_faiss_index(docs, dim): return index if __name__ == "__main__": - print("[BUILD] Loading documents...") - docs = load_documents(settings.DOC_PATH) + if len(sys.argv) < 2: + print("Usage: python -m scripts.build_index ") + sys.exit(1) + + user_id = sys.argv[1] + doc_dir = os.path.join("docs", user_id) + index_path = os.path.join("index_data", f"{user_id}.index") + + print(f"[BUILD] Loading documents from {doc_dir} ...") + docs = load_documents(doc_dir) print(f"[BUILD] Loaded {len(docs)} documents") print("[BUILD] Building FAISS index...") - index = build_faiss_index(docs, settings.EMBEDDING_DIM) + index = build_faiss_index(docs, dim=768) # 或根据模型配置动态设定 - print(f"[BUILD] Saving index to {settings.INDEX_FILE}") - faiss.write_index(index, settings.INDEX_FILE) \ No newline at end of file + print(f"[BUILD] Saving index to {index_path}") + faiss.write_index(index, index_path) diff --git a/scripts/index_permissions.json b/scripts/index_permissions.json index 8de11d6..3a42dc4 100644 --- a/scripts/index_permissions.json +++ b/scripts/index_permissions.json @@ -1,5 +1,5 @@ { - "user_001": ["common.index", "engineering.index"], - "user_002": ["common.index", "hr.index"], - "user_003": [] -} \ No newline at end of file + "user_001": ["common.index", "engineering.index"], // 工程部员工 + "user_002": ["common.index", "hr.index"], // 人力资源员工 + "user_003": [] // 普通用户,无共享库 +} diff --git a/scripts/permissions.py b/scripts/permissions.py index c432595..bf13614 100644 --- a/scripts/permissions.py +++ b/scripts/permissions.py @@ -1,12 +1,24 @@ import json +import os PERMISSION_PATH = "scripts/index_permissions.json" def get_user_allowed_indexes(user_id: str): + if not os.path.exists(PERMISSION_PATH): + print(f"[Permission Warning] 权限配置文件不存在: {PERMISSION_PATH}") + return [] + try: with open(PERMISSION_PATH, "r", encoding="utf-8") as f: permission_map = json.load(f) + + if user_id not in permission_map: + print(f"[Permission Info] 用户 {user_id} 无共享库配置,使用默认权限。") return permission_map.get(user_id, []) + + except json.JSONDecodeError: + print(f"[Permission Error] 权限配置文件 JSON 格式错误。") + return [] except Exception as e: - print(f"[Permission Error] {e}") - return [] \ No newline at end of file + print(f"[Permission Error] 未知错误: {e}") + return [] diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py index be26149..17a0947 100644 --- a/scripts/rag_build_query.py +++ b/scripts/rag_build_query.py @@ -24,43 +24,47 @@ def build_user_index(user_id: str): embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model) - # 构建向量索引 index = VectorStoreIndex.from_documents( documents, service_context=service_context, vector_store=FaissVectorStore() ) - # 保存为用户专属 .index 文件 index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") faiss.write_index(index.vector_store.index, index_path) print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}") def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: - index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") - if not os.path.exists(index_path): - raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在") - embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME) service_context = ServiceContext.from_defaults(embed_model=embed_model) - # 加载主索引 - vector_store = FaissVectorStore.from_persist_path(index_path) - index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context) - - nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question) + all_nodes = [] - # 加载权限范围内的共享索引 + # 加载用户主索引 + index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") + if not os.path.exists(index_path): + raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在") + user_store = FaissVectorStore.from_persist_path(index_path) + user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context) + all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question) + + # 加载共享索引 shared_indexes = get_user_allowed_indexes(user_id) - for shared_name in shared_indexes: - shared_path = os.path.join(USER_INDEX_PATH, shared_name) - if os.path.exists(shared_path): - shared_store = FaissVectorStore.from_persist_path(shared_path) - shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context) - nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) + if shared_indexes: + for shared_name in shared_indexes: + shared_path = os.path.join(USER_INDEX_PATH, shared_name) + if os.path.exists(shared_path) and shared_path != index_path: + shared_store = FaissVectorStore.from_persist_path(shared_path) + shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context) + all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question) + else: + print(f"[INFO] 用户 {user_id} 没有共享索引权限") - # 构造 Prompt - context_str = "\n\n".join([n.get_content() for n in nodes]) + # 合并 + 按 score 排序 + sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0)) + top_nodes = sorted_nodes[:top_k] + + context_str = "\n\n".join([n.get_content() for n in top_nodes]) prompt_template = PromptTemplate( "请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}" ) diff --git a/scripts/update_index.py b/scripts/update_index.py index 03d611c..70d3e60 100644 --- a/scripts/update_index.py +++ b/scripts/update_index.py @@ -1,23 +1,30 @@ import time import os +import sys from app.core.index import faiss_index -from app.core.config import settings -def watch_and_reload(interval=600): - print("[HOT-RELOAD] Watching index file for updates...") +def watch_and_reload(user_id: str, interval=600): + index_file = os.path.join("index_data", f"{user_id}.index") + print(f"[HOT-RELOAD] Watching {index_file} for updates...") last_mtime = None + while True: try: - mtime = os.path.getmtime(settings.INDEX_FILE) + mtime = os.path.getmtime(index_file) if last_mtime is None: last_mtime = mtime elif mtime != last_mtime: print("[HOT-RELOAD] Detected new index file. Reloading...") - faiss_index.update_index(settings.INDEX_FILE) + faiss_index.update_index(index_file) last_mtime = mtime except Exception as e: print(f"[HOT-RELOAD] Error: {e}") time.sleep(interval) if __name__ == "__main__": - watch_and_reload() \ No newline at end of file + if len(sys.argv) < 2: + print("Usage: python -m scripts.update_index ") + sys.exit(1) + + user_id = sys.argv[1] + watch_and_reload(user_id)