faiss_rag_enterprise/llama_index/embeddings/bedrock.py

392 lines
14 KiB
Python

import json
import os
import warnings
from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence
from deprecated import deprecated
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks.base import CallbackManager
from llama_index.constants import DEFAULT_EMBED_BATCH_SIZE
from llama_index.core.embeddings.base import BaseEmbedding, Embedding
from llama_index.core.llms.types import ChatMessage
from llama_index.types import BaseOutputParser, PydanticProgramMode
class PROVIDERS(str, Enum):
AMAZON = "amazon"
COHERE = "cohere"
class Models(str, Enum):
TITAN_EMBEDDING = "amazon.titan-embed-text-v1"
TITAN_EMBEDDING_G1_TEXT_02 = "amazon.titan-embed-g1-text-02"
COHERE_EMBED_ENGLISH_V3 = "cohere.embed-english-v3"
COHERE_EMBED_MULTILINGUAL_V3 = "cohere.embed-multilingual-v3"
PROVIDER_SPECIFIC_IDENTIFIERS = {
PROVIDERS.AMAZON.value: {
"get_embeddings_func": lambda r: r.get("embedding"),
},
PROVIDERS.COHERE.value: {
"get_embeddings_func": lambda r: r.get("embeddings")[0],
},
}
class BedrockEmbedding(BaseEmbedding):
model: str = Field(description="The modelId of the Bedrock model to use.")
profile_name: Optional[str] = Field(
description="The name of aws profile to use. If not given, then the default profile is used.",
exclude=True,
)
aws_access_key_id: Optional[str] = Field(
description="AWS Access Key ID to use", exclude=True
)
aws_secret_access_key: Optional[str] = Field(
description="AWS Secret Access Key to use", exclude=True
)
aws_session_token: Optional[str] = Field(
description="AWS Session Token to use", exclude=True
)
region_name: Optional[str] = Field(
description="AWS region name to use. Uses region configured in AWS CLI if not passed",
exclude=True,
)
botocore_session: Optional[Any] = Field(
description="Use this Botocore session instead of creating a new default one.",
exclude=True,
)
botocore_config: Optional[Any] = Field(
description="Custom configuration object to use instead of the default generated one.",
exclude=True,
)
max_retries: int = Field(
default=10, description="The maximum number of API retries.", gt=0
)
timeout: float = Field(
default=60.0,
description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.",
)
additional_kwargs: Dict[str, Any] = Field(
default_factory=dict, description="Additional kwargs for the bedrock client."
)
_client: Any = PrivateAttr()
def __init__(
self,
model: str = Models.TITAN_EMBEDDING,
profile_name: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
region_name: Optional[str] = None,
client: Optional[Any] = None,
botocore_session: Optional[Any] = None,
botocore_config: Optional[Any] = None,
additional_kwargs: Optional[Dict[str, Any]] = None,
max_retries: int = 10,
timeout: float = 60.0,
callback_manager: Optional[CallbackManager] = None,
# base class
system_prompt: Optional[str] = None,
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
completion_to_prompt: Optional[Callable[[str], str]] = None,
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
**kwargs: Any,
):
additional_kwargs = additional_kwargs or {}
session_kwargs = {
"profile_name": profile_name,
"region_name": region_name,
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
"aws_session_token": aws_session_token,
"botocore_session": botocore_session,
}
config = None
try:
import boto3
from botocore.config import Config
config = (
Config(
retries={"max_attempts": max_retries, "mode": "standard"},
connect_timeout=timeout,
read_timeout=timeout,
)
if botocore_config is None
else botocore_config
)
session = boto3.Session(**session_kwargs)
except ImportError:
raise ImportError(
"boto3 package not found, install with" "'pip install boto3'"
)
# Prior to general availability, custom boto3 wheel files were
# distributed that used the bedrock service to invokeModel.
# This check prevents any services still using those wheel files
# from breaking
if client is not None:
self._client = client
elif "bedrock-runtime" in session.get_available_services():
self._client = session.client("bedrock-runtime", config=config)
else:
self._client = session.client("bedrock", config=config)
super().__init__(
model=model,
max_retries=max_retries,
timeout=timeout,
botocore_config=config,
profile_name=profile_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=region_name,
botocore_session=botocore_session,
additional_kwargs=additional_kwargs,
callback_manager=callback_manager,
system_prompt=system_prompt,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
pydantic_program_mode=pydantic_program_mode,
output_parser=output_parser,
**kwargs,
)
@staticmethod
def list_supported_models() -> Dict[str, List[str]]:
list_models = {}
for provider in PROVIDERS:
list_models[provider.value] = [m.value for m in Models]
return list_models
@classmethod
def class_name(self) -> str:
return "BedrockEmbedding"
@deprecated(
version="0.9.48",
reason=(
"Use the provided kwargs in the constructor, "
"set_credentials will be removed in future releases."
),
action="once",
)
def set_credentials(
self,
aws_region: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_profile: Optional[str] = None,
) -> None:
aws_region = aws_region or os.getenv("AWS_REGION")
aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
aws_secret_access_key = aws_secret_access_key or os.getenv(
"AWS_SECRET_ACCESS_KEY"
)
aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
if aws_region is None:
warnings.warn(
"AWS_REGION not found. Set environment variable AWS_REGION or set aws_region"
)
if aws_access_key_id is None:
warnings.warn(
"AWS_ACCESS_KEY_ID not found. Set environment variable AWS_ACCESS_KEY_ID or set aws_access_key_id"
)
assert aws_access_key_id is not None
if aws_secret_access_key is None:
warnings.warn(
"AWS_SECRET_ACCESS_KEY not found. Set environment variable AWS_SECRET_ACCESS_KEY or set aws_secret_access_key"
)
assert aws_secret_access_key is not None
if aws_session_token is None:
warnings.warn(
"AWS_SESSION_TOKEN not found. Set environment variable AWS_SESSION_TOKEN or set aws_session_token"
)
assert aws_session_token is not None
session_kwargs = {
"profile_name": aws_profile,
"region_name": aws_region,
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
"aws_session_token": aws_session_token,
}
try:
import boto3
session = boto3.Session(**session_kwargs)
except ImportError:
raise ImportError(
"boto3 package not found, install with" "'pip install boto3'"
)
if "bedrock-runtime" in session.get_available_services():
self._client = session.client("bedrock-runtime")
else:
self._client = session.client("bedrock")
@classmethod
@deprecated(
version="0.9.48",
reason=(
"Use the provided kwargs in the constructor, "
"set_credentials will be removed in future releases."
),
action="once",
)
def from_credentials(
cls,
model_name: str = Models.TITAN_EMBEDDING,
aws_region: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_profile: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
) -> "BedrockEmbedding":
"""
Instantiate using AWS credentials.
Args:
model_name (str) : Name of the model
aws_access_key_id (str): AWS access key ID
aws_secret_access_key (str): AWS secret access key
aws_session_token (str): AWS session token
aws_region (str): AWS region where the service is located
aws_profile (str): AWS profile, when None, default profile is chosen automatically
Example:
.. code-block:: python
from llama_index.embeddings import BedrockEmbedding
# Define the model name
model_name = "your_model_name"
embeddings = BedrockEmbedding.from_credentials(
model_name,
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region,
aws_profile,
)
"""
session_kwargs = {
"profile_name": aws_profile,
"region_name": aws_region,
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
"aws_session_token": aws_session_token,
}
try:
import boto3
session = boto3.Session(**session_kwargs)
except ImportError:
raise ImportError(
"boto3 package not found, install with" "'pip install boto3'"
)
if "bedrock-runtime" in session.get_available_services():
client = session.client("bedrock-runtime")
else:
client = session.client("bedrock")
return cls(
client=client,
model_name=model_name,
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
verbose=verbose,
)
def _get_embedding(self, payload: str, type: Literal["text", "query"]) -> Embedding:
if self._client is None:
self.set_credentials()
if self._client is None:
raise ValueError("Client not set")
provider = self.model.split(".")[0]
request_body = self._get_request_body(provider, payload, type)
response = self._client.invoke_model(
body=request_body,
modelId=self.model,
accept="application/json",
contentType="application/json",
)
resp = json.loads(response.get("body").read().decode("utf-8"))
identifiers = PROVIDER_SPECIFIC_IDENTIFIERS.get(provider, None)
if identifiers is None:
raise ValueError("Provider not supported")
return identifiers["get_embeddings_func"](resp)
def _get_query_embedding(self, query: str) -> Embedding:
return self._get_embedding(query, "query")
def _get_text_embedding(self, text: str) -> Embedding:
return self._get_embedding(text, "text")
def _get_request_body(
self, provider: str, payload: str, type: Literal["text", "query"]
) -> Any:
"""Build the request body as per the provider.
Currently supported providers are amazon, cohere.
amazon:
Sample Payload of type str
"Hello World!"
cohere:
Sample Payload of type dict of following format
{
'texts': ["This is a test document", "This is another document"],
'input_type': 'search_document',
'truncate': 'NONE'
}
"""
if provider == PROVIDERS.AMAZON:
request_body = json.dumps({"inputText": payload})
elif provider == PROVIDERS.COHERE:
input_types = {
"text": "search_document",
"query": "search_query",
}
request_body = json.dumps(
{
"texts": [payload],
"input_type": input_types[type],
"truncate": "NONE",
}
)
else:
raise ValueError("Provider not supported")
return request_body
async def _aget_query_embedding(self, query: str) -> Embedding:
return self._get_embedding(query, "query")
async def _aget_text_embedding(self, text: str) -> Embedding:
return self._get_embedding(text, "text")