43 lines
1.8 KiB
Python
43 lines
1.8 KiB
Python
from typing import Dict, Type
|
|
|
|
from llama_index.embeddings.base import BaseEmbedding
|
|
from llama_index.embeddings.google import GoogleUnivSentEncoderEmbedding
|
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
|
from llama_index.embeddings.langchain import LangchainEmbedding
|
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
|
from llama_index.embeddings.text_embeddings_inference import TextEmbeddingsInference
|
|
from llama_index.embeddings.utils import resolve_embed_model
|
|
from llama_index.token_counter.mock_embed_model import MockEmbedding
|
|
|
|
RECOGNIZED_EMBEDDINGS: Dict[str, Type[BaseEmbedding]] = {
|
|
GoogleUnivSentEncoderEmbedding.class_name(): GoogleUnivSentEncoderEmbedding,
|
|
OpenAIEmbedding.class_name(): OpenAIEmbedding,
|
|
LangchainEmbedding.class_name(): LangchainEmbedding,
|
|
MockEmbedding.class_name(): MockEmbedding,
|
|
HuggingFaceEmbedding.class_name(): HuggingFaceEmbedding,
|
|
TextEmbeddingsInference.class_name(): TextEmbeddingsInference,
|
|
OpenAIEmbedding.class_name(): OpenAIEmbedding,
|
|
}
|
|
|
|
|
|
def load_embed_model(data: dict) -> BaseEmbedding:
|
|
"""Load Embedding by name."""
|
|
if isinstance(data, BaseEmbedding):
|
|
return data
|
|
name = data.get("class_name", None)
|
|
if name is None:
|
|
raise ValueError("Embedding loading requires a class_name")
|
|
if name not in RECOGNIZED_EMBEDDINGS:
|
|
raise ValueError(f"Invalid Embedding name: {name}")
|
|
|
|
# special handling for LangchainEmbedding
|
|
# it can be any local model technially
|
|
if name == LangchainEmbedding.class_name():
|
|
local_name = data.get("model_name", None)
|
|
if local_name is not None:
|
|
return resolve_embed_model("local:" + local_name)
|
|
else:
|
|
raise ValueError("LangchainEmbedding requires a model_name")
|
|
|
|
return RECOGNIZED_EMBEDDINGS[name].from_dict(data)
|