first commit!
This commit is contained in:
commit
bb65afab1f
|
|
@ -0,0 +1,18 @@
|
||||||
|
FROM python:3.10-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY . /app
|
||||||
|
|
||||||
|
# 安装系统依赖(必要时)
|
||||||
|
RUN apt-get update && apt-get install -y build-essential && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# 安装 Python 依赖
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# 声明服务监听端口(供 Kubernetes 或 docker run -P 使用)
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# 启动服务(可指定绑定 IP 和端口)
|
||||||
|
CMD ["gunicorn", "app.main:app", "-k", "uvicorn.workers.UvicornWorker", "-c", "gunicorn_config.py"]
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
# FAISS Enterprise RAG
|
||||||
|
|
||||||
|
企业级部署方案,支持千万向量、本地高并发、热更新、容器部署。
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import numpy as np
|
||||||
|
from app.core.index_manager import UserIndexManager
|
||||||
|
from app.core.embedding import embedder
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
index_manager = UserIndexManager()
|
||||||
|
|
||||||
|
class QueryRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
|
||||||
|
@router.post("/search")
|
||||||
|
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
|
||||||
|
try:
|
||||||
|
query_vector = embedder.encode([request.query])
|
||||||
|
index = index_manager.get_index(user_id)
|
||||||
|
D, I = index.search(query_vector, settings.TOP_K)
|
||||||
|
return {
|
||||||
|
"user_id": user_id,
|
||||||
|
"query": request.query,
|
||||||
|
"results": [{"id": int(idx), "score": float(dist)} for dist, idx in zip(D, I)]
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
from fastapi import APIRouter, UploadFile, File, Form
|
||||||
|
import os
|
||||||
|
from shutil import copyfileobj
|
||||||
|
from scripts.rag_build_query import build_user_index
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.post("/upload")
|
||||||
|
def upload_user_file(user_id: str = Form(...), file: UploadFile = File(...)):
|
||||||
|
user_doc_dir = os.path.join("docs", user_id)
|
||||||
|
os.makedirs(user_doc_dir, exist_ok=True)
|
||||||
|
|
||||||
|
file_path = os.path.join(user_doc_dir, file.filename)
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
copyfileobj(file.file, f)
|
||||||
|
|
||||||
|
print(f"[UPLOAD] 文件已保存至 {file_path}")
|
||||||
|
|
||||||
|
# 自动重建用户索引
|
||||||
|
build_user_index(user_id)
|
||||||
|
print(f"[UPLOAD] 用户 {user_id} 的索引已重建")
|
||||||
|
|
||||||
|
return {"status": "ok", "filename": file.filename}
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
from pydantic import BaseSettings
|
||||||
|
import os
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
INDEX_FILE: str = os.getenv("INDEX_FILE", "index_data/index.faiss")
|
||||||
|
EMBEDDING_DIM: int = 768
|
||||||
|
TOP_K: int = 5
|
||||||
|
DOC_PATH: str = "docs/"
|
||||||
|
DEVICE: str = "cpu" # 可设置为 cuda:0
|
||||||
|
MODEL_NAME: str = "BAAI/bge-m3"
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
class BGEEmbedding:
|
||||||
|
def __init__(self):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
|
||||||
|
self.model = AutoModel.from_pretrained(settings.MODEL_NAME)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def encode(self, texts):
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
|
||||||
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
||||||
|
return embeddings.cpu().numpy()
|
||||||
|
|
||||||
|
def mean_pooling(self, model_output, attention_mask):
|
||||||
|
token_embeddings = model_output[0]
|
||||||
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||||
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||||
|
|
||||||
|
embedder = BGEEmbedding()
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
import threading
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
class FaissIndexWrapper:
|
||||||
|
def __init__(self):
|
||||||
|
self.index_lock = threading.Lock()
|
||||||
|
self.index = None
|
||||||
|
self.load_index(settings.INDEX_FILE)
|
||||||
|
|
||||||
|
def load_index(self, path):
|
||||||
|
with self.index_lock:
|
||||||
|
self.index = faiss.read_index(path)
|
||||||
|
print(f"[FAISS] Index loaded from {path}")
|
||||||
|
|
||||||
|
def update_index(self, path):
|
||||||
|
"""热更新替换当前索引"""
|
||||||
|
with self.index_lock:
|
||||||
|
new_index = faiss.read_index(path)
|
||||||
|
self.index = new_index
|
||||||
|
print(f"[FAISS] Index hot-swapped from {path}")
|
||||||
|
|
||||||
|
def search(self, vector: np.ndarray, top_k: int = 5):
|
||||||
|
with self.index_lock:
|
||||||
|
D, I = self.index.search(vector.astype(np.float32), top_k)
|
||||||
|
return D[0], I[0]
|
||||||
|
|
||||||
|
# 单例,供 API 层引用
|
||||||
|
faiss_index = FaissIndexWrapper()
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
import os
|
||||||
|
import faiss
|
||||||
|
|
||||||
|
class UserIndexManager:
|
||||||
|
def __init__(self, base_dir="index_data/"):
|
||||||
|
self.base_dir = base_dir
|
||||||
|
self.index_map = {} # user_id -> faiss index
|
||||||
|
self._load_all_indexes()
|
||||||
|
|
||||||
|
def _load_all_indexes(self):
|
||||||
|
for fname in os.listdir(self.base_dir):
|
||||||
|
if fname.endswith(".index"):
|
||||||
|
user_id = fname.replace(".index", "")
|
||||||
|
path = os.path.join(self.base_dir, fname)
|
||||||
|
self.index_map[user_id] = faiss.read_index(path)
|
||||||
|
print(f"[INIT] Loaded index for user {user_id}")
|
||||||
|
|
||||||
|
def get_index(self, user_id):
|
||||||
|
if user_id not in self.index_map:
|
||||||
|
raise ValueError(f"Index for user {user_id} not loaded.")
|
||||||
|
return self.index_map[user_id]
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from app.api.search import router as search_router
|
||||||
|
from app.api.upload import router as upload_router
|
||||||
|
|
||||||
|
app = FastAPI(title="Enterprise FAISS RAG Server")
|
||||||
|
|
||||||
|
app.include_router(search_router, prefix="/api", tags=["search"])
|
||||||
|
app.include_router(upload_router, prefix="/api", tags=["upload"])
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
workers = 12 # 可按CPU核心调整
|
||||||
|
timeout = 60
|
||||||
|
preload_app = True
|
||||||
|
bind = "0.0.0.0:8000"
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
faiss-cpu
|
||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
gunicorn
|
||||||
|
pydantic
|
||||||
|
numpy
|
||||||
|
transformers
|
||||||
|
torch
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
import os
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
from app.core.embedding import embedder
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
def load_documents(doc_folder):
|
||||||
|
texts = []
|
||||||
|
for fname in os.listdir(doc_folder):
|
||||||
|
path = os.path.join(doc_folder, fname)
|
||||||
|
if os.path.isfile(path):
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
texts.append(f.read())
|
||||||
|
return texts
|
||||||
|
|
||||||
|
def build_faiss_index(docs, dim):
|
||||||
|
vectors = embedder.encode(docs)
|
||||||
|
index = faiss.IndexFlatIP(dim)
|
||||||
|
index.add(vectors)
|
||||||
|
return index
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("[BUILD] Loading documents...")
|
||||||
|
docs = load_documents(settings.DOC_PATH)
|
||||||
|
print(f"[BUILD] Loaded {len(docs)} documents")
|
||||||
|
|
||||||
|
print("[BUILD] Building FAISS index...")
|
||||||
|
index = build_faiss_index(docs, settings.EMBEDDING_DIM)
|
||||||
|
|
||||||
|
print(f"[BUILD] Saving index to {settings.INDEX_FILE}")
|
||||||
|
faiss.write_index(index, settings.INDEX_FILE)
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
{
|
||||||
|
"user_001": ["common.index", "engineering.index"],
|
||||||
|
"user_002": ["common.index", "hr.index"],
|
||||||
|
"user_003": []
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
PERMISSION_PATH = "scripts/index_permissions.json"
|
||||||
|
|
||||||
|
def get_user_allowed_indexes(user_id: str):
|
||||||
|
try:
|
||||||
|
with open(PERMISSION_PATH, "r", encoding="utf-8") as f:
|
||||||
|
permission_map = json.load(f)
|
||||||
|
return permission_map.get(user_id, [])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Permission Error] {e}")
|
||||||
|
return []
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
import os
|
||||||
|
import faiss
|
||||||
|
from typing import List
|
||||||
|
from llama_index import (
|
||||||
|
SimpleDirectoryReader,
|
||||||
|
VectorStoreIndex,
|
||||||
|
ServiceContext,
|
||||||
|
PromptTemplate,
|
||||||
|
)
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
from llama_index.llms.base import ChatMessage
|
||||||
|
|
||||||
|
# 假设你要用的本地嵌入模型
|
||||||
|
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
USER_INDEX_PATH = "index_data"
|
||||||
|
USER_DOC_PATH = "docs"
|
||||||
|
|
||||||
|
def build_user_index(user_id: str):
|
||||||
|
doc_dir = os.path.join(USER_DOC_PATH, user_id)
|
||||||
|
if not os.path.exists(doc_dir):
|
||||||
|
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
|
||||||
|
|
||||||
|
documents = SimpleDirectoryReader(doc_dir).load_data()
|
||||||
|
embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_NAME)
|
||||||
|
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||||
|
|
||||||
|
# 构建向量索引
|
||||||
|
index = VectorStoreIndex.from_documents(
|
||||||
|
documents,
|
||||||
|
service_context=service_context,
|
||||||
|
vector_store=FaissVectorStore()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存为用户专属 .index 文件
|
||||||
|
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
||||||
|
faiss.write_index(index.vector_store.index, index_path)
|
||||||
|
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
|
||||||
|
|
||||||
|
from scripts.permissions import get_user_allowed_indexes
|
||||||
|
|
||||||
|
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
|
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
||||||
|
if not os.path.exists(index_path):
|
||||||
|
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
|
||||||
|
|
||||||
|
embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_NAME)
|
||||||
|
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||||
|
|
||||||
|
# 加载索引
|
||||||
|
vector_store = FaissVectorStore.from_persist_path(index_path)
|
||||||
|
index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context)
|
||||||
|
|
||||||
|
retriever = index.as_retriever(similarity_top_k=top_k)
|
||||||
|
nodes = retriever.retrieve(question)
|
||||||
|
|
||||||
|
# 构造 Prompt
|
||||||
|
context_str = "\n\n".join([n.get_content() for n in nodes])
|
||||||
|
prompt_template = PromptTemplate(
|
||||||
|
"请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}"
|
||||||
|
)
|
||||||
|
final_prompt = prompt_template.format(
|
||||||
|
context=context_str,
|
||||||
|
query=question,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("[PROMPT构建完成]")
|
||||||
|
return final_prompt
|
||||||
|
|
||||||
|
# 示例:
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uid = "user_001"
|
||||||
|
build_user_index(uid)
|
||||||
|
prompt = query_user_rag(uid, "这份资料中提到了哪些关键点?")
|
||||||
|
print("\n------ 最终构建的 Prompt 给 LLM 使用 ------\n")
|
||||||
|
print(prompt)
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
from app.core.index import faiss_index
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
def watch_and_reload(interval=600):
|
||||||
|
print("[HOT-RELOAD] Watching index file for updates...")
|
||||||
|
last_mtime = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
mtime = os.path.getmtime(settings.INDEX_FILE)
|
||||||
|
if last_mtime is None:
|
||||||
|
last_mtime = mtime
|
||||||
|
elif mtime != last_mtime:
|
||||||
|
print("[HOT-RELOAD] Detected new index file. Reloading...")
|
||||||
|
faiss_index.update_index(settings.INDEX_FILE)
|
||||||
|
last_mtime = mtime
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[HOT-RELOAD] Error: {e}")
|
||||||
|
time.sleep(interval)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
watch_and_reload()
|
||||||
Loading…
Reference in New Issue