faiss_rag_enterprise/llama_index/embeddings/loading.py

43 lines
1.8 KiB
Python

from typing import Dict, Type
from llama_index.embeddings.base import BaseEmbedding
from llama_index.embeddings.google import GoogleUnivSentEncoderEmbedding
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.embeddings.langchain import LangchainEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.text_embeddings_inference import TextEmbeddingsInference
from llama_index.embeddings.utils import resolve_embed_model
from llama_index.token_counter.mock_embed_model import MockEmbedding
RECOGNIZED_EMBEDDINGS: Dict[str, Type[BaseEmbedding]] = {
GoogleUnivSentEncoderEmbedding.class_name(): GoogleUnivSentEncoderEmbedding,
OpenAIEmbedding.class_name(): OpenAIEmbedding,
LangchainEmbedding.class_name(): LangchainEmbedding,
MockEmbedding.class_name(): MockEmbedding,
HuggingFaceEmbedding.class_name(): HuggingFaceEmbedding,
TextEmbeddingsInference.class_name(): TextEmbeddingsInference,
OpenAIEmbedding.class_name(): OpenAIEmbedding,
}
def load_embed_model(data: dict) -> BaseEmbedding:
"""Load Embedding by name."""
if isinstance(data, BaseEmbedding):
return data
name = data.get("class_name", None)
if name is None:
raise ValueError("Embedding loading requires a class_name")
if name not in RECOGNIZED_EMBEDDINGS:
raise ValueError(f"Invalid Embedding name: {name}")
# special handling for LangchainEmbedding
# it can be any local model technially
if name == LangchainEmbedding.class_name():
local_name = data.get("model_name", None)
if local_name is not None:
return resolve_embed_model("local:" + local_name)
else:
raise ValueError("LangchainEmbedding requires a model_name")
return RECOGNIZED_EMBEDDINGS[name].from_dict(data)