This commit is contained in:
parent
1ff999085f
commit
8b15c7553c
|
|
@ -12,27 +12,39 @@ from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from scripts.permissions import get_user_allowed_indexes
|
from scripts.permissions import get_user_allowed_indexes
|
||||||
import faiss
|
import faiss
|
||||||
|
from typing import List
|
||||||
|
import asyncio
|
||||||
|
|
||||||
USER_INDEX_PATH = "index_data"
|
USER_INDEX_PATH = "index_data"
|
||||||
USER_DOC_PATH = "docs"
|
USER_DOC_PATH = "docs"
|
||||||
|
|
||||||
|
|
||||||
class CustomEmbedding(BaseEmbedding):
|
class CustomEmbedding(BaseEmbedding):
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str):
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
self.model = SentenceTransformer(model_name)
|
self.model = SentenceTransformer(model_name)
|
||||||
|
|
||||||
def _get_text_embedding(self, text: str) -> list[float]:
|
# 同步方法(必须实现)
|
||||||
|
def _get_text_embedding(self, text: str) -> List[float]:
|
||||||
return self.model.encode(text).tolist()
|
return self.model.encode(text).tolist()
|
||||||
|
|
||||||
def _get_query_embedding(self, query: str) -> list[float]:
|
def _get_query_embedding(self, query: str) -> List[float]:
|
||||||
return self.model.encode(query).tolist()
|
return self.model.encode(query).tolist()
|
||||||
|
|
||||||
def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]:
|
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||||
return self.model.encode(texts).tolist()
|
return self.model.encode(texts).tolist()
|
||||||
|
|
||||||
def _get_query_embeddings(self, queries: list[str]) -> list[list[float]]:
|
def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
|
||||||
return self.model.encode(queries).tolist()
|
return self.model.encode(queries).tolist()
|
||||||
|
|
||||||
|
# 异步方法(必须实现,哪怕用同步方式包起来)
|
||||||
|
async def _aget_query_embedding(self, query: str) -> List[float]:
|
||||||
|
return self._get_query_embedding(query)
|
||||||
|
|
||||||
|
async def _aget_query_embeddings(self, queries: List[str]) -> List[List[float]]:
|
||||||
|
return self._get_query_embeddings(queries)
|
||||||
|
|
||||||
|
|
||||||
def build_user_index(user_id: str):
|
def build_user_index(user_id: str):
|
||||||
doc_dir = os.path.join(USER_DOC_PATH, user_id)
|
doc_dir = os.path.join(USER_DOC_PATH, user_id)
|
||||||
if not os.path.exists(doc_dir):
|
if not os.path.exists(doc_dir):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue