faiss_rag_enterprise/llama_index/llms/bedrock_utils.py

204 lines
6.0 KiB
Python

import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Sequence
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from llama_index.core.llms.types import ChatMessage
from llama_index.llms.anthropic_utils import messages_to_anthropic_prompt
from llama_index.llms.generic_utils import (
prompt_to_messages,
)
from llama_index.llms.llama_utils import (
completion_to_prompt as completion_to_llama_prompt,
)
from llama_index.llms.llama_utils import (
messages_to_prompt as messages_to_llama_prompt,
)
HUMAN_PREFIX = "\n\nHuman:"
ASSISTANT_PREFIX = "\n\nAssistant:"
# Values taken from https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html#model-parameters-claude
COMPLETION_MODELS = {
"amazon.titan-tg1-large": 8000,
"amazon.titan-text-express-v1": 8000,
"ai21.j2-grande-instruct": 8000,
"ai21.j2-jumbo-instruct": 8000,
"ai21.j2-mid": 8000,
"ai21.j2-mid-v1": 8000,
"ai21.j2-ultra": 8000,
"ai21.j2-ultra-v1": 8000,
"cohere.command-text-v14": 4096,
}
# Anthropic models require prompt to start with "Human:" and
# end with "Assistant:"
CHAT_ONLY_MODELS = {
"anthropic.claude-instant-v1": 100000,
"anthropic.claude-v1": 100000,
"anthropic.claude-v2": 100000,
"anthropic.claude-v2:1": 200000,
"meta.llama2-13b-chat-v1": 2048,
"meta.llama2-70b-chat-v1": 4096,
}
BEDROCK_FOUNDATION_LLMS = {**COMPLETION_MODELS, **CHAT_ONLY_MODELS}
# Only the following models support streaming as
# per result of Bedrock.Client.list_foundation_models
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html
STREAMING_MODELS = {
"amazon.titan-tg1-large",
"amazon.titan-text-express-v1",
"anthropic.claude-instant-v1",
"anthropic.claude-v1",
"anthropic.claude-v2",
"anthropic.claude-v2:1",
"meta.llama2-13b-chat-v1",
}
class Provider(ABC):
@property
@abstractmethod
def max_tokens_key(self) -> str:
...
@abstractmethod
def get_text_from_response(self, response: dict) -> str:
...
def get_text_from_stream_response(self, response: dict) -> str:
return self.get_text_from_response(response)
def get_request_body(self, prompt: str, inference_parameters: dict) -> dict:
return {"prompt": prompt, **inference_parameters}
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None
completion_to_prompt: Optional[Callable[[str], str]] = None
class AmazonProvider(Provider):
max_tokens_key = "maxTokenCount"
def get_text_from_response(self, response: dict) -> str:
return response["results"][0]["outputText"]
def get_text_from_stream_response(self, response: dict) -> str:
return response["outputText"]
def get_request_body(self, prompt: str, inference_parameters: dict) -> dict:
return {
"inputText": prompt,
"textGenerationConfig": {**inference_parameters},
}
class Ai21Provider(Provider):
max_tokens_key = "maxTokens"
def get_text_from_response(self, response: dict) -> str:
return response["completions"][0]["data"]["text"]
def completion_to_anthopic_prompt(completion: str) -> str:
return messages_to_anthropic_prompt(prompt_to_messages(completion))
class AnthropicProvider(Provider):
max_tokens_key = "max_tokens_to_sample"
def __init__(self) -> None:
self.messages_to_prompt = messages_to_anthropic_prompt
self.completion_to_prompt = completion_to_anthopic_prompt
def get_text_from_response(self, response: dict) -> str:
return response["completion"]
class CohereProvider(Provider):
max_tokens_key = "max_tokens"
def get_text_from_response(self, response: dict) -> str:
return response["generations"][0]["text"]
class MetaProvider(Provider):
max_tokens_key = "max_gen_len"
def __init__(self) -> None:
self.messages_to_prompt = messages_to_llama_prompt
self.completion_to_prompt = completion_to_llama_prompt
def get_text_from_response(self, response: dict) -> str:
return response["generation"]
PROVIDERS = {
"amazon": AmazonProvider(),
"ai21": Ai21Provider(),
"anthropic": AnthropicProvider(),
"cohere": CohereProvider(),
"meta": MetaProvider(),
}
def get_provider(model: str) -> Provider:
provider_name = model.split(".")[0]
if provider_name not in PROVIDERS:
raise ValueError(f"Provider {provider_name} for model {model} is not supported")
return PROVIDERS[provider_name]
logger = logging.getLogger(__name__)
def _create_retry_decorator(client: Any, max_retries: int) -> Callable[[Any], Any]:
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
try:
import boto3 # noqa
except ImportError as e:
raise ImportError(
"You must install the `boto3` package to use Bedrock."
"Please `pip install boto3`"
) from e
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(retry_if_exception_type(client.exceptions.ThrottlingException)),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(
client: Any,
model: str,
request_body: str,
max_retries: int,
stream: bool = False,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(client=client, max_retries=max_retries)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
if stream:
return client.invoke_model_with_response_stream(
modelId=model, body=request_body
)
return client.invoke_model(modelId=model, body=request_body)
return _completion_with_retry(**kwargs)