faiss_rag_enterprise/llama_index/embeddings/cohereai.py

161 lines
6.1 KiB
Python

from enum import Enum
from typing import Any, List, Optional
from llama_index.bridge.pydantic import Field
from llama_index.callbacks import CallbackManager
from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding
# Enums for validation and type safety
class CohereAIModelName(str, Enum):
ENGLISH_V3 = "embed-english-v3.0"
ENGLISH_LIGHT_V3 = "embed-english-light-v3.0"
MULTILINGUAL_V3 = "embed-multilingual-v3.0"
MULTILINGUAL_LIGHT_V3 = "embed-multilingual-light-v3.0"
ENGLISH_V2 = "embed-english-v2.0"
ENGLISH_LIGHT_V2 = "embed-english-light-v2.0"
MULTILINGUAL_V2 = "embed-multilingual-v2.0"
class CohereAIInputType(str, Enum):
SEARCH_QUERY = "search_query"
SEARCH_DOCUMENT = "search_document"
CLASSIFICATION = "classification"
CLUSTERING = "clustering"
class CohereAITruncate(str, Enum):
START = "START"
END = "END"
NONE = "NONE"
# convenient shorthand
CAMN = CohereAIModelName
CAIT = CohereAIInputType
CAT = CohereAITruncate
# This list would be used for model name and input type validation
VALID_MODEL_INPUT_TYPES = [
(CAMN.ENGLISH_V3, CAIT.SEARCH_QUERY),
(CAMN.ENGLISH_LIGHT_V3, CAIT.SEARCH_QUERY),
(CAMN.MULTILINGUAL_V3, CAIT.SEARCH_QUERY),
(CAMN.MULTILINGUAL_LIGHT_V3, CAIT.SEARCH_QUERY),
(CAMN.ENGLISH_V3, CAIT.SEARCH_DOCUMENT),
(CAMN.ENGLISH_LIGHT_V3, CAIT.SEARCH_DOCUMENT),
(CAMN.MULTILINGUAL_V3, CAIT.SEARCH_DOCUMENT),
(CAMN.MULTILINGUAL_LIGHT_V3, CAIT.SEARCH_DOCUMENT),
(CAMN.ENGLISH_V3, CAIT.CLASSIFICATION),
(CAMN.ENGLISH_LIGHT_V3, CAIT.CLASSIFICATION),
(CAMN.MULTILINGUAL_V3, CAIT.CLASSIFICATION),
(CAMN.MULTILINGUAL_LIGHT_V3, CAIT.CLASSIFICATION),
(CAMN.ENGLISH_V3, CAIT.CLUSTERING),
(CAMN.ENGLISH_LIGHT_V3, CAIT.CLUSTERING),
(CAMN.MULTILINGUAL_V3, CAIT.CLUSTERING),
(CAMN.MULTILINGUAL_LIGHT_V3, CAIT.CLUSTERING),
(CAMN.ENGLISH_V2, None),
(CAMN.ENGLISH_LIGHT_V2, None),
(CAMN.MULTILINGUAL_V2, None),
]
VALID_TRUNCATE_OPTIONS = [CAT.START, CAT.END, CAT.NONE]
# Assuming BaseEmbedding is a Pydantic model and handles its own initializations
class CohereEmbedding(BaseEmbedding):
"""CohereEmbedding uses the Cohere API to generate embeddings for text."""
# Instance variables initialized via Pydantic's mechanism
cohere_client: Any = Field(description="CohereAI client")
truncate: str = Field(description="Truncation type - START/ END/ NONE")
input_type: Optional[str] = Field(description="Model Input type")
def __init__(
self,
cohere_api_key: Optional[str] = None,
model_name: str = "embed-english-v2.0",
truncate: str = "END",
input_type: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
):
"""
A class representation for generating embeddings using the Cohere API.
Args:
cohere_client (Any): An instance of the Cohere client, which is used to communicate with the Cohere API.
truncate (str): A string indicating the truncation strategy to be applied to input text. Possible values
are 'START', 'END', or 'NONE'.
input_type (Optional[str]): An optional string that specifies the type of input provided to the model.
This is model-dependent and could be one of the following: 'search_query',
'search_document', 'classification', or 'clustering'.
model_name (str): The name of the model to be used for generating embeddings. The class ensures that
this model is supported and that the input type provided is compatible with the model.
"""
# Attempt to import cohere. If it fails, raise an informative ImportError.
try:
import cohere
except ImportError:
raise ImportError(
"CohereEmbedding requires the 'cohere' package to be installed.\n"
"Please install it with `pip install cohere`."
)
# Validate model_name and input_type
if (model_name, input_type) not in VALID_MODEL_INPUT_TYPES:
raise ValueError(
f"{(model_name, input_type)} is not valid for model '{model_name}'"
)
if truncate not in VALID_TRUNCATE_OPTIONS:
raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}")
super().__init__(
cohere_client=cohere.Client(cohere_api_key, client_name="llama_index"),
cohere_api_key=cohere_api_key,
model_name=model_name,
truncate=truncate,
input_type=input_type,
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
)
@classmethod
def class_name(cls) -> str:
return "CohereEmbedding"
def _embed(self, texts: List[str]) -> List[List[float]]:
"""Embed sentences using Cohere."""
if self.input_type:
result = self.cohere_client.embed(
texts=texts,
input_type=self.input_type,
model=self.model_name,
truncate=self.truncate,
).embeddings
else:
result = self.cohere_client.embed(
texts=texts, model=self.model_name, truncate=self.truncate
).embeddings
return [list(map(float, e)) for e in result]
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._embed([query])[0]
async def _aget_query_embedding(self, query: str) -> List[float]:
"""Get query embedding async."""
return self._get_query_embedding(query)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._embed([text])[0]
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Get text embedding async."""
return self._get_text_embedding(text)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
return self._embed(texts)