429 lines
14 KiB
Python
429 lines
14 KiB
Python
"""OpenAI embeddings file."""
|
|
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import httpx
|
|
from openai import AsyncOpenAI, OpenAI
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.embeddings.base import BaseEmbedding
|
|
from llama_index.llms.openai_utils import (
|
|
create_retry_decorator,
|
|
resolve_openai_credentials,
|
|
)
|
|
|
|
embedding_retry_decorator = create_retry_decorator(
|
|
max_retries=6,
|
|
random_exponential=True,
|
|
stop_after_delay_seconds=60,
|
|
min_seconds=1,
|
|
max_seconds=20,
|
|
)
|
|
|
|
|
|
class OpenAIEmbeddingMode(str, Enum):
|
|
"""OpenAI embedding mode."""
|
|
|
|
SIMILARITY_MODE = "similarity"
|
|
TEXT_SEARCH_MODE = "text_search"
|
|
|
|
|
|
class OpenAIEmbeddingModelType(str, Enum):
|
|
"""OpenAI embedding model type."""
|
|
|
|
DAVINCI = "davinci"
|
|
CURIE = "curie"
|
|
BABBAGE = "babbage"
|
|
ADA = "ada"
|
|
TEXT_EMBED_ADA_002 = "text-embedding-ada-002"
|
|
TEXT_EMBED_3_LARGE = "text-embedding-3-large"
|
|
TEXT_EMBED_3_SMALL = "text-embedding-3-small"
|
|
|
|
|
|
class OpenAIEmbeddingModeModel(str, Enum):
|
|
"""OpenAI embedding mode model."""
|
|
|
|
# davinci
|
|
TEXT_SIMILARITY_DAVINCI = "text-similarity-davinci-001"
|
|
TEXT_SEARCH_DAVINCI_QUERY = "text-search-davinci-query-001"
|
|
TEXT_SEARCH_DAVINCI_DOC = "text-search-davinci-doc-001"
|
|
|
|
# curie
|
|
TEXT_SIMILARITY_CURIE = "text-similarity-curie-001"
|
|
TEXT_SEARCH_CURIE_QUERY = "text-search-curie-query-001"
|
|
TEXT_SEARCH_CURIE_DOC = "text-search-curie-doc-001"
|
|
|
|
# babbage
|
|
TEXT_SIMILARITY_BABBAGE = "text-similarity-babbage-001"
|
|
TEXT_SEARCH_BABBAGE_QUERY = "text-search-babbage-query-001"
|
|
TEXT_SEARCH_BABBAGE_DOC = "text-search-babbage-doc-001"
|
|
|
|
# ada
|
|
TEXT_SIMILARITY_ADA = "text-similarity-ada-001"
|
|
TEXT_SEARCH_ADA_QUERY = "text-search-ada-query-001"
|
|
TEXT_SEARCH_ADA_DOC = "text-search-ada-doc-001"
|
|
|
|
# text-embedding-ada-002
|
|
TEXT_EMBED_ADA_002 = "text-embedding-ada-002"
|
|
|
|
# text-embedding-3-large
|
|
TEXT_EMBED_3_LARGE = "text-embedding-3-large"
|
|
|
|
# text-embedding-3-small
|
|
TEXT_EMBED_3_SMALL = "text-embedding-3-small"
|
|
|
|
|
|
# convenient shorthand
|
|
OAEM = OpenAIEmbeddingMode
|
|
OAEMT = OpenAIEmbeddingModelType
|
|
OAEMM = OpenAIEmbeddingModeModel
|
|
|
|
EMBED_MAX_TOKEN_LIMIT = 2048
|
|
|
|
|
|
_QUERY_MODE_MODEL_DICT = {
|
|
(OAEM.SIMILARITY_MODE, "davinci"): OAEMM.TEXT_SIMILARITY_DAVINCI,
|
|
(OAEM.SIMILARITY_MODE, "curie"): OAEMM.TEXT_SIMILARITY_CURIE,
|
|
(OAEM.SIMILARITY_MODE, "babbage"): OAEMM.TEXT_SIMILARITY_BABBAGE,
|
|
(OAEM.SIMILARITY_MODE, "ada"): OAEMM.TEXT_SIMILARITY_ADA,
|
|
(OAEM.SIMILARITY_MODE, "text-embedding-ada-002"): OAEMM.TEXT_EMBED_ADA_002,
|
|
(OAEM.SIMILARITY_MODE, "text-embedding-3-small"): OAEMM.TEXT_EMBED_3_SMALL,
|
|
(OAEM.SIMILARITY_MODE, "text-embedding-3-large"): OAEMM.TEXT_EMBED_3_LARGE,
|
|
(OAEM.TEXT_SEARCH_MODE, "davinci"): OAEMM.TEXT_SEARCH_DAVINCI_QUERY,
|
|
(OAEM.TEXT_SEARCH_MODE, "curie"): OAEMM.TEXT_SEARCH_CURIE_QUERY,
|
|
(OAEM.TEXT_SEARCH_MODE, "babbage"): OAEMM.TEXT_SEARCH_BABBAGE_QUERY,
|
|
(OAEM.TEXT_SEARCH_MODE, "ada"): OAEMM.TEXT_SEARCH_ADA_QUERY,
|
|
(OAEM.TEXT_SEARCH_MODE, "text-embedding-ada-002"): OAEMM.TEXT_EMBED_ADA_002,
|
|
(OAEM.TEXT_SEARCH_MODE, "text-embedding-3-large"): OAEMM.TEXT_EMBED_3_LARGE,
|
|
(OAEM.TEXT_SEARCH_MODE, "text-embedding-3-small"): OAEMM.TEXT_EMBED_3_SMALL,
|
|
}
|
|
|
|
_TEXT_MODE_MODEL_DICT = {
|
|
(OAEM.SIMILARITY_MODE, "davinci"): OAEMM.TEXT_SIMILARITY_DAVINCI,
|
|
(OAEM.SIMILARITY_MODE, "curie"): OAEMM.TEXT_SIMILARITY_CURIE,
|
|
(OAEM.SIMILARITY_MODE, "babbage"): OAEMM.TEXT_SIMILARITY_BABBAGE,
|
|
(OAEM.SIMILARITY_MODE, "ada"): OAEMM.TEXT_SIMILARITY_ADA,
|
|
(OAEM.SIMILARITY_MODE, "text-embedding-ada-002"): OAEMM.TEXT_EMBED_ADA_002,
|
|
(OAEM.SIMILARITY_MODE, "text-embedding-3-small"): OAEMM.TEXT_EMBED_3_SMALL,
|
|
(OAEM.SIMILARITY_MODE, "text-embedding-3-large"): OAEMM.TEXT_EMBED_3_LARGE,
|
|
(OAEM.TEXT_SEARCH_MODE, "davinci"): OAEMM.TEXT_SEARCH_DAVINCI_DOC,
|
|
(OAEM.TEXT_SEARCH_MODE, "curie"): OAEMM.TEXT_SEARCH_CURIE_DOC,
|
|
(OAEM.TEXT_SEARCH_MODE, "babbage"): OAEMM.TEXT_SEARCH_BABBAGE_DOC,
|
|
(OAEM.TEXT_SEARCH_MODE, "ada"): OAEMM.TEXT_SEARCH_ADA_DOC,
|
|
(OAEM.TEXT_SEARCH_MODE, "text-embedding-ada-002"): OAEMM.TEXT_EMBED_ADA_002,
|
|
(OAEM.TEXT_SEARCH_MODE, "text-embedding-3-large"): OAEMM.TEXT_EMBED_3_LARGE,
|
|
(OAEM.TEXT_SEARCH_MODE, "text-embedding-3-small"): OAEMM.TEXT_EMBED_3_SMALL,
|
|
}
|
|
|
|
|
|
@embedding_retry_decorator
|
|
def get_embedding(client: OpenAI, text: str, engine: str, **kwargs: Any) -> List[float]:
|
|
"""Get embedding.
|
|
|
|
NOTE: Copied from OpenAI's embedding utils:
|
|
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
|
|
|
Copied here to avoid importing unnecessary dependencies
|
|
like matplotlib, plotly, scipy, sklearn.
|
|
|
|
"""
|
|
text = text.replace("\n", " ")
|
|
|
|
return (
|
|
client.embeddings.create(input=[text], model=engine, **kwargs).data[0].embedding
|
|
)
|
|
|
|
|
|
@embedding_retry_decorator
|
|
async def aget_embedding(
|
|
aclient: AsyncOpenAI, text: str, engine: str, **kwargs: Any
|
|
) -> List[float]:
|
|
"""Asynchronously get embedding.
|
|
|
|
NOTE: Copied from OpenAI's embedding utils:
|
|
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
|
|
|
Copied here to avoid importing unnecessary dependencies
|
|
like matplotlib, plotly, scipy, sklearn.
|
|
|
|
"""
|
|
text = text.replace("\n", " ")
|
|
|
|
return (
|
|
(await aclient.embeddings.create(input=[text], model=engine, **kwargs))
|
|
.data[0]
|
|
.embedding
|
|
)
|
|
|
|
|
|
@embedding_retry_decorator
|
|
def get_embeddings(
|
|
client: OpenAI, list_of_text: List[str], engine: str, **kwargs: Any
|
|
) -> List[List[float]]:
|
|
"""Get embeddings.
|
|
|
|
NOTE: Copied from OpenAI's embedding utils:
|
|
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
|
|
|
Copied here to avoid importing unnecessary dependencies
|
|
like matplotlib, plotly, scipy, sklearn.
|
|
|
|
"""
|
|
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
|
|
|
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
|
|
|
data = client.embeddings.create(input=list_of_text, model=engine, **kwargs).data
|
|
return [d.embedding for d in data]
|
|
|
|
|
|
@embedding_retry_decorator
|
|
async def aget_embeddings(
|
|
aclient: AsyncOpenAI,
|
|
list_of_text: List[str],
|
|
engine: str,
|
|
**kwargs: Any,
|
|
) -> List[List[float]]:
|
|
"""Asynchronously get embeddings.
|
|
|
|
NOTE: Copied from OpenAI's embedding utils:
|
|
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
|
|
|
Copied here to avoid importing unnecessary dependencies
|
|
like matplotlib, plotly, scipy, sklearn.
|
|
|
|
"""
|
|
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
|
|
|
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
|
|
|
data = (
|
|
await aclient.embeddings.create(input=list_of_text, model=engine, **kwargs)
|
|
).data
|
|
return [d.embedding for d in data]
|
|
|
|
|
|
def get_engine(
|
|
mode: str,
|
|
model: str,
|
|
mode_model_dict: Dict[Tuple[OpenAIEmbeddingMode, str], OpenAIEmbeddingModeModel],
|
|
) -> OpenAIEmbeddingModeModel:
|
|
"""Get engine."""
|
|
key = (OpenAIEmbeddingMode(mode), OpenAIEmbeddingModelType(model))
|
|
if key not in mode_model_dict:
|
|
raise ValueError(f"Invalid mode, model combination: {key}")
|
|
return mode_model_dict[key]
|
|
|
|
|
|
class OpenAIEmbedding(BaseEmbedding):
|
|
"""OpenAI class for embeddings.
|
|
|
|
Args:
|
|
mode (str): Mode for embedding.
|
|
Defaults to OpenAIEmbeddingMode.TEXT_SEARCH_MODE.
|
|
Options are:
|
|
|
|
- OpenAIEmbeddingMode.SIMILARITY_MODE
|
|
- OpenAIEmbeddingMode.TEXT_SEARCH_MODE
|
|
|
|
model (str): Model for embedding.
|
|
Defaults to OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002.
|
|
Options are:
|
|
|
|
- OpenAIEmbeddingModelType.DAVINCI
|
|
- OpenAIEmbeddingModelType.CURIE
|
|
- OpenAIEmbeddingModelType.BABBAGE
|
|
- OpenAIEmbeddingModelType.ADA
|
|
- OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002
|
|
"""
|
|
|
|
additional_kwargs: Dict[str, Any] = Field(
|
|
default_factory=dict, description="Additional kwargs for the OpenAI API."
|
|
)
|
|
|
|
api_key: str = Field(description="The OpenAI API key.")
|
|
api_base: str = Field(description="The base URL for OpenAI API.")
|
|
api_version: str = Field(description="The version for OpenAI API.")
|
|
|
|
max_retries: int = Field(
|
|
default=10, description="Maximum number of retries.", gte=0
|
|
)
|
|
timeout: float = Field(default=60.0, description="Timeout for each request.", gte=0)
|
|
default_headers: Optional[Dict[str, str]] = Field(
|
|
default=None, description="The default headers for API requests."
|
|
)
|
|
reuse_client: bool = Field(
|
|
default=True,
|
|
description=(
|
|
"Reuse the OpenAI client between requests. When doing anything with large "
|
|
"volumes of async API calls, setting this to false can improve stability."
|
|
),
|
|
)
|
|
dimensions: Optional[int] = Field(
|
|
default=None,
|
|
description=(
|
|
"The number of dimensions on the output embedding vectors. "
|
|
"Works only with v3 embedding models."
|
|
),
|
|
)
|
|
|
|
_query_engine: OpenAIEmbeddingModeModel = PrivateAttr()
|
|
_text_engine: OpenAIEmbeddingModeModel = PrivateAttr()
|
|
_client: Optional[OpenAI] = PrivateAttr()
|
|
_aclient: Optional[AsyncOpenAI] = PrivateAttr()
|
|
_http_client: Optional[httpx.Client] = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
|
|
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
|
|
embed_batch_size: int = 100,
|
|
dimensions: Optional[int] = None,
|
|
additional_kwargs: Optional[Dict[str, Any]] = None,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
api_version: Optional[str] = None,
|
|
max_retries: int = 10,
|
|
timeout: float = 60.0,
|
|
reuse_client: bool = True,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
default_headers: Optional[Dict[str, str]] = None,
|
|
http_client: Optional[httpx.Client] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
additional_kwargs = additional_kwargs or {}
|
|
if dimensions is not None:
|
|
additional_kwargs["dimensions"] = dimensions
|
|
|
|
api_key, api_base, api_version = resolve_openai_credentials(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
)
|
|
|
|
self._query_engine = get_engine(mode, model, _QUERY_MODE_MODEL_DICT)
|
|
self._text_engine = get_engine(mode, model, _TEXT_MODE_MODEL_DICT)
|
|
|
|
if "model_name" in kwargs:
|
|
model_name = kwargs.pop("model_name")
|
|
self._query_engine = self._text_engine = model_name
|
|
else:
|
|
model_name = model
|
|
|
|
super().__init__(
|
|
embed_batch_size=embed_batch_size,
|
|
dimensions=dimensions,
|
|
callback_manager=callback_manager,
|
|
model_name=model_name,
|
|
additional_kwargs=additional_kwargs,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
max_retries=max_retries,
|
|
reuse_client=reuse_client,
|
|
timeout=timeout,
|
|
default_headers=default_headers,
|
|
**kwargs,
|
|
)
|
|
|
|
self._client = None
|
|
self._aclient = None
|
|
self._http_client = http_client
|
|
|
|
def _get_client(self) -> OpenAI:
|
|
if not self.reuse_client:
|
|
return OpenAI(**self._get_credential_kwargs())
|
|
|
|
if self._client is None:
|
|
self._client = OpenAI(**self._get_credential_kwargs())
|
|
return self._client
|
|
|
|
def _get_aclient(self) -> AsyncOpenAI:
|
|
if not self.reuse_client:
|
|
return AsyncOpenAI(**self._get_credential_kwargs())
|
|
|
|
if self._aclient is None:
|
|
self._aclient = AsyncOpenAI(**self._get_credential_kwargs())
|
|
return self._aclient
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "OpenAIEmbedding"
|
|
|
|
def _get_credential_kwargs(self) -> Dict[str, Any]:
|
|
return {
|
|
"api_key": self.api_key,
|
|
"base_url": self.api_base,
|
|
"max_retries": self.max_retries,
|
|
"timeout": self.timeout,
|
|
"default_headers": self.default_headers,
|
|
"http_client": self._http_client,
|
|
}
|
|
|
|
def _get_query_embedding(self, query: str) -> List[float]:
|
|
"""Get query embedding."""
|
|
client = self._get_client()
|
|
return get_embedding(
|
|
client,
|
|
query,
|
|
engine=self._query_engine,
|
|
**self.additional_kwargs,
|
|
)
|
|
|
|
async def _aget_query_embedding(self, query: str) -> List[float]:
|
|
"""The asynchronous version of _get_query_embedding."""
|
|
aclient = self._get_aclient()
|
|
return await aget_embedding(
|
|
aclient,
|
|
query,
|
|
engine=self._query_engine,
|
|
**self.additional_kwargs,
|
|
)
|
|
|
|
def _get_text_embedding(self, text: str) -> List[float]:
|
|
"""Get text embedding."""
|
|
client = self._get_client()
|
|
return get_embedding(
|
|
client,
|
|
text,
|
|
engine=self._text_engine,
|
|
**self.additional_kwargs,
|
|
)
|
|
|
|
async def _aget_text_embedding(self, text: str) -> List[float]:
|
|
"""Asynchronously get text embedding."""
|
|
aclient = self._get_aclient()
|
|
return await aget_embedding(
|
|
aclient,
|
|
text,
|
|
engine=self._text_engine,
|
|
**self.additional_kwargs,
|
|
)
|
|
|
|
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
"""Get text embeddings.
|
|
|
|
By default, this is a wrapper around _get_text_embedding.
|
|
Can be overridden for batch queries.
|
|
|
|
"""
|
|
client = self._get_client()
|
|
return get_embeddings(
|
|
client,
|
|
texts,
|
|
engine=self._text_engine,
|
|
**self.additional_kwargs,
|
|
)
|
|
|
|
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
"""Asynchronously get text embeddings."""
|
|
aclient = self._get_aclient()
|
|
return await aget_embeddings(
|
|
aclient,
|
|
texts,
|
|
engine=self._text_engine,
|
|
**self.additional_kwargs,
|
|
)
|