commit bb65afab1f0e60f202f593744256da99922ace30 Author: hailin Date: Thu May 8 15:12:46 2025 +0800 first commit! diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..9e4c08e --- /dev/null +++ b/Dockerfile @@ -0,0 +1,18 @@ +FROM python:3.10-slim + +WORKDIR /app + +COPY . /app + +# 安装系统依赖(必要时) +RUN apt-get update && apt-get install -y build-essential && \ + rm -rf /var/lib/apt/lists/* + +# 安装 Python 依赖 +RUN pip install --no-cache-dir -r requirements.txt + +# 声明服务监听端口(供 Kubernetes 或 docker run -P 使用) +EXPOSE 8000 + +# 启动服务(可指定绑定 IP 和端口) +CMD ["gunicorn", "app.main:app", "-k", "uvicorn.workers.UvicornWorker", "-c", "gunicorn_config.py"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..308addf --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# FAISS Enterprise RAG + +企业级部署方案,支持千万向量、本地高并发、热更新、容器部署。 \ No newline at end of file diff --git a/app/api/search.py b/app/api/search.py new file mode 100644 index 0000000..43876cd --- /dev/null +++ b/app/api/search.py @@ -0,0 +1,26 @@ +from fastapi import APIRouter, HTTPException, Query +from pydantic import BaseModel +import numpy as np +from app.core.index_manager import UserIndexManager +from app.core.embedding import embedder +from app.core.config import settings + +router = APIRouter() +index_manager = UserIndexManager() + +class QueryRequest(BaseModel): + query: str + +@router.post("/search") +def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")): + try: + query_vector = embedder.encode([request.query]) + index = index_manager.get_index(user_id) + D, I = index.search(query_vector, settings.TOP_K) + return { + "user_id": user_id, + "query": request.query, + "results": [{"id": int(idx), "score": float(dist)} for dist, idx in zip(D, I)] + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/app/api/upload.py b/app/api/upload.py new file mode 100644 index 0000000..861487c --- /dev/null +++ b/app/api/upload.py @@ -0,0 +1,23 @@ +from fastapi import APIRouter, UploadFile, File, Form +import os +from shutil import copyfileobj +from scripts.rag_build_query import build_user_index + +router = APIRouter() + +@router.post("/upload") +def upload_user_file(user_id: str = Form(...), file: UploadFile = File(...)): + user_doc_dir = os.path.join("docs", user_id) + os.makedirs(user_doc_dir, exist_ok=True) + + file_path = os.path.join(user_doc_dir, file.filename) + with open(file_path, "wb") as f: + copyfileobj(file.file, f) + + print(f"[UPLOAD] 文件已保存至 {file_path}") + + # 自动重建用户索引 + build_user_index(user_id) + print(f"[UPLOAD] 用户 {user_id} 的索引已重建") + + return {"status": "ok", "filename": file.filename} \ No newline at end of file diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 0000000..ec4cfbe --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,12 @@ +from pydantic import BaseSettings +import os + +class Settings(BaseSettings): + INDEX_FILE: str = os.getenv("INDEX_FILE", "index_data/index.faiss") + EMBEDDING_DIM: int = 768 + TOP_K: int = 5 + DOC_PATH: str = "docs/" + DEVICE: str = "cpu" # 可设置为 cuda:0 + MODEL_NAME: str = "BAAI/bge-m3" + +settings = Settings() \ No newline at end of file diff --git a/app/core/embedding.py b/app/core/embedding.py new file mode 100644 index 0000000..814c906 --- /dev/null +++ b/app/core/embedding.py @@ -0,0 +1,25 @@ +from transformers import AutoTokenizer, AutoModel +import torch +import numpy as np +from app.core.config import settings + +class BGEEmbedding: + def __init__(self): + self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME) + self.model = AutoModel.from_pretrained(settings.MODEL_NAME) + self.model.eval() + + def encode(self, texts): + with torch.no_grad(): + inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt") + outputs = self.model(**inputs) + embeddings = self.mean_pooling(outputs, inputs['attention_mask']) + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + return embeddings.cpu().numpy() + + def mean_pooling(self, model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + +embedder = BGEEmbedding() \ No newline at end of file diff --git a/app/core/index.py b/app/core/index.py new file mode 100644 index 0000000..049c264 --- /dev/null +++ b/app/core/index.py @@ -0,0 +1,31 @@ + +import faiss +import numpy as np +import threading +from app.core.config import settings + +class FaissIndexWrapper: + def __init__(self): + self.index_lock = threading.Lock() + self.index = None + self.load_index(settings.INDEX_FILE) + + def load_index(self, path): + with self.index_lock: + self.index = faiss.read_index(path) + print(f"[FAISS] Index loaded from {path}") + + def update_index(self, path): + """热更新替换当前索引""" + with self.index_lock: + new_index = faiss.read_index(path) + self.index = new_index + print(f"[FAISS] Index hot-swapped from {path}") + + def search(self, vector: np.ndarray, top_k: int = 5): + with self.index_lock: + D, I = self.index.search(vector.astype(np.float32), top_k) + return D[0], I[0] + +# 单例,供 API 层引用 +faiss_index = FaissIndexWrapper() diff --git a/app/core/index_manager.py b/app/core/index_manager.py new file mode 100644 index 0000000..8626fbf --- /dev/null +++ b/app/core/index_manager.py @@ -0,0 +1,21 @@ +import os +import faiss + +class UserIndexManager: + def __init__(self, base_dir="index_data/"): + self.base_dir = base_dir + self.index_map = {} # user_id -> faiss index + self._load_all_indexes() + + def _load_all_indexes(self): + for fname in os.listdir(self.base_dir): + if fname.endswith(".index"): + user_id = fname.replace(".index", "") + path = os.path.join(self.base_dir, fname) + self.index_map[user_id] = faiss.read_index(path) + print(f"[INIT] Loaded index for user {user_id}") + + def get_index(self, user_id): + if user_id not in self.index_map: + raise ValueError(f"Index for user {user_id} not loaded.") + return self.index_map[user_id] \ No newline at end of file diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..73e59a6 --- /dev/null +++ b/app/main.py @@ -0,0 +1,8 @@ +from fastapi import FastAPI +from app.api.search import router as search_router +from app.api.upload import router as upload_router + +app = FastAPI(title="Enterprise FAISS RAG Server") + +app.include_router(search_router, prefix="/api", tags=["search"]) +app.include_router(upload_router, prefix="/api", tags=["upload"]) \ No newline at end of file diff --git a/gunicorn_config.py b/gunicorn_config.py new file mode 100644 index 0000000..8228bea --- /dev/null +++ b/gunicorn_config.py @@ -0,0 +1,4 @@ +workers = 12 # 可按CPU核心调整 +timeout = 60 +preload_app = True +bind = "0.0.0.0:8000" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..56f4a46 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +faiss-cpu +fastapi +uvicorn +gunicorn +pydantic +numpy +transformers +torch \ No newline at end of file diff --git a/scripts/build_index.py b/scripts/build_index.py new file mode 100644 index 0000000..21c19c0 --- /dev/null +++ b/scripts/build_index.py @@ -0,0 +1,31 @@ +import os +import faiss +import numpy as np +from app.core.embedding import embedder +from app.core.config import settings + +def load_documents(doc_folder): + texts = [] + for fname in os.listdir(doc_folder): + path = os.path.join(doc_folder, fname) + if os.path.isfile(path): + with open(path, "r", encoding="utf-8") as f: + texts.append(f.read()) + return texts + +def build_faiss_index(docs, dim): + vectors = embedder.encode(docs) + index = faiss.IndexFlatIP(dim) + index.add(vectors) + return index + +if __name__ == "__main__": + print("[BUILD] Loading documents...") + docs = load_documents(settings.DOC_PATH) + print(f"[BUILD] Loaded {len(docs)} documents") + + print("[BUILD] Building FAISS index...") + index = build_faiss_index(docs, settings.EMBEDDING_DIM) + + print(f"[BUILD] Saving index to {settings.INDEX_FILE}") + faiss.write_index(index, settings.INDEX_FILE) \ No newline at end of file diff --git a/scripts/index_permissions.json b/scripts/index_permissions.json new file mode 100644 index 0000000..8de11d6 --- /dev/null +++ b/scripts/index_permissions.json @@ -0,0 +1,5 @@ +{ + "user_001": ["common.index", "engineering.index"], + "user_002": ["common.index", "hr.index"], + "user_003": [] +} \ No newline at end of file diff --git a/scripts/permissions.py b/scripts/permissions.py new file mode 100644 index 0000000..c432595 --- /dev/null +++ b/scripts/permissions.py @@ -0,0 +1,12 @@ +import json + +PERMISSION_PATH = "scripts/index_permissions.json" + +def get_user_allowed_indexes(user_id: str): + try: + with open(PERMISSION_PATH, "r", encoding="utf-8") as f: + permission_map = json.load(f) + return permission_map.get(user_id, []) + except Exception as e: + print(f"[Permission Error] {e}") + return [] \ No newline at end of file diff --git a/scripts/rag_build_query.py b/scripts/rag_build_query.py new file mode 100644 index 0000000..897a856 --- /dev/null +++ b/scripts/rag_build_query.py @@ -0,0 +1,76 @@ +import os +import faiss +from typing import List +from llama_index import ( + SimpleDirectoryReader, + VectorStoreIndex, + ServiceContext, + PromptTemplate, +) +from llama_index.embeddings.huggingface import HuggingFaceEmbedding +from llama_index.vector_stores.faiss import FaissVectorStore +from llama_index.llms.base import ChatMessage + +# 假设你要用的本地嵌入模型 +EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" +USER_INDEX_PATH = "index_data" +USER_DOC_PATH = "docs" + +def build_user_index(user_id: str): + doc_dir = os.path.join(USER_DOC_PATH, user_id) + if not os.path.exists(doc_dir): + raise FileNotFoundError(f"文档目录不存在: {doc_dir}") + + documents = SimpleDirectoryReader(doc_dir).load_data() + embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_NAME) + service_context = ServiceContext.from_defaults(embed_model=embed_model) + + # 构建向量索引 + index = VectorStoreIndex.from_documents( + documents, + service_context=service_context, + vector_store=FaissVectorStore() + ) + + # 保存为用户专属 .index 文件 + index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") + faiss.write_index(index.vector_store.index, index_path) + print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}") + +from scripts.permissions import get_user_allowed_indexes + +def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str: + index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index") + if not os.path.exists(index_path): + raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在") + + embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_NAME) + service_context = ServiceContext.from_defaults(embed_model=embed_model) + + # 加载索引 + vector_store = FaissVectorStore.from_persist_path(index_path) + index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context) + + retriever = index.as_retriever(similarity_top_k=top_k) + nodes = retriever.retrieve(question) + + # 构造 Prompt + context_str = "\n\n".join([n.get_content() for n in nodes]) + prompt_template = PromptTemplate( + "请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}" + ) + final_prompt = prompt_template.format( + context=context_str, + query=question, + ) + + print("[PROMPT构建完成]") + return final_prompt + +# 示例: +if __name__ == "__main__": + uid = "user_001" + build_user_index(uid) + prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?") + print("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n") + print(prompt) \ No newline at end of file diff --git a/scripts/update_index.py b/scripts/update_index.py new file mode 100644 index 0000000..03d611c --- /dev/null +++ b/scripts/update_index.py @@ -0,0 +1,23 @@ +import time +import os +from app.core.index import faiss_index +from app.core.config import settings + +def watch_and_reload(interval=600): + print("[HOT-RELOAD] Watching index file for updates...") + last_mtime = None + while True: + try: + mtime = os.path.getmtime(settings.INDEX_FILE) + if last_mtime is None: + last_mtime = mtime + elif mtime != last_mtime: + print("[HOT-RELOAD] Detected new index file. Reloading...") + faiss_index.update_index(settings.INDEX_FILE) + last_mtime = mtime + except Exception as e: + print(f"[HOT-RELOAD] Error: {e}") + time.sleep(interval) + +if __name__ == "__main__": + watch_and_reload() \ No newline at end of file