faiss_rag_enterprise/scripts/build_index.py

40 lines
1.1 KiB
Python

import os
import faiss
import numpy as np
import sys
from app.core.embedding import embedder
def load_documents(doc_folder):
texts = []
for fname in os.listdir(doc_folder):
path = os.path.join(doc_folder, fname)
if os.path.isfile(path):
with open(path, "r", encoding="utf-8") as f:
texts.append(f.read())
return texts
def build_faiss_index(docs, dim):
vectors = embedder.encode(docs)
index = faiss.IndexFlatIP(dim)
index.add(vectors)
return index
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python -m scripts.build_index <user_id>")
sys.exit(1)
user_id = sys.argv[1]
doc_dir = os.path.join("docs", user_id)
index_path = os.path.join("index_data", f"{user_id}.index")
print(f"[BUILD] Loading documents from {doc_dir} ...")
docs = load_documents(doc_dir)
print(f"[BUILD] Loaded {len(docs)} documents")
print("[BUILD] Building FAISS index...")
index = build_faiss_index(docs, dim=768) # 或根据模型配置动态设定
print(f"[BUILD] Saving index to {index_path}")
faiss.write_index(index, index_path)