113 lines
3.1 KiB
Python
113 lines
3.1 KiB
Python
import logging
|
|
from typing import Any, Callable, Dict, List, Optional, Sequence
|
|
|
|
from tenacity import (
|
|
before_sleep_log,
|
|
retry,
|
|
retry_if_exception_type,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
)
|
|
|
|
from llama_index.core.llms.types import ChatMessage
|
|
|
|
COMMAND_MODELS = {
|
|
"command": 4096,
|
|
"command-nightly": 4096,
|
|
"command-light": 4096,
|
|
"command-light-nightly": 4096,
|
|
}
|
|
|
|
GENERATION_MODELS = {"base": 2048, "base-light": 2048}
|
|
|
|
REPRESENTATION_MODELS = {
|
|
"embed-english-light-v2.0": 512,
|
|
"embed-english-v2.0": 512,
|
|
"embed-multilingual-v2.0": 256,
|
|
}
|
|
|
|
ALL_AVAILABLE_MODELS = {**COMMAND_MODELS, **GENERATION_MODELS, **REPRESENTATION_MODELS}
|
|
CHAT_MODELS = {**COMMAND_MODELS}
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
|
|
min_seconds = 4
|
|
max_seconds = 10
|
|
# Wait 2^x * 1 second between each retry starting with
|
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
|
try:
|
|
import cohere
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"You must install the `cohere` package to use Cohere."
|
|
"Please `pip install cohere`"
|
|
) from e
|
|
|
|
return retry(
|
|
reraise=True,
|
|
stop=stop_after_attempt(max_retries),
|
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
|
retry=(retry_if_exception_type(cohere.error.CohereConnectionError)),
|
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
)
|
|
|
|
|
|
def completion_with_retry(
|
|
client: Any, max_retries: int, chat: bool = False, **kwargs: Any
|
|
) -> Any:
|
|
"""Use tenacity to retry the completion call."""
|
|
retry_decorator = _create_retry_decorator(max_retries=max_retries)
|
|
|
|
@retry_decorator
|
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
|
if chat:
|
|
return client.chat(**kwargs)
|
|
else:
|
|
return client.generate(**kwargs)
|
|
|
|
return _completion_with_retry(**kwargs)
|
|
|
|
|
|
async def acompletion_with_retry(
|
|
aclient: Any,
|
|
max_retries: int,
|
|
chat: bool = False,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Use tenacity to retry the async completion call."""
|
|
retry_decorator = _create_retry_decorator(max_retries=max_retries)
|
|
|
|
@retry_decorator
|
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
|
if chat:
|
|
return await aclient.chat(**kwargs)
|
|
else:
|
|
return await aclient.generate(**kwargs)
|
|
|
|
return await _completion_with_retry(**kwargs)
|
|
|
|
|
|
def cohere_modelname_to_contextsize(modelname: str) -> int:
|
|
context_size = ALL_AVAILABLE_MODELS.get(modelname, None)
|
|
if context_size is None:
|
|
raise ValueError(
|
|
f"Unknown model: {modelname}. Please provide a valid Cohere model name."
|
|
"Known models are: " + ", ".join(ALL_AVAILABLE_MODELS.keys())
|
|
)
|
|
|
|
return context_size
|
|
|
|
|
|
def is_chat_model(model: str) -> bool:
|
|
return model in COMMAND_MODELS
|
|
|
|
|
|
def messages_to_cohere_history(
|
|
messages: Sequence[ChatMessage],
|
|
) -> List[Dict[str, Optional[str]]]:
|
|
return [
|
|
{"user_name": message.role, "message": message.content} for message in messages
|
|
]
|