from collections import ChainMap from typing import ( Any, Dict, List, Optional, Protocol, Sequence, get_args, runtime_checkable, ) from llama_index.bridge.pydantic import BaseModel, Field, validator from llama_index.callbacks import CBEventType, EventPayload from llama_index.core.llms.types import ( ChatMessage, ChatResponseAsyncGen, ChatResponseGen, CompletionResponseAsyncGen, CompletionResponseGen, MessageRole, ) from llama_index.core.query_pipeline.query_component import ( InputKeys, OutputKeys, QueryComponent, StringableInput, validate_and_convert_stringable, ) from llama_index.llms.base import BaseLLM from llama_index.llms.generic_utils import ( messages_to_prompt as generic_messages_to_prompt, ) from llama_index.llms.generic_utils import ( prompt_to_messages, ) from llama_index.prompts import BasePromptTemplate, PromptTemplate from llama_index.types import ( BaseOutputParser, PydanticProgramMode, TokenAsyncGen, TokenGen, ) # NOTE: These two protocols are needed to appease mypy @runtime_checkable class MessagesToPromptType(Protocol): def __call__(self, messages: Sequence[ChatMessage]) -> str: pass @runtime_checkable class CompletionToPromptType(Protocol): def __call__(self, prompt: str) -> str: pass def stream_completion_response_to_tokens( completion_response_gen: CompletionResponseGen, ) -> TokenGen: """Convert a stream completion response to a stream of tokens.""" def gen() -> TokenGen: for response in completion_response_gen: yield response.delta or "" return gen() def stream_chat_response_to_tokens( chat_response_gen: ChatResponseGen, ) -> TokenGen: """Convert a stream completion response to a stream of tokens.""" def gen() -> TokenGen: for response in chat_response_gen: yield response.delta or "" return gen() async def astream_completion_response_to_tokens( completion_response_gen: CompletionResponseAsyncGen, ) -> TokenAsyncGen: """Convert a stream completion response to a stream of tokens.""" async def gen() -> TokenAsyncGen: async for response in completion_response_gen: yield response.delta or "" return gen() async def astream_chat_response_to_tokens( chat_response_gen: ChatResponseAsyncGen, ) -> TokenAsyncGen: """Convert a stream completion response to a stream of tokens.""" async def gen() -> TokenAsyncGen: async for response in chat_response_gen: yield response.delta or "" return gen() def default_completion_to_prompt(prompt: str) -> str: return prompt class LLM(BaseLLM): system_prompt: Optional[str] = Field( default=None, description="System prompt for LLM calls." ) messages_to_prompt: MessagesToPromptType = Field( description="Function to convert a list of messages to an LLM prompt.", default=generic_messages_to_prompt, exclude=True, ) completion_to_prompt: CompletionToPromptType = Field( description="Function to convert a completion to an LLM prompt.", default=default_completion_to_prompt, exclude=True, ) output_parser: Optional[BaseOutputParser] = Field( description="Output parser to parse, validate, and correct errors programmatically.", default=None, exclude=True, ) pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT # deprecated query_wrapper_prompt: Optional[BasePromptTemplate] = Field( description="Query wrapper prompt for LLM calls.", default=None, exclude=True, ) @validator("messages_to_prompt", pre=True) def set_messages_to_prompt( cls, messages_to_prompt: Optional[MessagesToPromptType] ) -> MessagesToPromptType: return messages_to_prompt or generic_messages_to_prompt @validator("completion_to_prompt", pre=True) def set_completion_to_prompt( cls, completion_to_prompt: Optional[CompletionToPromptType] ) -> CompletionToPromptType: return completion_to_prompt or default_completion_to_prompt def _log_template_data( self, prompt: BasePromptTemplate, **prompt_args: Any ) -> None: template_vars = { k: v for k, v in ChainMap(prompt.kwargs, prompt_args).items() if k in prompt.template_vars } with self.callback_manager.event( CBEventType.TEMPLATING, payload={ EventPayload.TEMPLATE: prompt.get_template(llm=self), EventPayload.TEMPLATE_VARS: template_vars, EventPayload.SYSTEM_PROMPT: self.system_prompt, EventPayload.QUERY_WRAPPER_PROMPT: self.query_wrapper_prompt, }, ): pass def _get_prompt(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str: formatted_prompt = prompt.format( llm=self, messages_to_prompt=self.messages_to_prompt, completion_to_prompt=self.completion_to_prompt, **prompt_args, ) if self.output_parser is not None: formatted_prompt = self.output_parser.format(formatted_prompt) return self._extend_prompt(formatted_prompt) def _get_messages( self, prompt: BasePromptTemplate, **prompt_args: Any ) -> List[ChatMessage]: messages = prompt.format_messages(llm=self, **prompt_args) if self.output_parser is not None: messages = self.output_parser.format_messages(messages) return self._extend_messages(messages) def structured_predict( self, output_cls: BaseModel, prompt: PromptTemplate, **prompt_args: Any, ) -> BaseModel: from llama_index.program.utils import get_program_for_llm program = get_program_for_llm( output_cls, prompt, self, pydantic_program_mode=self.pydantic_program_mode, ) return program(**prompt_args) async def astructured_predict( self, output_cls: BaseModel, prompt: PromptTemplate, **prompt_args: Any, ) -> BaseModel: from llama_index.program.utils import get_program_for_llm program = get_program_for_llm( output_cls, prompt, self, pydantic_program_mode=self.pydantic_program_mode, ) return await program.acall(**prompt_args) def _parse_output(self, output: str) -> str: if self.output_parser is not None: return str(self.output_parser.parse(output)) return output def predict( self, prompt: BasePromptTemplate, **prompt_args: Any, ) -> str: """Predict.""" self._log_template_data(prompt, **prompt_args) if self.metadata.is_chat_model: messages = self._get_messages(prompt, **prompt_args) chat_response = self.chat(messages) output = chat_response.message.content or "" else: formatted_prompt = self._get_prompt(prompt, **prompt_args) response = self.complete(formatted_prompt, formatted=True) output = response.text return self._parse_output(output) def stream( self, prompt: BasePromptTemplate, **prompt_args: Any, ) -> TokenGen: """Stream.""" self._log_template_data(prompt, **prompt_args) if self.metadata.is_chat_model: messages = self._get_messages(prompt, **prompt_args) chat_response = self.stream_chat(messages) stream_tokens = stream_chat_response_to_tokens(chat_response) else: formatted_prompt = self._get_prompt(prompt, **prompt_args) stream_response = self.stream_complete(formatted_prompt, formatted=True) stream_tokens = stream_completion_response_to_tokens(stream_response) if prompt.output_parser is not None or self.output_parser is not None: raise NotImplementedError("Output parser is not supported for streaming.") return stream_tokens async def apredict( self, prompt: BasePromptTemplate, **prompt_args: Any, ) -> str: """Async predict.""" self._log_template_data(prompt, **prompt_args) if self.metadata.is_chat_model: messages = self._get_messages(prompt, **prompt_args) chat_response = await self.achat(messages) output = chat_response.message.content or "" else: formatted_prompt = self._get_prompt(prompt, **prompt_args) response = await self.acomplete(formatted_prompt, formatted=True) output = response.text return self._parse_output(output) async def astream( self, prompt: BasePromptTemplate, **prompt_args: Any, ) -> TokenAsyncGen: """Async stream.""" self._log_template_data(prompt, **prompt_args) if self.metadata.is_chat_model: messages = self._get_messages(prompt, **prompt_args) chat_response = await self.astream_chat(messages) stream_tokens = await astream_chat_response_to_tokens(chat_response) else: formatted_prompt = self._get_prompt(prompt, **prompt_args) stream_response = await self.astream_complete( formatted_prompt, formatted=True ) stream_tokens = await astream_completion_response_to_tokens(stream_response) if prompt.output_parser is not None or self.output_parser is not None: raise NotImplementedError("Output parser is not supported for streaming.") return stream_tokens def _extend_prompt( self, formatted_prompt: str, ) -> str: """Add system and query wrapper prompts to base prompt.""" extended_prompt = formatted_prompt if self.system_prompt: extended_prompt = self.system_prompt + "\n\n" + extended_prompt if self.query_wrapper_prompt: extended_prompt = self.query_wrapper_prompt.format( query_str=extended_prompt ) return extended_prompt def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]: """Add system prompt to chat message list.""" if self.system_prompt: messages = [ ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt), *messages, ] return messages def _as_query_component(self, **kwargs: Any) -> QueryComponent: """Return query component.""" if self.metadata.is_chat_model: return LLMChatComponent(llm=self, **kwargs) else: return LLMCompleteComponent(llm=self, **kwargs) class BaseLLMComponent(QueryComponent): """Base LLM component.""" llm: LLM = Field(..., description="LLM") streaming: bool = Field(default=False, description="Streaming mode") class Config: arbitrary_types_allowed = True def set_callback_manager(self, callback_manager: Any) -> None: """Set callback manager.""" self.llm.callback_manager = callback_manager class LLMCompleteComponent(BaseLLMComponent): """LLM completion component.""" def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]: """Validate component inputs during run_component.""" if "prompt" not in input: raise ValueError("Prompt must be in input dict.") # do special check to see if prompt is a list of chat messages if isinstance(input["prompt"], get_args(List[ChatMessage])): input["prompt"] = self.llm.messages_to_prompt(input["prompt"]) input["prompt"] = validate_and_convert_stringable(input["prompt"]) else: input["prompt"] = validate_and_convert_stringable(input["prompt"]) input["prompt"] = self.llm.completion_to_prompt(input["prompt"]) return input def _run_component(self, **kwargs: Any) -> Any: """Run component.""" # TODO: support only complete for now # non-trivial to figure how to support chat/complete/etc. prompt = kwargs["prompt"] # ignore all other kwargs for now if self.streaming: response = self.llm.stream_complete(prompt, formatted=True) else: response = self.llm.complete(prompt, formatted=True) return {"output": response} async def _arun_component(self, **kwargs: Any) -> Any: """Run component.""" # TODO: support only complete for now # non-trivial to figure how to support chat/complete/etc. prompt = kwargs["prompt"] # ignore all other kwargs for now response = await self.llm.acomplete(prompt, formatted=True) return {"output": response} @property def input_keys(self) -> InputKeys: """Input keys.""" # TODO: support only complete for now return InputKeys.from_keys({"prompt"}) @property def output_keys(self) -> OutputKeys: """Output keys.""" return OutputKeys.from_keys({"output"}) class LLMChatComponent(BaseLLMComponent): """LLM chat component.""" def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]: """Validate component inputs during run_component.""" if "messages" not in input: raise ValueError("Messages must be in input dict.") # if `messages` is a string, convert to a list of chat message if isinstance(input["messages"], get_args(StringableInput)): input["messages"] = validate_and_convert_stringable(input["messages"]) input["messages"] = prompt_to_messages(str(input["messages"])) for message in input["messages"]: if not isinstance(message, ChatMessage): raise ValueError("Messages must be a list of ChatMessage") return input def _run_component(self, **kwargs: Any) -> Any: """Run component.""" # TODO: support only complete for now # non-trivial to figure how to support chat/complete/etc. messages = kwargs["messages"] if self.streaming: response = self.llm.stream_chat(messages) else: response = self.llm.chat(messages) return {"output": response} async def _arun_component(self, **kwargs: Any) -> Any: """Run component.""" # TODO: support only complete for now # non-trivial to figure how to support chat/complete/etc. messages = kwargs["messages"] if self.streaming: response = await self.llm.astream_chat(messages) else: response = await self.llm.achat(messages) return {"output": response} @property def input_keys(self) -> InputKeys: """Input keys.""" # TODO: support only complete for now return InputKeys.from_keys({"messages"}) @property def output_keys(self) -> OutputKeys: """Output keys.""" return OutputKeys.from_keys({"output"})