This commit is contained in:
parent
727305fbd4
commit
31458b7ac7
|
|
@ -1,4 +1,8 @@
|
|||
import os
|
||||
from typing import List
|
||||
import asyncio
|
||||
import faiss
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from llama_index import (
|
||||
SimpleDirectoryReader,
|
||||
|
|
@ -7,47 +11,23 @@ from llama_index import (
|
|||
PromptTemplate,
|
||||
ServiceContext,
|
||||
)
|
||||
from llama_index.embeddings.base import BaseEmbedding
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
from app.core.config import settings
|
||||
from scripts.permissions import get_user_allowed_indexes
|
||||
import faiss
|
||||
from typing import List
|
||||
import asyncio
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
|
||||
USER_INDEX_PATH = "index_data"
|
||||
USER_DOC_PATH = "docs"
|
||||
|
||||
|
||||
class CustomEmbedding(BaseEmbedding):
|
||||
model: SentenceTransformer # ✅ 显式声明是字段
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer(model_name)
|
||||
|
||||
# 同步方法(必须实现)
|
||||
def _get_text_embedding(self, text: str) -> List[float]:
|
||||
return self.model.encode(text).tolist()
|
||||
|
||||
# ✅ 替代 CustomEmbedding,用于 bge-m3 模型,自动加前缀
|
||||
class BGEEmbedding(HuggingFaceEmbedding):
|
||||
def _get_query_embedding(self, query: str) -> List[float]:
|
||||
return self.model.encode(query).tolist()
|
||||
|
||||
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.model.encode(texts).tolist()
|
||||
prefix = "Represent this sentence for searching relevant passages: "
|
||||
return super()._get_query_embedding(prefix + query)
|
||||
|
||||
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
|
||||
return self.model.encode(queries).tolist()
|
||||
|
||||
# 异步方法(必须实现,哪怕用同步方式包起来)
|
||||
async def _aget_query_embedding(self, query: str) -> List[float]:
|
||||
return self._get_query_embedding(query)
|
||||
|
||||
async def _aget_query_embeddings(self, queries: List[str]) -> List[List[float]]:
|
||||
return self._get_query_embeddings(queries)
|
||||
|
||||
prefix = "Represent this sentence for searching relevant passages: "
|
||||
return super()._get_query_embeddings([prefix + q for q in queries])
|
||||
|
||||
def build_user_index(user_id: str):
|
||||
doc_dir = os.path.join(USER_DOC_PATH, user_id)
|
||||
|
|
@ -55,13 +35,17 @@ def build_user_index(user_id: str):
|
|||
raise FileNotFoundError(f"文档目录不存在: {doc_dir}")
|
||||
|
||||
documents = SimpleDirectoryReader(doc_dir).load_data()
|
||||
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
||||
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
|
||||
|
||||
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
|
||||
|
||||
# ✅ 指定正确维度:bge-m3 是 1024
|
||||
faiss_index = faiss.IndexFlatL2(1024)
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
vector_store=FaissVectorStore(),
|
||||
vector_store=vector_store,
|
||||
service_context=service_context
|
||||
)
|
||||
|
||||
|
|
@ -70,7 +54,7 @@ def build_user_index(user_id: str):
|
|||
print(f"[BUILD] 为用户 {user_id} 构建并保存了索引 → {index_path}")
|
||||
|
||||
def query_user_rag(user_id: str, question: str, top_k: int = 4) -> str:
|
||||
embed_model = HuggingFaceEmbedding(model_name=settings.MODEL_NAME)
|
||||
embed_model = BGEEmbedding(model_name=settings.MODEL_NAME)
|
||||
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm=None)
|
||||
|
||||
all_nodes = []
|
||||
|
|
|
|||
Loading…
Reference in New Issue