faiss_rag_enterprise/llama_index/llms/cohere_utils.py

113 lines
3.1 KiB
Python

import logging
from typing import Any, Callable, Dict, List, Optional, Sequence
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from llama_index.core.llms.types import ChatMessage
COMMAND_MODELS = {
"command": 4096,
"command-nightly": 4096,
"command-light": 4096,
"command-light-nightly": 4096,
}
GENERATION_MODELS = {"base": 2048, "base-light": 2048}
REPRESENTATION_MODELS = {
"embed-english-light-v2.0": 512,
"embed-english-v2.0": 512,
"embed-multilingual-v2.0": 256,
}
ALL_AVAILABLE_MODELS = {**COMMAND_MODELS, **GENERATION_MODELS, **REPRESENTATION_MODELS}
CHAT_MODELS = {**COMMAND_MODELS}
logger = logging.getLogger(__name__)
def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
try:
import cohere
except ImportError as e:
raise ImportError(
"You must install the `cohere` package to use Cohere."
"Please `pip install cohere`"
) from e
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(retry_if_exception_type(cohere.error.CohereConnectionError)),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(
client: Any, max_retries: int, chat: bool = False, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(max_retries=max_retries)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
if chat:
return client.chat(**kwargs)
else:
return client.generate(**kwargs)
return _completion_with_retry(**kwargs)
async def acompletion_with_retry(
aclient: Any,
max_retries: int,
chat: bool = False,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(max_retries=max_retries)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
if chat:
return await aclient.chat(**kwargs)
else:
return await aclient.generate(**kwargs)
return await _completion_with_retry(**kwargs)
def cohere_modelname_to_contextsize(modelname: str) -> int:
context_size = ALL_AVAILABLE_MODELS.get(modelname, None)
if context_size is None:
raise ValueError(
f"Unknown model: {modelname}. Please provide a valid Cohere model name."
"Known models are: " + ", ".join(ALL_AVAILABLE_MODELS.keys())
)
return context_size
def is_chat_model(model: str) -> bool:
return model in COMMAND_MODELS
def messages_to_cohere_history(
messages: Sequence[ChatMessage],
) -> List[Dict[str, Optional[str]]]:
return [
{"user_name": message.role, "message": message.content} for message in messages
]