"""Prompts.""" from abc import ABC, abstractmethod from copy import deepcopy from typing import ( TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, ) from llama_index.bridge.pydantic import Field if TYPE_CHECKING: from llama_index.bridge.langchain import BasePromptTemplate as LangchainTemplate from llama_index.bridge.langchain import ( ConditionalPromptSelector as LangchainSelector, ) from llama_index.bridge.pydantic import BaseModel from llama_index.core.llms.types import ChatMessage from llama_index.core.query_pipeline.query_component import ( ChainableMixin, InputKeys, OutputKeys, QueryComponent, validate_and_convert_stringable, ) from llama_index.llms.base import BaseLLM from llama_index.llms.generic_utils import ( messages_to_prompt as default_messages_to_prompt, ) from llama_index.llms.generic_utils import ( prompt_to_messages, ) from llama_index.prompts.prompt_type import PromptType from llama_index.prompts.utils import get_template_vars from llama_index.types import BaseOutputParser class BasePromptTemplate(ChainableMixin, BaseModel, ABC): metadata: Dict[str, Any] template_vars: List[str] kwargs: Dict[str, str] output_parser: Optional[BaseOutputParser] template_var_mappings: Optional[Dict[str, Any]] = Field( default_factory=dict, description="Template variable mappings (Optional)." ) function_mappings: Optional[Dict[str, Callable]] = Field( default_factory=dict, description=( "Function mappings (Optional). This is a mapping from template " "variable names to functions that take in the current kwargs and " "return a string." ), ) def _map_template_vars(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: """For keys in template_var_mappings, swap in the right keys.""" template_var_mappings = self.template_var_mappings or {} return {template_var_mappings.get(k, k): v for k, v in kwargs.items()} def _map_function_vars(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: """For keys in function_mappings, compute values and combine w/ kwargs. Users can pass in functions instead of fixed values as format variables. For each function, we call the function with the current kwargs, get back the value, and then use that value in the template for the corresponding format variable. """ function_mappings = self.function_mappings or {} # first generate the values for the functions new_kwargs = {} for k, v in function_mappings.items(): # TODO: figure out what variables to pass into each function # is it the kwargs specified during query time? just the fixed kwargs? # all kwargs? new_kwargs[k] = v(**kwargs) # then, add the fixed variables only if not in new_kwargs already # (implying that function mapping will override fixed variables) for k, v in kwargs.items(): if k not in new_kwargs: new_kwargs[k] = v return new_kwargs def _map_all_vars(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: """Map both template and function variables. We (1) first call function mappings to compute functions, and then (2) call the template_var_mappings. """ # map function new_kwargs = self._map_function_vars(kwargs) # map template vars (to point to existing format vars in string template) return self._map_template_vars(new_kwargs) class Config: arbitrary_types_allowed = True @abstractmethod def partial_format(self, **kwargs: Any) -> "BasePromptTemplate": ... @abstractmethod def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str: ... @abstractmethod def format_messages( self, llm: Optional[BaseLLM] = None, **kwargs: Any ) -> List[ChatMessage]: ... @abstractmethod def get_template(self, llm: Optional[BaseLLM] = None) -> str: ... def _as_query_component( self, llm: Optional[BaseLLM] = None, **kwargs: Any ) -> QueryComponent: """As query component.""" return PromptComponent(prompt=self, format_messages=False, llm=llm) class PromptTemplate(BasePromptTemplate): template: str def __init__( self, template: str, prompt_type: str = PromptType.CUSTOM, output_parser: Optional[BaseOutputParser] = None, metadata: Optional[Dict[str, Any]] = None, template_var_mappings: Optional[Dict[str, Any]] = None, function_mappings: Optional[Dict[str, Callable]] = None, **kwargs: Any, ) -> None: if metadata is None: metadata = {} metadata["prompt_type"] = prompt_type template_vars = get_template_vars(template) super().__init__( template=template, template_vars=template_vars, kwargs=kwargs, metadata=metadata, output_parser=output_parser, template_var_mappings=template_var_mappings, function_mappings=function_mappings, ) def partial_format(self, **kwargs: Any) -> "PromptTemplate": """Partially format the prompt.""" # NOTE: this is a hack to get around deepcopy failing on output parser output_parser = self.output_parser self.output_parser = None # get function and fixed kwargs, and add that to a copy # of the current prompt object prompt = deepcopy(self) prompt.kwargs.update(kwargs) # NOTE: put the output parser back prompt.output_parser = output_parser self.output_parser = output_parser return prompt def format( self, llm: Optional[BaseLLM] = None, completion_to_prompt: Optional[Callable[[str], str]] = None, **kwargs: Any, ) -> str: """Format the prompt into a string.""" del llm # unused all_kwargs = { **self.kwargs, **kwargs, } mapped_all_kwargs = self._map_all_vars(all_kwargs) prompt = self.template.format(**mapped_all_kwargs) if self.output_parser is not None: prompt = self.output_parser.format(prompt) if completion_to_prompt is not None: prompt = completion_to_prompt(prompt) return prompt def format_messages( self, llm: Optional[BaseLLM] = None, **kwargs: Any ) -> List[ChatMessage]: """Format the prompt into a list of chat messages.""" del llm # unused prompt = self.format(**kwargs) return prompt_to_messages(prompt) def get_template(self, llm: Optional[BaseLLM] = None) -> str: return self.template class ChatPromptTemplate(BasePromptTemplate): message_templates: List[ChatMessage] def __init__( self, message_templates: List[ChatMessage], prompt_type: str = PromptType.CUSTOM, output_parser: Optional[BaseOutputParser] = None, metadata: Optional[Dict[str, Any]] = None, template_var_mappings: Optional[Dict[str, Any]] = None, function_mappings: Optional[Dict[str, Callable]] = None, **kwargs: Any, ): if metadata is None: metadata = {} metadata["prompt_type"] = prompt_type template_vars = [] for message_template in message_templates: template_vars.extend(get_template_vars(message_template.content or "")) super().__init__( message_templates=message_templates, kwargs=kwargs, metadata=metadata, output_parser=output_parser, template_vars=template_vars, template_var_mappings=template_var_mappings, function_mappings=function_mappings, ) def partial_format(self, **kwargs: Any) -> "ChatPromptTemplate": prompt = deepcopy(self) prompt.kwargs.update(kwargs) return prompt def format( self, llm: Optional[BaseLLM] = None, messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, **kwargs: Any, ) -> str: del llm # unused messages = self.format_messages(**kwargs) if messages_to_prompt is not None: return messages_to_prompt(messages) return default_messages_to_prompt(messages) def format_messages( self, llm: Optional[BaseLLM] = None, **kwargs: Any ) -> List[ChatMessage]: del llm # unused """Format the prompt into a list of chat messages.""" all_kwargs = { **self.kwargs, **kwargs, } mapped_all_kwargs = self._map_all_vars(all_kwargs) messages: List[ChatMessage] = [] for message_template in self.message_templates: template_vars = get_template_vars(message_template.content or "") relevant_kwargs = { k: v for k, v in mapped_all_kwargs.items() if k in template_vars } content_template = message_template.content or "" # if there's mappings specified, make sure those are used content = content_template.format(**relevant_kwargs) message: ChatMessage = message_template.copy() message.content = content messages.append(message) if self.output_parser is not None: messages = self.output_parser.format_messages(messages) return messages def get_template(self, llm: Optional[BaseLLM] = None) -> str: return default_messages_to_prompt(self.message_templates) def _as_query_component( self, llm: Optional[BaseLLM] = None, **kwargs: Any ) -> QueryComponent: """As query component.""" return PromptComponent(prompt=self, format_messages=True, llm=llm) class SelectorPromptTemplate(BasePromptTemplate): default_template: BasePromptTemplate conditionals: Optional[ List[Tuple[Callable[[BaseLLM], bool], BasePromptTemplate]] ] = None def __init__( self, default_template: BasePromptTemplate, conditionals: Optional[ List[Tuple[Callable[[BaseLLM], bool], BasePromptTemplate]] ] = None, ): metadata = default_template.metadata kwargs = default_template.kwargs template_vars = default_template.template_vars output_parser = default_template.output_parser super().__init__( default_template=default_template, conditionals=conditionals, metadata=metadata, kwargs=kwargs, template_vars=template_vars, output_parser=output_parser, ) def select(self, llm: Optional[BaseLLM] = None) -> BasePromptTemplate: # ensure output parser is up to date self.default_template.output_parser = self.output_parser if llm is None: return self.default_template if self.conditionals is not None: for condition, prompt in self.conditionals: if condition(llm): # ensure output parser is up to date prompt.output_parser = self.output_parser return prompt return self.default_template def partial_format(self, **kwargs: Any) -> "SelectorPromptTemplate": default_template = self.default_template.partial_format(**kwargs) if self.conditionals is None: conditionals = None else: conditionals = [ (condition, prompt.partial_format(**kwargs)) for condition, prompt in self.conditionals ] return SelectorPromptTemplate( default_template=default_template, conditionals=conditionals ) def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str: """Format the prompt into a string.""" prompt = self.select(llm=llm) return prompt.format(**kwargs) def format_messages( self, llm: Optional[BaseLLM] = None, **kwargs: Any ) -> List[ChatMessage]: """Format the prompt into a list of chat messages.""" prompt = self.select(llm=llm) return prompt.format_messages(**kwargs) def get_template(self, llm: Optional[BaseLLM] = None) -> str: prompt = self.select(llm=llm) return prompt.get_template(llm=llm) class LangchainPromptTemplate(BasePromptTemplate): selector: Any requires_langchain_llm: bool = False def __init__( self, template: Optional["LangchainTemplate"] = None, selector: Optional["LangchainSelector"] = None, output_parser: Optional[BaseOutputParser] = None, prompt_type: str = PromptType.CUSTOM, metadata: Optional[Dict[str, Any]] = None, template_var_mappings: Optional[Dict[str, Any]] = None, function_mappings: Optional[Dict[str, Callable]] = None, requires_langchain_llm: bool = False, ) -> None: try: from llama_index.bridge.langchain import ( ConditionalPromptSelector as LangchainSelector, ) except ImportError: raise ImportError( "Must install `llama_index[langchain]` to use LangchainPromptTemplate." ) if selector is None: if template is None: raise ValueError("Must provide either template or selector.") selector = LangchainSelector(default_prompt=template) else: if template is not None: raise ValueError("Must provide either template or selector.") selector = selector kwargs = selector.default_prompt.partial_variables template_vars = selector.default_prompt.input_variables if metadata is None: metadata = {} metadata["prompt_type"] = prompt_type super().__init__( selector=selector, metadata=metadata, kwargs=kwargs, template_vars=template_vars, output_parser=output_parser, template_var_mappings=template_var_mappings, function_mappings=function_mappings, requires_langchain_llm=requires_langchain_llm, ) def partial_format(self, **kwargs: Any) -> "BasePromptTemplate": """Partially format the prompt.""" from llama_index.bridge.langchain import ( ConditionalPromptSelector as LangchainSelector, ) mapped_kwargs = self._map_all_vars(kwargs) default_prompt = self.selector.default_prompt.partial(**mapped_kwargs) conditionals = [ (condition, prompt.partial(**mapped_kwargs)) for condition, prompt in self.selector.conditionals ] lc_selector = LangchainSelector( default_prompt=default_prompt, conditionals=conditionals ) # copy full prompt object, replace selector lc_prompt = deepcopy(self) lc_prompt.selector = lc_selector return lc_prompt def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str: """Format the prompt into a string.""" from llama_index.llms.langchain import LangChainLLM if llm is not None: # if llamaindex LLM is provided, and we require a langchain LLM, # then error. but otherwise if `requires_langchain_llm` is False, # then we can just use the default prompt if not isinstance(llm, LangChainLLM) and self.requires_langchain_llm: raise ValueError("Must provide a LangChainLLM.") elif not isinstance(llm, LangChainLLM): lc_template = self.selector.default_prompt else: lc_template = self.selector.get_prompt(llm=llm.llm) else: lc_template = self.selector.default_prompt # if there's mappings specified, make sure those are used mapped_kwargs = self._map_all_vars(kwargs) return lc_template.format(**mapped_kwargs) def format_messages( self, llm: Optional[BaseLLM] = None, **kwargs: Any ) -> List[ChatMessage]: """Format the prompt into a list of chat messages.""" from llama_index.llms.langchain import LangChainLLM from llama_index.llms.langchain_utils import from_lc_messages if llm is not None: # if llamaindex LLM is provided, and we require a langchain LLM, # then error. but otherwise if `requires_langchain_llm` is False, # then we can just use the default prompt if not isinstance(llm, LangChainLLM) and self.requires_langchain_llm: raise ValueError("Must provide a LangChainLLM.") elif not isinstance(llm, LangChainLLM): lc_template = self.selector.default_prompt else: lc_template = self.selector.get_prompt(llm=llm.llm) else: lc_template = self.selector.default_prompt # if there's mappings specified, make sure those are used mapped_kwargs = self._map_all_vars(kwargs) lc_prompt_value = lc_template.format_prompt(**mapped_kwargs) lc_messages = lc_prompt_value.to_messages() return from_lc_messages(lc_messages) def get_template(self, llm: Optional[BaseLLM] = None) -> str: from llama_index.llms.langchain import LangChainLLM if llm is not None: # if llamaindex LLM is provided, and we require a langchain LLM, # then error. but otherwise if `requires_langchain_llm` is False, # then we can just use the default prompt if not isinstance(llm, LangChainLLM) and self.requires_langchain_llm: raise ValueError("Must provide a LangChainLLM.") elif not isinstance(llm, LangChainLLM): lc_template = self.selector.default_prompt else: lc_template = self.selector.get_prompt(llm=llm.llm) else: lc_template = self.selector.default_prompt try: return str(lc_template.template) # type: ignore except AttributeError: return str(lc_template) # NOTE: only for backwards compatibility Prompt = PromptTemplate class PromptComponent(QueryComponent): """Prompt component.""" prompt: BasePromptTemplate = Field(..., description="Prompt") llm: Optional[BaseLLM] = Field( default=None, description="LLM to use for formatting prompt." ) format_messages: bool = Field( default=False, description="Whether to format the prompt into a list of chat messages.", ) class Config: arbitrary_types_allowed = True def set_callback_manager(self, callback_manager: Any) -> None: """Set callback manager.""" def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]: """Validate component inputs during run_component.""" keys = list(input.keys()) for k in keys: input[k] = validate_and_convert_stringable(input[k]) return input def _run_component(self, **kwargs: Any) -> Any: """Run component.""" if self.format_messages: output: Union[str, List[ChatMessage]] = self.prompt.format_messages( llm=self.llm, **kwargs ) else: output = self.prompt.format(llm=self.llm, **kwargs) return {"prompt": output} async def _arun_component(self, **kwargs: Any) -> Any: """Run component.""" # NOTE: no native async for prompt return self._run_component(**kwargs) @property def input_keys(self) -> InputKeys: """Input keys.""" return InputKeys.from_keys( set(self.prompt.template_vars) - set(self.prompt.kwargs) ) @property def output_keys(self) -> OutputKeys: """Output keys.""" return OutputKeys.from_keys({"prompt"})