faiss_rag_enterprise/llama_index/llms/vertex_utils.py

231 lines
7.5 KiB
Python

# utils script
# generation with retry
import logging
from typing import Any, Callable, Optional
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from llama_index.core.llms.types import ChatMessage, MessageRole
CHAT_MODELS = ["chat-bison", "chat-bison-32k", "chat-bison@001"]
TEXT_MODELS = ["text-bison", "text-bison-32k", "text-bison@001"]
CODE_MODELS = ["code-bison", "code-bison-32k", "code-bison@001"]
CODE_CHAT_MODELS = ["codechat-bison", "codechat-bison-32k", "codechat-bison@001"]
logger = logging.getLogger(__name__)
def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
import google.api_core
min_seconds = 4
max_seconds = 10
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
| retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
| retry_if_exception_type(google.api_core.exceptions.Aborted)
| retry_if_exception_type(google.api_core.exceptions.DeadlineExceeded)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(
client: Any,
prompt: Optional[Any],
max_retries: int = 5,
chat: bool = False,
stream: bool = False,
is_gemini: bool = False,
params: Any = {},
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(max_retries=max_retries)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
if is_gemini:
history = params["message_history"] if "message_history" in params else []
generation = client.start_chat(history=history)
generation_config = dict(kwargs)
return generation.send_message(
prompt, stream=stream, generation_config=generation_config
)
elif chat:
generation = client.start_chat(**params)
if stream:
return generation.send_message_streaming(prompt, **kwargs)
else:
return generation.send_message(prompt, **kwargs)
else:
if stream:
return client.predict_streaming(prompt, **kwargs)
else:
return client.predict(prompt, **kwargs)
return _completion_with_retry(**kwargs)
async def acompletion_with_retry(
client: Any,
prompt: Optional[str],
max_retries: int = 5,
chat: bool = False,
is_gemini: bool = False,
params: Any = {},
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(max_retries=max_retries)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
if is_gemini:
history = params["message_history"] if "message_history" in params else []
generation = client.start_chat(history=history)
generation_config = dict(kwargs)
return await generation.send_message_async(
prompt, generation_config=generation_config
)
elif chat:
generation = client.start_chat(**params)
return await generation.send_message_async(prompt, **kwargs)
else:
return await client.predict_async(prompt, **kwargs)
return await _completion_with_retry(**kwargs)
def init_vertexai(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[Any] = None,
) -> None:
"""Init vertexai.
Args:
project: The default GCP project to use when making Vertex API calls.
location: The default location to use when making API calls.
credentials: The default custom
credentials to use when making API calls. If not provided credentials
will be ascertained from the environment.
Raises:
ImportError: If importing vertexai SDK did not succeed.
"""
try:
import vertexai
except ImportError:
raise (ValueError(f"Please install vertex AI client by following the steps"))
vertexai.init(
project=project,
location=location,
credentials=credentials,
)
def _parse_message(message: ChatMessage, is_gemini: bool) -> Any:
if is_gemini:
from llama_index.llms.vertex_gemini_utils import (
convert_chat_message_to_gemini_content,
)
return convert_chat_message_to_gemini_content(message=message, is_history=False)
else:
return message.content
def _parse_chat_history(history: Any, is_gemini: bool) -> Any:
"""Parse a sequence of messages into history.
Args:
history: The list of messages to re-create the history of the chat.
Returns:
A parsed chat history.
Raises:
ValueError: If a sequence of message has a SystemMessage not at the
first place.
"""
from vertexai.language_models import ChatMessage
vertex_messages, context = [], None
for i, message in enumerate(history):
if i == 0 and message.role == MessageRole.SYSTEM:
if is_gemini:
raise ValueError("Gemini model don't support system messages")
context = message.content
elif message.role == MessageRole.ASSISTANT or message.role == MessageRole.USER:
if is_gemini:
from llama_index.llms.vertex_gemini_utils import (
convert_chat_message_to_gemini_content,
)
vertex_messages.append(
convert_chat_message_to_gemini_content(
message=message, is_history=True
)
)
else:
vertex_message = ChatMessage(
content=message.content,
author="bot" if message.role == MessageRole.ASSISTANT else "user",
)
vertex_messages.append(vertex_message)
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)
if len(vertex_messages) % 2 != 0:
raise ValueError("total no of messages should be even")
return {"context": context, "message_history": vertex_messages}
def _parse_examples(examples: Any) -> Any:
from vertexai.language_models import InputOutputTextPair
if len(examples) % 2 != 0:
raise ValueError(
f"Expect examples to have an even amount of messages, got {len(examples)}."
)
example_pairs = []
input_text = None
for i, example in enumerate(examples):
if i % 2 == 0:
if not example.role == MessageRole.USER:
raise ValueError(
f"Expected the first message in a part to be from user, got "
f"{type(example)} for the {i}th message."
)
input_text = example.content
if i % 2 == 1:
if not example.role == MessageRole.ASSISTANT:
raise ValueError(
f"Expected the second message in a part to be from AI, got "
f"{type(example)} for the {i}th message."
)
pair = InputOutputTextPair(
input_text=input_text, output_text=example.content
)
example_pairs.append(pair)
return example_pairs