faiss_rag_enterprise/llama_index/embeddings/utils.py

96 lines
3.4 KiB
Python

"""Embedding utils for LlamaIndex."""
import os
from typing import TYPE_CHECKING, List, Optional, Union
if TYPE_CHECKING:
from llama_index.bridge.langchain import Embeddings as LCEmbeddings
from llama_index.embeddings.base import BaseEmbedding
from llama_index.embeddings.clip import ClipEmbedding
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.embeddings.huggingface_utils import (
INSTRUCTOR_MODELS,
)
from llama_index.embeddings.instructor import InstructorEmbedding
from llama_index.embeddings.langchain import LangchainEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai_utils import validate_openai_api_key
from llama_index.token_counter.mock_embed_model import MockEmbedding
from llama_index.utils import get_cache_dir
EmbedType = Union[BaseEmbedding, "LCEmbeddings", str]
def save_embedding(embedding: List[float], file_path: str) -> None:
"""Save embedding to file."""
with open(file_path, "w") as f:
f.write(",".join([str(x) for x in embedding]))
def load_embedding(file_path: str) -> List[float]:
"""Load embedding from file. Will only return first embedding in file."""
with open(file_path) as f:
for line in f:
embedding = [float(x) for x in line.strip().split(",")]
break
return embedding
def resolve_embed_model(embed_model: Optional[EmbedType] = None) -> BaseEmbedding:
"""Resolve embed model."""
try:
from llama_index.bridge.langchain import Embeddings as LCEmbeddings
except ImportError:
LCEmbeddings = None # type: ignore
if embed_model == "default":
try:
embed_model = OpenAIEmbedding()
validate_openai_api_key(embed_model.api_key)
except ValueError as e:
raise ValueError(
"\n******\n"
"Could not load OpenAI embedding model. "
"If you intended to use OpenAI, please check your OPENAI_API_KEY.\n"
"Original error:\n"
f"{e!s}"
"\nConsider using embed_model='local'.\n"
"Visit our documentation for more embedding options: "
"https://docs.llamaindex.ai/en/stable/module_guides/models/"
"embeddings.html#modules"
"\n******"
)
# for image embeddings
if embed_model == "clip":
embed_model = ClipEmbedding()
if isinstance(embed_model, str):
splits = embed_model.split(":", 1)
is_local = splits[0]
model_name = splits[1] if len(splits) > 1 else None
if is_local != "local":
raise ValueError(
"embed_model must start with str 'local' or of type BaseEmbedding"
)
cache_folder = os.path.join(get_cache_dir(), "models")
os.makedirs(cache_folder, exist_ok=True)
if model_name in INSTRUCTOR_MODELS:
embed_model = InstructorEmbedding(
model_name=model_name, cache_folder=cache_folder
)
else:
embed_model = HuggingFaceEmbedding(
model_name=model_name, cache_folder=cache_folder
)
if LCEmbeddings is not None and isinstance(embed_model, LCEmbeddings):
embed_model = LangchainEmbedding(embed_model)
if embed_model is None:
print("Embeddings have been explicitly disabled. Using MockEmbedding.")
embed_model = MockEmbedding(embed_dim=1)
return embed_model