This commit is contained in:
parent
3a1fc39c48
commit
2d0a6c11ee
|
|
@ -1,12 +1,13 @@
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import numpy as np
|
|
||||||
from app.core.index_manager import UserIndexManager
|
|
||||||
from app.core.embedding import embedder
|
from app.core.embedding import embedder
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
from llama_index import VectorStoreIndex, ServiceContext
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
import os
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
index_manager = UserIndexManager()
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
|
|
@ -14,13 +15,27 @@ class QueryRequest(BaseModel):
|
||||||
@router.post("/search")
|
@router.post("/search")
|
||||||
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
|
def search_docs(request: QueryRequest, user_id: str = Query(..., description="用户ID")):
|
||||||
try:
|
try:
|
||||||
query_vector = embedder.encode([request.query])
|
index_path = os.path.join("index_data", f"{user_id}.index")
|
||||||
index = index_manager.get_index(user_id)
|
if not os.path.exists(index_path):
|
||||||
D, I = index.search(query_vector, settings.TOP_K)
|
raise HTTPException(status_code=404, detail="用户索引不存在")
|
||||||
|
|
||||||
|
# 构建 LlamaIndex 检索器
|
||||||
|
faiss_store = FaissVectorStore.from_persist_path(index_path)
|
||||||
|
service_context = ServiceContext.from_defaults(embed_model=embedder)
|
||||||
|
index = VectorStoreIndex.from_vector_store(faiss_store, service_context=service_context)
|
||||||
|
|
||||||
|
# 检索结果(真实文本)
|
||||||
|
retriever = index.as_retriever(similarity_top_k=settings.TOP_K)
|
||||||
|
nodes = retriever.retrieve(request.query)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"query": request.query,
|
"query": request.query,
|
||||||
"results": [{"id": int(idx), "score": float(dist)} for dist, idx in zip(D, I)]
|
"results": [
|
||||||
|
{"score": float(node.score or 0), "text": node.get_content()}
|
||||||
|
for node in nodes
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,32 @@
|
||||||
from fastapi import APIRouter, UploadFile, File, Form
|
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
|
||||||
import os
|
import os
|
||||||
from shutil import copyfileobj
|
import shutil
|
||||||
from scripts.rag_build_query import build_user_index
|
from scripts.rag_build_query import build_user_index
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
ALLOWED_SUFFIXES = {".txt", ".md", ".pdf", ".docx"}
|
||||||
|
|
||||||
@router.post("/upload")
|
@router.post("/upload")
|
||||||
def upload_user_file(user_id: str = Form(...), file: UploadFile = File(...)):
|
def upload_user_file(user_id: str = Form(...), file: UploadFile = File(...)):
|
||||||
|
filename = os.path.basename(file.filename)
|
||||||
|
suffix = os.path.splitext(filename)[-1].lower()
|
||||||
|
if suffix not in ALLOWED_SUFFIXES:
|
||||||
|
raise HTTPException(status_code=400, detail="不支持的文件类型")
|
||||||
|
|
||||||
user_doc_dir = os.path.join("docs", user_id)
|
user_doc_dir = os.path.join("docs", user_id)
|
||||||
os.makedirs(user_doc_dir, exist_ok=True)
|
os.makedirs(user_doc_dir, exist_ok=True)
|
||||||
|
|
||||||
file_path = os.path.join(user_doc_dir, file.filename)
|
file_path = os.path.join(user_doc_dir, filename)
|
||||||
with open(file_path, "wb") as f:
|
try:
|
||||||
copyfileobj(file.file, f)
|
with open(file_path, "wb") as f:
|
||||||
|
shutil.copyfileobj(file.file, f)
|
||||||
|
print(f"[UPLOAD] 文件已保存至 {file_path}")
|
||||||
|
|
||||||
print(f"[UPLOAD] 文件已保存至 {file_path}")
|
build_user_index(user_id)
|
||||||
|
print(f"[UPLOAD] 用户 {user_id} 的索引已重建")
|
||||||
|
|
||||||
# 自动重建用户索引
|
except Exception as e:
|
||||||
build_user_index(user_id)
|
print(f"[UPLOAD ERROR] {e}")
|
||||||
print(f"[UPLOAD] 用户 {user_id} 的索引已重建")
|
raise HTTPException(status_code=500, detail="索引构建失败")
|
||||||
|
|
||||||
return {"status": "ok", "filename": file.filename}
|
return {"status": "ok", "filename": filename}
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,10 @@
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
import os
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
INDEX_FILE: str = os.getenv("INDEX_FILE", "index_data/index.faiss")
|
EMBEDDING_DIM: int = 768 # 嵌入维度(取决于模型)
|
||||||
EMBEDDING_DIM: int = 768
|
TOP_K: int = 5 # 默认检索 top K 个段落
|
||||||
TOP_K: int = 5
|
DOC_PATH: str = "docs/" # 默认文档根目录
|
||||||
DOC_PATH: str = "docs/"
|
DEVICE: str = "cpu" # 可改为 "cuda" 使用 GPU
|
||||||
DEVICE: str = "cpu"
|
MODEL_NAME: str = "BAAI/bge-m3" # 多语种语义嵌入模型
|
||||||
MODEL_NAME: str = "BAAI/bge-m3"
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
|
||||||
|
|
@ -5,21 +5,26 @@ from app.core.config import settings
|
||||||
|
|
||||||
class BGEEmbedding:
|
class BGEEmbedding:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
self.device = torch.device(settings.DEVICE)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
|
self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
|
||||||
self.model = AutoModel.from_pretrained(settings.MODEL_NAME)
|
self.model = AutoModel.from_pretrained(settings.MODEL_NAME).to(self.device)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def encode(self, texts):
|
def encode(self, texts, batch_size=8):
|
||||||
with torch.no_grad():
|
all_embeddings = []
|
||||||
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
for i in range(0, len(texts), batch_size):
|
||||||
outputs = self.model(**inputs)
|
batch_texts = texts[i:i+batch_size]
|
||||||
embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
|
with torch.no_grad():
|
||||||
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
inputs = self.tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt").to(self.device)
|
||||||
return embeddings.cpu().numpy()
|
outputs = self.model(**inputs)
|
||||||
|
embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
|
||||||
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
||||||
|
all_embeddings.append(embeddings.cpu().numpy())
|
||||||
|
return np.vstack(all_embeddings)
|
||||||
|
|
||||||
def mean_pooling(self, model_output, attention_mask):
|
def mean_pooling(self, model_output, attention_mask):
|
||||||
token_embeddings = model_output[0]
|
token_embeddings = model_output[0]
|
||||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||||
|
|
||||||
embedder = BGEEmbedding()
|
embedder = BGEEmbedding()
|
||||||
|
|
|
||||||
|
|
@ -1,31 +1,28 @@
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import threading
|
import threading
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
class FaissIndexWrapper:
|
class FaissIndexWrapper:
|
||||||
def __init__(self):
|
def __init__(self, index_path: str):
|
||||||
self.index_lock = threading.Lock()
|
self.index_lock = threading.Lock()
|
||||||
self.index = None
|
self.index = None
|
||||||
self.load_index(settings.INDEX_FILE)
|
self.index_path = index_path
|
||||||
|
self.load_index(index_path)
|
||||||
|
|
||||||
def load_index(self, path):
|
def load_index(self, path: str = None):
|
||||||
with self.index_lock:
|
with self.index_lock:
|
||||||
|
path = path or self.index_path
|
||||||
self.index = faiss.read_index(path)
|
self.index = faiss.read_index(path)
|
||||||
print(f"[FAISS] Index loaded from {path}")
|
print(f"[FAISS] Index loaded from {path}")
|
||||||
|
|
||||||
def update_index(self, path):
|
def update_index(self, path: str = None):
|
||||||
"""热更新替换当前索引"""
|
"""热更新替换当前索引"""
|
||||||
with self.index_lock:
|
with self.index_lock:
|
||||||
new_index = faiss.read_index(path)
|
path = path or self.index_path
|
||||||
self.index = new_index
|
self.index = faiss.read_index(path)
|
||||||
print(f"[FAISS] Index hot-swapped from {path}")
|
print(f"[FAISS] Index hot-swapped from {path}")
|
||||||
|
|
||||||
def search(self, vector: np.ndarray, top_k: int = 5):
|
def search(self, vector: np.ndarray, top_k: int = 5):
|
||||||
with self.index_lock:
|
with self.index_lock:
|
||||||
D, I = self.index.search(vector.astype(np.float32), top_k)
|
D, I = self.index.search(vector.astype(np.float32), top_k)
|
||||||
return D[0], I[0]
|
return D[0], I[0]
|
||||||
|
|
||||||
# 单例,供 API 层引用
|
|
||||||
faiss_index = FaissIndexWrapper()
|
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,38 @@
|
||||||
import os
|
import os
|
||||||
import faiss
|
import faiss
|
||||||
|
import threading
|
||||||
|
|
||||||
class UserIndexManager:
|
class UserIndexManager:
|
||||||
def __init__(self, base_dir="index_data/"):
|
def __init__(self, base_dir="index_data/"):
|
||||||
self.base_dir = base_dir
|
self.base_dir = base_dir
|
||||||
self.index_map = {} # user_id -> faiss index
|
self.index_map = {} # user_id -> faiss.Index
|
||||||
|
self.index_lock = threading.Lock()
|
||||||
self._load_all_indexes()
|
self._load_all_indexes()
|
||||||
|
|
||||||
def _load_all_indexes(self):
|
def _load_all_indexes(self):
|
||||||
for fname in os.listdir(self.base_dir):
|
for fname in os.listdir(self.base_dir):
|
||||||
if fname.endswith(".index"):
|
if fname.endswith(".index"):
|
||||||
user_id = fname.replace(".index", "")
|
user_id = fname.replace(".index", "")
|
||||||
path = os.path.join(self.base_dir, fname)
|
self._load_single_index(user_id)
|
||||||
self.index_map[user_id] = faiss.read_index(path)
|
|
||||||
print(f"[INIT] Loaded index for user {user_id}")
|
def _load_single_index(self, user_id):
|
||||||
|
path = os.path.join(self.base_dir, f"{user_id}.index")
|
||||||
|
if os.path.exists(path):
|
||||||
|
index = faiss.read_index(path)
|
||||||
|
self.index_map[user_id] = index
|
||||||
|
print(f"[INIT] Loaded index for user: {user_id}")
|
||||||
|
|
||||||
def get_index(self, user_id):
|
def get_index(self, user_id):
|
||||||
if user_id not in self.index_map:
|
with self.index_lock:
|
||||||
raise ValueError(f"Index for user {user_id} not loaded.")
|
if user_id not in self.index_map:
|
||||||
return self.index_map[user_id]
|
raise FileNotFoundError(f"No FAISS index loaded for user: {user_id}")
|
||||||
|
return self.index_map[user_id]
|
||||||
|
|
||||||
|
def update_index(self, user_id):
|
||||||
|
"""重新加载用户索引文件"""
|
||||||
|
path = os.path.join(self.base_dir, f"{user_id}.index")
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise FileNotFoundError(f"No index file found for user: {user_id}")
|
||||||
|
with self.index_lock:
|
||||||
|
self.index_map[user_id] = faiss.read_index(path)
|
||||||
|
print(f"[UPDATE] Index for user {user_id} has been hot-reloaded.")
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import os
|
import os
|
||||||
import faiss
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import sys
|
||||||
from app.core.embedding import embedder
|
from app.core.embedding import embedder
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
def load_documents(doc_folder):
|
def load_documents(doc_folder):
|
||||||
texts = []
|
texts = []
|
||||||
|
|
@ -20,12 +20,20 @@ def build_faiss_index(docs, dim):
|
||||||
return index
|
return index
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("[BUILD] Loading documents...")
|
if len(sys.argv) < 2:
|
||||||
docs = load_documents(settings.DOC_PATH)
|
print("Usage: python -m scripts.build_index <user_id>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
user_id = sys.argv[1]
|
||||||
|
doc_dir = os.path.join("docs", user_id)
|
||||||
|
index_path = os.path.join("index_data", f"{user_id}.index")
|
||||||
|
|
||||||
|
print(f"[BUILD] Loading documents from {doc_dir} ...")
|
||||||
|
docs = load_documents(doc_dir)
|
||||||
print(f"[BUILD] Loaded {len(docs)} documents")
|
print(f"[BUILD] Loaded {len(docs)} documents")
|
||||||
|
|
||||||
print("[BUILD] Building FAISS index...")
|
print("[BUILD] Building FAISS index...")
|
||||||
index = build_faiss_index(docs, settings.EMBEDDING_DIM)
|
index = build_faiss_index(docs, dim=768) # 或根据模型配置动态设定
|
||||||
|
|
||||||
print(f"[BUILD] Saving index to {settings.INDEX_FILE}")
|
print(f"[BUILD] Saving index to {index_path}")
|
||||||
faiss.write_index(index, settings.INDEX_FILE)
|
faiss.write_index(index, index_path)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
{
|
{
|
||||||
"user_001": ["common.index", "engineering.index"],
|
"user_001": ["common.index", "engineering.index"], // 工程部员工
|
||||||
"user_002": ["common.index", "hr.index"],
|
"user_002": ["common.index", "hr.index"], // 人力资源员工
|
||||||
"user_003": []
|
"user_003": [] // 普通用户,无共享库
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,24 @@
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
PERMISSION_PATH = "scripts/index_permissions.json"
|
PERMISSION_PATH = "scripts/index_permissions.json"
|
||||||
|
|
||||||
def get_user_allowed_indexes(user_id: str):
|
def get_user_allowed_indexes(user_id: str):
|
||||||
|
if not os.path.exists(PERMISSION_PATH):
|
||||||
|
print(f"[Permission Warning] 权限配置文件不存在: {PERMISSION_PATH}")
|
||||||
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(PERMISSION_PATH, "r", encoding="utf-8") as f:
|
with open(PERMISSION_PATH, "r", encoding="utf-8") as f:
|
||||||
permission_map = json.load(f)
|
permission_map = json.load(f)
|
||||||
|
|
||||||
|
if user_id not in permission_map:
|
||||||
|
print(f"[Permission Info] 用户 {user_id} 无共享库配置,使用默认权限。")
|
||||||
return permission_map.get(user_id, [])
|
return permission_map.get(user_id, [])
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"[Permission Error] 权限配置文件 JSON 格式错误。")
|
||||||
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[Permission Error] {e}")
|
print(f"[Permission Error] 未知错误: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -24,43 +24,47 @@ def build_user_index(user_id: str):
|
||||||
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
||||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||||
|
|
||||||
# 构建向量索引
|
|
||||||
index = VectorStoreIndex.from_documents(
|
index = VectorStoreIndex.from_documents(
|
||||||
documents,
|
documents,
|
||||||
service_context=service_context,
|
service_context=service_context,
|
||||||
vector_store=FaissVectorStore()
|
vector_store=FaissVectorStore()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存为用户专属 .index 文件
|
|
||||||
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
||||||
faiss.write_index(index.vector_store.index, index_path)
|
faiss.write_index(index.vector_store.index, index_path)
|
||||||
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
|
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
|
||||||
|
|
||||||
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||||
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
|
||||||
if not os.path.exists(index_path):
|
|
||||||
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
|
|
||||||
|
|
||||||
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
||||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||||
|
|
||||||
# 加载主索引
|
all_nodes = []
|
||||||
vector_store = FaissVectorStore.from_persist_path(index_path)
|
|
||||||
index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context)
|
|
||||||
|
|
||||||
nodes = index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
|
||||||
|
|
||||||
# 加载权限范围内的共享索引
|
# 加载用户主索引
|
||||||
|
index_path = os.path.join(USER_INDEX_PATH, f"{user_id}.index")
|
||||||
|
if not os.path.exists(index_path):
|
||||||
|
raise FileNotFoundError(f"[ERROR] 用户 {user_id} 的索引不存在")
|
||||||
|
user_store = FaissVectorStore.from_persist_path(index_path)
|
||||||
|
user_index = VectorStoreIndex.from_vector_store(user_store, service_context=service_context)
|
||||||
|
all_nodes += user_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||||
|
|
||||||
|
# 加载共享索引
|
||||||
shared_indexes = get_user_allowed_indexes(user_id)
|
shared_indexes = get_user_allowed_indexes(user_id)
|
||||||
for shared_name in shared_indexes:
|
if shared_indexes:
|
||||||
shared_path = os.path.join(USER_INDEX_PATH, shared_name)
|
for shared_name in shared_indexes:
|
||||||
if os.path.exists(shared_path):
|
shared_path = os.path.join(USER_INDEX_PATH, shared_name)
|
||||||
shared_store = FaissVectorStore.from_persist_path(shared_path)
|
if os.path.exists(shared_path) and shared_path != index_path:
|
||||||
shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context)
|
shared_store = FaissVectorStore.from_persist_path(shared_path)
|
||||||
nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
shared_index = VectorStoreIndex.from_vector_store(shared_store, service_context=service_context)
|
||||||
|
all_nodes += shared_index.as_retriever(similarity_top_k=top_k).retrieve(question)
|
||||||
|
else:
|
||||||
|
print(f"[INFO] 用户 {user_id} 没有共享索引权限")
|
||||||
|
|
||||||
# 构造 Prompt
|
# 合并 + 按 score 排序
|
||||||
context_str = "\n\n".join([n.get_content() for n in nodes])
|
sorted_nodes = sorted(all_nodes, key=lambda n: -(n.score or 0))
|
||||||
|
top_nodes = sorted_nodes[:top_k]
|
||||||
|
|
||||||
|
context_str = "\n\n".join([n.get_content() for n in top_nodes])
|
||||||
prompt_template = PromptTemplate(
|
prompt_template = PromptTemplate(
|
||||||
"请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}"
|
"请根据以下内容回答用户问题:\n\n{context}\n\n问题:{query}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,30 @@
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from app.core.index import faiss_index
|
from app.core.index import faiss_index
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
def watch_and_reload(interval=600):
|
def watch_and_reload(user_id: str, interval=600):
|
||||||
print("[HOT-RELOAD] Watching index file for updates...")
|
index_file = os.path.join("index_data", f"{user_id}.index")
|
||||||
|
print(f"[HOT-RELOAD] Watching {index_file} for updates...")
|
||||||
last_mtime = None
|
last_mtime = None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
mtime = os.path.getmtime(settings.INDEX_FILE)
|
mtime = os.path.getmtime(index_file)
|
||||||
if last_mtime is None:
|
if last_mtime is None:
|
||||||
last_mtime = mtime
|
last_mtime = mtime
|
||||||
elif mtime != last_mtime:
|
elif mtime != last_mtime:
|
||||||
print("[HOT-RELOAD] Detected new index file. Reloading...")
|
print("[HOT-RELOAD] Detected new index file. Reloading...")
|
||||||
faiss_index.update_index(settings.INDEX_FILE)
|
faiss_index.update_index(index_file)
|
||||||
last_mtime = mtime
|
last_mtime = mtime
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[HOT-RELOAD] Error: {e}")
|
print(f"[HOT-RELOAD] Error: {e}")
|
||||||
time.sleep(interval)
|
time.sleep(interval)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
watch_and_reload()
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: python -m scripts.update_index <user_id>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
user_id = sys.argv[1]
|
||||||
|
watch_and_reload(user_id)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue