first commit!

This commit is contained in:
hailin 2025-05-08 15:12:46 +08:00
commit bb65afab1f
16 changed files with 326 additions and 0 deletions

18
Dockerfile Normal file
View File

@ -0,0 +1,18 @@
FROM python:3.10-slim
WORKDIR /app
COPY . /app
# 安装系统依赖(必要时)
RUN apt-get update && apt-get install -y build-essential && \
rm -rf /var/lib/apt/lists/*
# 安装 Python 依赖
RUN pip install --no-cache-dir -r requirements.txt
# 声明服务监听端口(供 Kubernetes 或 docker run -P 使用)
EXPOSE 8000
# 启动服务(可指定绑定 IP 和端口)
CMD ["gunicorn", "app.main:app", "-k", "uvicorn.workers.UvicornWorker", "-c", "gunicorn_config.py"]

3
README.md Normal file
View File

@ -0,0 +1,3 @@
# FAISS Enterprise RAG
企业级部署方案,支持千万向量、本地高并发、热更新、容器部署。

26
app/api/search.py Normal file
View File

@ -0,0 +1,26 @@
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
import numpy as np
from app.core.index_manager import UserIndexManager
from app.core.embedding import embedder
from app.core.config import settings
router = APIRouter()
index_manager = UserIndexManager()
class QueryRequest(BaseModel):
query: str
@router.post("/search")
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
try:
query_vector = embedder.encode([request.query])
index = index_manager.get_index(user_id)
D, I = index.search(query_vector, settings.TOP_K)
return {
"user_id": user_id,
"query": request.query,
"results": [{"id": int(idx), "score": float(dist)} for dist, idx in zip(D, I)]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

23
app/api/upload.py Normal file
View File

@ -0,0 +1,23 @@
from fastapi import APIRouter, UploadFile, File, Form
import os
from shutil import copyfileobj
from scripts.rag_build_query import build_user_index
router = APIRouter()
@router.post("/upload")
def upload_user_file(user_id: str = Form(...), file: UploadFile = File(...)):
user_doc_dir = os.path.join("docs", user_id)
os.makedirs(user_doc_dir, exist_ok=True)
file_path = os.path.join(user_doc_dir, file.filename)
with open(file_path, "wb") as f:
copyfileobj(file.file, f)
print(f"[UPLOAD] 文件已保存至 {file_path}")
# 自动重建用户索引
build_user_index(user_id)
print(f"[UPLOAD] 用户 {user_id} 的索引已重建")
return {"status": "ok", "filename": file.filename}

12
app/core/config.py Normal file
View File

@ -0,0 +1,12 @@
from pydantic import BaseSettings
import os
class Settings(BaseSettings):
INDEX_FILE: str = os.getenv("INDEX_FILE", "index_data/index.faiss")
EMBEDDING_DIM: int = 768
TOP_K: int = 5
DOC_PATH: str = "docs/"
DEVICE: str = "cpu" # 可设置为 cuda:0
MODEL_NAME: str = "BAAI/bge-m3"
settings = Settings()

25
app/core/embedding.py Normal file
View File

@ -0,0 +1,25 @@
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
from app.core.config import settings
class BGEEmbedding:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
self.model = AutoModel.from_pretrained(settings.MODEL_NAME)
self.model.eval()
def encode(self, texts):
with torch.no_grad():
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
outputs = self.model(**inputs)
embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.cpu().numpy()
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
embedder = BGEEmbedding()

31
app/core/index.py Normal file
View File

@ -0,0 +1,31 @@
import faiss
import numpy as np
import threading
from app.core.config import settings
class FaissIndexWrapper:
def __init__(self):
self.index_lock = threading.Lock()
self.index = None
self.load_index(settings.INDEX_FILE)
def load_index(self, path):
with self.index_lock:
self.index = faiss.read_index(path)
print(f"[FAISS] Index loaded from {path}")
def update_index(self, path):
"""热更新替换当前索引"""
with self.index_lock:
new_index = faiss.read_index(path)
self.index = new_index
print(f"[FAISS] Index hot-swapped from {path}")
def search(self, vector: np.ndarray, top_k: int = 5):
with self.index_lock:
D, I = self.index.search(vector.astype(np.float32), top_k)
return D[0], I[0]
# 单例,供 API 层引用
faiss_index = FaissIndexWrapper()

21
app/core/index_manager.py Normal file
View File

@ -0,0 +1,21 @@
import os
import faiss
class UserIndexManager:
def __init__(self, base_dir="index_data/"):
self.base_dir = base_dir
self.index_map = {} # user_id -> faiss index
self._load_all_indexes()
def _load_all_indexes(self):
for fname in os.listdir(self.base_dir):
if fname.endswith(".index"):
user_id = fname.replace(".index", "")
path = os.path.join(self.base_dir, fname)
self.index_map[user_id] = faiss.read_index(path)
print(f"[INIT] Loaded index for user {user_id}")
def get_index(self, user_id):
if user_id not in self.index_map:
raise ValueError(f"Index for user {user_id} not loaded.")
return self.index_map[user_id]

8
app/main.py Normal file
View File

@ -0,0 +1,8 @@
from fastapi import FastAPI
from app.api.search import router as search_router
from app.api.upload import router as upload_router
app = FastAPI(title="Enterprise FAISS RAG Server")
app.include_router(search_router, prefix="/api", tags=["search"])
app.include_router(upload_router, prefix="/api", tags=["upload"])

4
gunicorn_config.py Normal file
View File

@ -0,0 +1,4 @@
workers = 12 # 可按CPU核心调整
timeout = 60
preload_app = True
bind = "0.0.0.0:8000"

8
requirements.txt Normal file
View File

@ -0,0 +1,8 @@
faiss-cpu
fastapi
uvicorn
gunicorn
pydantic
numpy
transformers
torch

31
scripts/build_index.py Normal file
View File

@ -0,0 +1,31 @@
import os
import faiss
import numpy as np
from app.core.embedding import embedder
from app.core.config import settings
def load_documents(doc_folder):
texts = []
for fname in os.listdir(doc_folder):
path = os.path.join(doc_folder, fname)
if os.path.isfile(path):
with open(path, "r", encoding="utf-8") as f:
texts.append(f.read())
return texts
def build_faiss_index(docs, dim):
vectors = embedder.encode(docs)
index = faiss.IndexFlatIP(dim)
index.add(vectors)
return index
if __name__ == "__main__":
print("[BUILD] Loading documents...")
docs = load_documents(settings.DOC_PATH)
print(f"[BUILD] Loaded {len(docs)} documents")
print("[BUILD] Building FAISS index...")
index = build_faiss_index(docs, settings.EMBEDDING_DIM)
print(f"[BUILD] Saving index to {settings.INDEX_FILE}")
faiss.write_index(index, settings.INDEX_FILE)

View File

@ -0,0 +1,5 @@
{
"user_001": ["common.index", "engineering.index"],
"user_002": ["common.index", "hr.index"],
"user_003": []
}

12
scripts/permissions.py Normal file
View File

@ -0,0 +1,12 @@
import json
PERMISSION_PATH = "scripts/index_permissions.json"
def get_user_allowed_indexes(user_id: str):
try:
with open(PERMISSION_PATH, "r", encoding="utf-8") as f:
permission_map = json.load(f)
return permission_map.get(user_id, [])
except Exception as e:
print(f"[Permission Error] {e}")
return []

View File

@ -0,0 +1,76 @@
import os
import faiss
from typing import List
from llama_index import (
SimpleDirectoryReader,
VectorStoreIndex,
ServiceContext,
PromptTemplate,
)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.llms.base import ChatMessage
# 假设你要用的本地嵌入模型
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
USER_INDEX_PATH = "index_data"
USER_DOC_PATH = "docs"
def build_user_index(user_id: str):
doc_dir = os.path.join(USER_DOC_PATH, user_id)
if not os.path.exists(doc_dir):
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
documents = SimpleDirectoryReader(doc_dir).load_data()
embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_NAME)
service_context = ServiceContext.from_defaults(embed_model=embed_model)
# 构建向量索引
index = VectorStoreIndex.from_documents(
documents,
service_context=service_context,
vector_store=FaissVectorStore()
)
# 保存为用户专属 .index 文件
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
faiss.write_index(index.vector_store.index, index_path)
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
from scripts.permissions import get_user_allowed_indexes
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
if not os.path.exists(index_path):
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_NAME)
service_context = ServiceContext.from_defaults(embed_model=embed_model)
# 加载索引
vector_store = FaissVectorStore.from_persist_path(index_path)
index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context)
retriever = index.as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(question)
# 构造 Prompt
context_str = "\n\n".join([n.get_content() for n in nodes])
prompt_template = PromptTemplate(
"请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}"
)
final_prompt = prompt_template.format(
context=context_str,
query=question,
)
print("[PROMPT构建完成]")
return final_prompt
# 示例:
if __name__ == "__main__":
uid = "user_001"
build_user_index(uid)
prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?")
print("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n")
print(prompt)

23
scripts/update_index.py Normal file
View File

@ -0,0 +1,23 @@
import time
import os
from app.core.index import faiss_index
from app.core.config import settings
def watch_and_reload(interval=600):
print("[HOT-RELOAD] Watching index file for updates...")
last_mtime = None
while True:
try:
mtime = os.path.getmtime(settings.INDEX_FILE)
if last_mtime is None:
last_mtime = mtime
elif mtime != last_mtime:
print("[HOT-RELOAD] Detected new index file. Reloading...")
faiss_index.update_index(settings.INDEX_FILE)
last_mtime = mtime
except Exception as e:
print(f"[HOT-RELOAD] Error: {e}")
time.sleep(interval)
if __name__ == "__main__":
watch_and_reload()