from typing import Any, Awaitable, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_TEMPERATURE from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, ChatResponseGen, CompletionResponse, CompletionResponseAsyncGen, CompletionResponseGen, LLMMetadata, ) from llama_index.llms.base import llm_chat_callback, llm_completion_callback from llama_index.llms.generic_utils import ( achat_to_completion_decorator, acompletion_to_chat_decorator, astream_chat_to_completion_decorator, astream_completion_to_chat_decorator, chat_to_completion_decorator, completion_to_chat_decorator, stream_chat_to_completion_decorator, stream_completion_to_chat_decorator, ) from llama_index.llms.litellm_utils import ( acompletion_with_retry, completion_with_retry, from_litellm_message, is_function_calling_model, openai_modelname_to_contextsize, to_openai_message_dicts, validate_litellm_api_key, ) from llama_index.llms.llm import LLM from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_LITELLM_MODEL = "gpt-3.5-turbo" class LiteLLM(LLM): model: str = Field( default=DEFAULT_LITELLM_MODEL, description=( "The LiteLLM model to use. " "For complete list of providers https://docs.litellm.ai/docs/providers" ), ) temperature: float = Field( default=DEFAULT_TEMPERATURE, description="The temperature to use during generation.", gte=0.0, lte=1.0, ) max_tokens: Optional[int] = Field( description="The maximum number of tokens to generate.", gt=0, ) additional_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Additional kwargs for the LLM API.", # for all inputs https://docs.litellm.ai/docs/completion/input ) max_retries: int = Field( default=10, description="The maximum number of API retries." ) def __init__( self, model: str = DEFAULT_LITELLM_MODEL, temperature: float = DEFAULT_TEMPERATURE, max_tokens: Optional[int] = None, additional_kwargs: Optional[Dict[str, Any]] = None, max_retries: int = 10, api_key: Optional[str] = None, api_type: Optional[str] = None, api_base: Optional[str] = None, callback_manager: Optional[CallbackManager] = None, system_prompt: Optional[str] = None, messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, completion_to_prompt: Optional[Callable[[str], str]] = None, pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, output_parser: Optional[BaseOutputParser] = None, **kwargs: Any, ) -> None: if "custom_llm_provider" in kwargs: if ( kwargs["custom_llm_provider"] != "ollama" and kwargs["custom_llm_provider"] != "vllm" ): # don't check keys for local models validate_litellm_api_key(api_key, api_type) else: # by default assume it's a hosted endpoint validate_litellm_api_key(api_key, api_type) additional_kwargs = additional_kwargs or {} if api_key is not None: additional_kwargs["api_key"] = api_key if api_type is not None: additional_kwargs["api_type"] = api_type if api_base is not None: additional_kwargs["api_base"] = api_base super().__init__( model=model, temperature=temperature, max_tokens=max_tokens, additional_kwargs=additional_kwargs, max_retries=max_retries, callback_manager=callback_manager, system_prompt=system_prompt, messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, pydantic_program_mode=pydantic_program_mode, output_parser=output_parser, **kwargs, ) def _get_model_name(self) -> str: model_name = self.model if "ft-" in model_name: # legacy fine-tuning model_name = model_name.split(":")[0] elif model_name.startswith("ft:"): model_name = model_name.split(":")[1] return model_name @classmethod def class_name(cls) -> str: return "litellm_llm" @property def metadata(self) -> LLMMetadata: return LLMMetadata( context_window=openai_modelname_to_contextsize(self._get_model_name()), num_output=self.max_tokens or -1, is_chat_model=True, is_function_calling_model=is_function_calling_model(self._get_model_name()), model_name=self.model, ) @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: if self._is_chat_model: chat_fn = self._chat else: chat_fn = completion_to_chat_decorator(self._complete) return chat_fn(messages, **kwargs) @llm_chat_callback() def stream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: if self._is_chat_model: stream_chat_fn = self._stream_chat else: stream_chat_fn = stream_completion_to_chat_decorator(self._stream_complete) return stream_chat_fn(messages, **kwargs) @llm_completion_callback() def complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponse: # litellm assumes all llms are chat llms if self._is_chat_model: complete_fn = chat_to_completion_decorator(self._chat) else: complete_fn = self._complete return complete_fn(prompt, **kwargs) @llm_completion_callback() def stream_complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponseGen: if self._is_chat_model: stream_complete_fn = stream_chat_to_completion_decorator(self._stream_chat) else: stream_complete_fn = self._stream_complete return stream_complete_fn(prompt, **kwargs) @property def _is_chat_model(self) -> bool: # litellm assumes all llms are chat llms return True @property def _model_kwargs(self) -> Dict[str, Any]: base_kwargs = { "model": self.model, "temperature": self.temperature, "max_tokens": self.max_tokens, } return { **base_kwargs, **self.additional_kwargs, } def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]: return { **self._model_kwargs, **kwargs, } def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: if not self._is_chat_model: raise ValueError("This model is not a chat model.") message_dicts = to_openai_message_dicts(messages) all_kwargs = self._get_all_kwargs(**kwargs) if "max_tokens" in all_kwargs and all_kwargs["max_tokens"] is None: all_kwargs.pop( "max_tokens" ) # don't send max_tokens == None, this throws errors for Non OpenAI providers response = completion_with_retry( is_chat_model=self._is_chat_model, max_retries=self.max_retries, messages=message_dicts, stream=False, **all_kwargs, ) message_dict = response["choices"][0]["message"] message = from_litellm_message(message_dict) return ChatResponse( message=message, raw=response, additional_kwargs=self._get_response_token_counts(response), ) def _stream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: if not self._is_chat_model: raise ValueError("This model is not a chat model.") message_dicts = to_openai_message_dicts(messages) all_kwargs = self._get_all_kwargs(**kwargs) if "max_tokens" in all_kwargs and all_kwargs["max_tokens"] is None: all_kwargs.pop( "max_tokens" ) # don't send max_tokens == None, this throws errors for Non OpenAI providers def gen() -> ChatResponseGen: content = "" function_call: Optional[dict] = None for response in completion_with_retry( is_chat_model=self._is_chat_model, max_retries=self.max_retries, messages=message_dicts, stream=True, **all_kwargs, ): delta = response["choices"][0]["delta"] role = delta.get("role", "assistant") content_delta = delta.get("content", "") or "" content += content_delta function_call_delta = delta.get("function_call", None) if function_call_delta is not None: if function_call is None: function_call = function_call_delta ## ensure we do not add a blank function call if function_call.get("function_name", "") is None: del function_call["function_name"] else: function_call["arguments"] += function_call_delta["arguments"] additional_kwargs = {} if function_call is not None: additional_kwargs["function_call"] = function_call yield ChatResponse( message=ChatMessage( role=role, content=content, additional_kwargs=additional_kwargs, ), delta=content_delta, raw=response, additional_kwargs=self._get_response_token_counts(response), ) return gen() def _complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: raise NotImplementedError("litellm assumes all llms are chat llms.") def _stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: raise NotImplementedError("litellm assumes all llms are chat llms.") def _get_max_token_for_prompt(self, prompt: str) -> int: try: import tiktoken except ImportError: raise ImportError( "Please install tiktoken to use the max_tokens=None feature." ) context_window = self.metadata.context_window try: encoding = tiktoken.encoding_for_model(self._get_model_name()) except KeyError: encoding = encoding = tiktoken.get_encoding( "cl100k_base" ) # default to using cl10k_base tokens = encoding.encode(prompt) max_token = context_window - len(tokens) if max_token <= 0: raise ValueError( f"The prompt is too long for the model. " f"Please use a prompt that is less than {context_window} tokens." ) return max_token def _get_response_token_counts(self, raw_response: Any) -> dict: """Get the token usage reported by the response.""" if not isinstance(raw_response, dict): return {} usage = raw_response.get("usage", {}) return { "prompt_tokens": usage.get("prompt_tokens", 0), "completion_tokens": usage.get("completion_tokens", 0), "total_tokens": usage.get("total_tokens", 0), } # ===== Async Endpoints ===== @llm_chat_callback() async def achat( self, messages: Sequence[ChatMessage], **kwargs: Any, ) -> ChatResponse: achat_fn: Callable[..., Awaitable[ChatResponse]] if self._is_chat_model: achat_fn = self._achat else: achat_fn = acompletion_to_chat_decorator(self._acomplete) return await achat_fn(messages, **kwargs) @llm_chat_callback() async def astream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any, ) -> ChatResponseAsyncGen: astream_chat_fn: Callable[..., Awaitable[ChatResponseAsyncGen]] if self._is_chat_model: astream_chat_fn = self._astream_chat else: astream_chat_fn = astream_completion_to_chat_decorator( self._astream_complete ) return await astream_chat_fn(messages, **kwargs) @llm_completion_callback() async def acomplete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponse: if self._is_chat_model: acomplete_fn = achat_to_completion_decorator(self._achat) else: acomplete_fn = self._acomplete return await acomplete_fn(prompt, **kwargs) @llm_completion_callback() async def astream_complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponseAsyncGen: if self._is_chat_model: astream_complete_fn = astream_chat_to_completion_decorator( self._astream_chat ) else: astream_complete_fn = self._astream_complete return await astream_complete_fn(prompt, **kwargs) async def _achat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponse: if not self._is_chat_model: raise ValueError("This model is not a chat model.") message_dicts = to_openai_message_dicts(messages) all_kwargs = self._get_all_kwargs(**kwargs) response = await acompletion_with_retry( is_chat_model=self._is_chat_model, max_retries=self.max_retries, messages=message_dicts, stream=False, **all_kwargs, ) message_dict = response["choices"][0]["message"] message = from_litellm_message(message_dict) return ChatResponse( message=message, raw=response, additional_kwargs=self._get_response_token_counts(response), ) async def _astream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseAsyncGen: if not self._is_chat_model: raise ValueError("This model is not a chat model.") message_dicts = to_openai_message_dicts(messages) all_kwargs = self._get_all_kwargs(**kwargs) async def gen() -> ChatResponseAsyncGen: content = "" function_call: Optional[dict] = None async for response in await acompletion_with_retry( is_chat_model=self._is_chat_model, max_retries=self.max_retries, messages=message_dicts, stream=True, **all_kwargs, ): delta = response["choices"][0]["delta"] role = delta.get("role", "assistant") content_delta = delta.get("content", "") or "" content += content_delta function_call_delta = delta.get("function_call", None) if function_call_delta is not None: if function_call is None: function_call = function_call_delta ## ensure we do not add a blank function call if function_call.get("function_name", "") is None: del function_call["function_name"] else: function_call["arguments"] += function_call_delta["arguments"] additional_kwargs = {} if function_call is not None: additional_kwargs["function_call"] = function_call yield ChatResponse( message=ChatMessage( role=role, content=content, additional_kwargs=additional_kwargs, ), delta=content_delta, raw=response, additional_kwargs=self._get_response_token_counts(response), ) return gen() async def _acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse: raise NotImplementedError("litellm assumes all llms are chat llms.") async def _astream_complete( self, prompt: str, **kwargs: Any ) -> CompletionResponseAsyncGen: raise NotImplementedError("litellm assumes all llms are chat llms.")