573 lines
20 KiB
Python
573 lines
20 KiB
Python
"""Prompts."""
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
from copy import deepcopy
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
from llama_index.bridge.pydantic import Field
|
|
|
|
if TYPE_CHECKING:
|
|
from llama_index.bridge.langchain import BasePromptTemplate as LangchainTemplate
|
|
from llama_index.bridge.langchain import (
|
|
ConditionalPromptSelector as LangchainSelector,
|
|
)
|
|
from llama_index.bridge.pydantic import BaseModel
|
|
from llama_index.core.llms.types import ChatMessage
|
|
from llama_index.core.query_pipeline.query_component import (
|
|
ChainableMixin,
|
|
InputKeys,
|
|
OutputKeys,
|
|
QueryComponent,
|
|
validate_and_convert_stringable,
|
|
)
|
|
from llama_index.llms.base import BaseLLM
|
|
from llama_index.llms.generic_utils import (
|
|
messages_to_prompt as default_messages_to_prompt,
|
|
)
|
|
from llama_index.llms.generic_utils import (
|
|
prompt_to_messages,
|
|
)
|
|
from llama_index.prompts.prompt_type import PromptType
|
|
from llama_index.prompts.utils import get_template_vars
|
|
from llama_index.types import BaseOutputParser
|
|
|
|
|
|
class BasePromptTemplate(ChainableMixin, BaseModel, ABC):
|
|
metadata: Dict[str, Any]
|
|
template_vars: List[str]
|
|
kwargs: Dict[str, str]
|
|
output_parser: Optional[BaseOutputParser]
|
|
template_var_mappings: Optional[Dict[str, Any]] = Field(
|
|
default_factory=dict, description="Template variable mappings (Optional)."
|
|
)
|
|
function_mappings: Optional[Dict[str, Callable]] = Field(
|
|
default_factory=dict,
|
|
description=(
|
|
"Function mappings (Optional). This is a mapping from template "
|
|
"variable names to functions that take in the current kwargs and "
|
|
"return a string."
|
|
),
|
|
)
|
|
|
|
def _map_template_vars(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""For keys in template_var_mappings, swap in the right keys."""
|
|
template_var_mappings = self.template_var_mappings or {}
|
|
return {template_var_mappings.get(k, k): v for k, v in kwargs.items()}
|
|
|
|
def _map_function_vars(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""For keys in function_mappings, compute values and combine w/ kwargs.
|
|
|
|
Users can pass in functions instead of fixed values as format variables.
|
|
For each function, we call the function with the current kwargs,
|
|
get back the value, and then use that value in the template
|
|
for the corresponding format variable.
|
|
|
|
"""
|
|
function_mappings = self.function_mappings or {}
|
|
# first generate the values for the functions
|
|
new_kwargs = {}
|
|
for k, v in function_mappings.items():
|
|
# TODO: figure out what variables to pass into each function
|
|
# is it the kwargs specified during query time? just the fixed kwargs?
|
|
# all kwargs?
|
|
new_kwargs[k] = v(**kwargs)
|
|
|
|
# then, add the fixed variables only if not in new_kwargs already
|
|
# (implying that function mapping will override fixed variables)
|
|
for k, v in kwargs.items():
|
|
if k not in new_kwargs:
|
|
new_kwargs[k] = v
|
|
|
|
return new_kwargs
|
|
|
|
def _map_all_vars(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Map both template and function variables.
|
|
|
|
We (1) first call function mappings to compute functions,
|
|
and then (2) call the template_var_mappings.
|
|
|
|
"""
|
|
# map function
|
|
new_kwargs = self._map_function_vars(kwargs)
|
|
# map template vars (to point to existing format vars in string template)
|
|
return self._map_template_vars(new_kwargs)
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
@abstractmethod
|
|
def partial_format(self, **kwargs: Any) -> "BasePromptTemplate":
|
|
...
|
|
|
|
@abstractmethod
|
|
def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str:
|
|
...
|
|
|
|
@abstractmethod
|
|
def format_messages(
|
|
self, llm: Optional[BaseLLM] = None, **kwargs: Any
|
|
) -> List[ChatMessage]:
|
|
...
|
|
|
|
@abstractmethod
|
|
def get_template(self, llm: Optional[BaseLLM] = None) -> str:
|
|
...
|
|
|
|
def _as_query_component(
|
|
self, llm: Optional[BaseLLM] = None, **kwargs: Any
|
|
) -> QueryComponent:
|
|
"""As query component."""
|
|
return PromptComponent(prompt=self, format_messages=False, llm=llm)
|
|
|
|
|
|
class PromptTemplate(BasePromptTemplate):
|
|
template: str
|
|
|
|
def __init__(
|
|
self,
|
|
template: str,
|
|
prompt_type: str = PromptType.CUSTOM,
|
|
output_parser: Optional[BaseOutputParser] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
template_var_mappings: Optional[Dict[str, Any]] = None,
|
|
function_mappings: Optional[Dict[str, Callable]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
if metadata is None:
|
|
metadata = {}
|
|
metadata["prompt_type"] = prompt_type
|
|
|
|
template_vars = get_template_vars(template)
|
|
|
|
super().__init__(
|
|
template=template,
|
|
template_vars=template_vars,
|
|
kwargs=kwargs,
|
|
metadata=metadata,
|
|
output_parser=output_parser,
|
|
template_var_mappings=template_var_mappings,
|
|
function_mappings=function_mappings,
|
|
)
|
|
|
|
def partial_format(self, **kwargs: Any) -> "PromptTemplate":
|
|
"""Partially format the prompt."""
|
|
# NOTE: this is a hack to get around deepcopy failing on output parser
|
|
output_parser = self.output_parser
|
|
self.output_parser = None
|
|
|
|
# get function and fixed kwargs, and add that to a copy
|
|
# of the current prompt object
|
|
prompt = deepcopy(self)
|
|
prompt.kwargs.update(kwargs)
|
|
|
|
# NOTE: put the output parser back
|
|
prompt.output_parser = output_parser
|
|
self.output_parser = output_parser
|
|
return prompt
|
|
|
|
def format(
|
|
self,
|
|
llm: Optional[BaseLLM] = None,
|
|
completion_to_prompt: Optional[Callable[[str], str]] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Format the prompt into a string."""
|
|
del llm # unused
|
|
all_kwargs = {
|
|
**self.kwargs,
|
|
**kwargs,
|
|
}
|
|
|
|
mapped_all_kwargs = self._map_all_vars(all_kwargs)
|
|
prompt = self.template.format(**mapped_all_kwargs)
|
|
|
|
if self.output_parser is not None:
|
|
prompt = self.output_parser.format(prompt)
|
|
|
|
if completion_to_prompt is not None:
|
|
prompt = completion_to_prompt(prompt)
|
|
|
|
return prompt
|
|
|
|
def format_messages(
|
|
self, llm: Optional[BaseLLM] = None, **kwargs: Any
|
|
) -> List[ChatMessage]:
|
|
"""Format the prompt into a list of chat messages."""
|
|
del llm # unused
|
|
prompt = self.format(**kwargs)
|
|
return prompt_to_messages(prompt)
|
|
|
|
def get_template(self, llm: Optional[BaseLLM] = None) -> str:
|
|
return self.template
|
|
|
|
|
|
class ChatPromptTemplate(BasePromptTemplate):
|
|
message_templates: List[ChatMessage]
|
|
|
|
def __init__(
|
|
self,
|
|
message_templates: List[ChatMessage],
|
|
prompt_type: str = PromptType.CUSTOM,
|
|
output_parser: Optional[BaseOutputParser] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
template_var_mappings: Optional[Dict[str, Any]] = None,
|
|
function_mappings: Optional[Dict[str, Callable]] = None,
|
|
**kwargs: Any,
|
|
):
|
|
if metadata is None:
|
|
metadata = {}
|
|
metadata["prompt_type"] = prompt_type
|
|
|
|
template_vars = []
|
|
for message_template in message_templates:
|
|
template_vars.extend(get_template_vars(message_template.content or ""))
|
|
|
|
super().__init__(
|
|
message_templates=message_templates,
|
|
kwargs=kwargs,
|
|
metadata=metadata,
|
|
output_parser=output_parser,
|
|
template_vars=template_vars,
|
|
template_var_mappings=template_var_mappings,
|
|
function_mappings=function_mappings,
|
|
)
|
|
|
|
def partial_format(self, **kwargs: Any) -> "ChatPromptTemplate":
|
|
prompt = deepcopy(self)
|
|
prompt.kwargs.update(kwargs)
|
|
return prompt
|
|
|
|
def format(
|
|
self,
|
|
llm: Optional[BaseLLM] = None,
|
|
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
del llm # unused
|
|
messages = self.format_messages(**kwargs)
|
|
|
|
if messages_to_prompt is not None:
|
|
return messages_to_prompt(messages)
|
|
|
|
return default_messages_to_prompt(messages)
|
|
|
|
def format_messages(
|
|
self, llm: Optional[BaseLLM] = None, **kwargs: Any
|
|
) -> List[ChatMessage]:
|
|
del llm # unused
|
|
"""Format the prompt into a list of chat messages."""
|
|
all_kwargs = {
|
|
**self.kwargs,
|
|
**kwargs,
|
|
}
|
|
mapped_all_kwargs = self._map_all_vars(all_kwargs)
|
|
|
|
messages: List[ChatMessage] = []
|
|
for message_template in self.message_templates:
|
|
template_vars = get_template_vars(message_template.content or "")
|
|
relevant_kwargs = {
|
|
k: v for k, v in mapped_all_kwargs.items() if k in template_vars
|
|
}
|
|
content_template = message_template.content or ""
|
|
|
|
# if there's mappings specified, make sure those are used
|
|
content = content_template.format(**relevant_kwargs)
|
|
|
|
message: ChatMessage = message_template.copy()
|
|
message.content = content
|
|
messages.append(message)
|
|
|
|
if self.output_parser is not None:
|
|
messages = self.output_parser.format_messages(messages)
|
|
|
|
return messages
|
|
|
|
def get_template(self, llm: Optional[BaseLLM] = None) -> str:
|
|
return default_messages_to_prompt(self.message_templates)
|
|
|
|
def _as_query_component(
|
|
self, llm: Optional[BaseLLM] = None, **kwargs: Any
|
|
) -> QueryComponent:
|
|
"""As query component."""
|
|
return PromptComponent(prompt=self, format_messages=True, llm=llm)
|
|
|
|
|
|
class SelectorPromptTemplate(BasePromptTemplate):
|
|
default_template: BasePromptTemplate
|
|
conditionals: Optional[
|
|
List[Tuple[Callable[[BaseLLM], bool], BasePromptTemplate]]
|
|
] = None
|
|
|
|
def __init__(
|
|
self,
|
|
default_template: BasePromptTemplate,
|
|
conditionals: Optional[
|
|
List[Tuple[Callable[[BaseLLM], bool], BasePromptTemplate]]
|
|
] = None,
|
|
):
|
|
metadata = default_template.metadata
|
|
kwargs = default_template.kwargs
|
|
template_vars = default_template.template_vars
|
|
output_parser = default_template.output_parser
|
|
super().__init__(
|
|
default_template=default_template,
|
|
conditionals=conditionals,
|
|
metadata=metadata,
|
|
kwargs=kwargs,
|
|
template_vars=template_vars,
|
|
output_parser=output_parser,
|
|
)
|
|
|
|
def select(self, llm: Optional[BaseLLM] = None) -> BasePromptTemplate:
|
|
# ensure output parser is up to date
|
|
self.default_template.output_parser = self.output_parser
|
|
|
|
if llm is None:
|
|
return self.default_template
|
|
|
|
if self.conditionals is not None:
|
|
for condition, prompt in self.conditionals:
|
|
if condition(llm):
|
|
# ensure output parser is up to date
|
|
prompt.output_parser = self.output_parser
|
|
return prompt
|
|
|
|
return self.default_template
|
|
|
|
def partial_format(self, **kwargs: Any) -> "SelectorPromptTemplate":
|
|
default_template = self.default_template.partial_format(**kwargs)
|
|
if self.conditionals is None:
|
|
conditionals = None
|
|
else:
|
|
conditionals = [
|
|
(condition, prompt.partial_format(**kwargs))
|
|
for condition, prompt in self.conditionals
|
|
]
|
|
return SelectorPromptTemplate(
|
|
default_template=default_template, conditionals=conditionals
|
|
)
|
|
|
|
def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str:
|
|
"""Format the prompt into a string."""
|
|
prompt = self.select(llm=llm)
|
|
return prompt.format(**kwargs)
|
|
|
|
def format_messages(
|
|
self, llm: Optional[BaseLLM] = None, **kwargs: Any
|
|
) -> List[ChatMessage]:
|
|
"""Format the prompt into a list of chat messages."""
|
|
prompt = self.select(llm=llm)
|
|
return prompt.format_messages(**kwargs)
|
|
|
|
def get_template(self, llm: Optional[BaseLLM] = None) -> str:
|
|
prompt = self.select(llm=llm)
|
|
return prompt.get_template(llm=llm)
|
|
|
|
|
|
class LangchainPromptTemplate(BasePromptTemplate):
|
|
selector: Any
|
|
requires_langchain_llm: bool = False
|
|
|
|
def __init__(
|
|
self,
|
|
template: Optional["LangchainTemplate"] = None,
|
|
selector: Optional["LangchainSelector"] = None,
|
|
output_parser: Optional[BaseOutputParser] = None,
|
|
prompt_type: str = PromptType.CUSTOM,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
template_var_mappings: Optional[Dict[str, Any]] = None,
|
|
function_mappings: Optional[Dict[str, Callable]] = None,
|
|
requires_langchain_llm: bool = False,
|
|
) -> None:
|
|
try:
|
|
from llama_index.bridge.langchain import (
|
|
ConditionalPromptSelector as LangchainSelector,
|
|
)
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Must install `llama_index[langchain]` to use LangchainPromptTemplate."
|
|
)
|
|
if selector is None:
|
|
if template is None:
|
|
raise ValueError("Must provide either template or selector.")
|
|
selector = LangchainSelector(default_prompt=template)
|
|
else:
|
|
if template is not None:
|
|
raise ValueError("Must provide either template or selector.")
|
|
selector = selector
|
|
|
|
kwargs = selector.default_prompt.partial_variables
|
|
template_vars = selector.default_prompt.input_variables
|
|
|
|
if metadata is None:
|
|
metadata = {}
|
|
metadata["prompt_type"] = prompt_type
|
|
|
|
super().__init__(
|
|
selector=selector,
|
|
metadata=metadata,
|
|
kwargs=kwargs,
|
|
template_vars=template_vars,
|
|
output_parser=output_parser,
|
|
template_var_mappings=template_var_mappings,
|
|
function_mappings=function_mappings,
|
|
requires_langchain_llm=requires_langchain_llm,
|
|
)
|
|
|
|
def partial_format(self, **kwargs: Any) -> "BasePromptTemplate":
|
|
"""Partially format the prompt."""
|
|
from llama_index.bridge.langchain import (
|
|
ConditionalPromptSelector as LangchainSelector,
|
|
)
|
|
|
|
mapped_kwargs = self._map_all_vars(kwargs)
|
|
default_prompt = self.selector.default_prompt.partial(**mapped_kwargs)
|
|
conditionals = [
|
|
(condition, prompt.partial(**mapped_kwargs))
|
|
for condition, prompt in self.selector.conditionals
|
|
]
|
|
lc_selector = LangchainSelector(
|
|
default_prompt=default_prompt, conditionals=conditionals
|
|
)
|
|
|
|
# copy full prompt object, replace selector
|
|
lc_prompt = deepcopy(self)
|
|
lc_prompt.selector = lc_selector
|
|
return lc_prompt
|
|
|
|
def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str:
|
|
"""Format the prompt into a string."""
|
|
from llama_index.llms.langchain import LangChainLLM
|
|
|
|
if llm is not None:
|
|
# if llamaindex LLM is provided, and we require a langchain LLM,
|
|
# then error. but otherwise if `requires_langchain_llm` is False,
|
|
# then we can just use the default prompt
|
|
if not isinstance(llm, LangChainLLM) and self.requires_langchain_llm:
|
|
raise ValueError("Must provide a LangChainLLM.")
|
|
elif not isinstance(llm, LangChainLLM):
|
|
lc_template = self.selector.default_prompt
|
|
else:
|
|
lc_template = self.selector.get_prompt(llm=llm.llm)
|
|
else:
|
|
lc_template = self.selector.default_prompt
|
|
|
|
# if there's mappings specified, make sure those are used
|
|
mapped_kwargs = self._map_all_vars(kwargs)
|
|
return lc_template.format(**mapped_kwargs)
|
|
|
|
def format_messages(
|
|
self, llm: Optional[BaseLLM] = None, **kwargs: Any
|
|
) -> List[ChatMessage]:
|
|
"""Format the prompt into a list of chat messages."""
|
|
from llama_index.llms.langchain import LangChainLLM
|
|
from llama_index.llms.langchain_utils import from_lc_messages
|
|
|
|
if llm is not None:
|
|
# if llamaindex LLM is provided, and we require a langchain LLM,
|
|
# then error. but otherwise if `requires_langchain_llm` is False,
|
|
# then we can just use the default prompt
|
|
if not isinstance(llm, LangChainLLM) and self.requires_langchain_llm:
|
|
raise ValueError("Must provide a LangChainLLM.")
|
|
elif not isinstance(llm, LangChainLLM):
|
|
lc_template = self.selector.default_prompt
|
|
else:
|
|
lc_template = self.selector.get_prompt(llm=llm.llm)
|
|
else:
|
|
lc_template = self.selector.default_prompt
|
|
|
|
# if there's mappings specified, make sure those are used
|
|
mapped_kwargs = self._map_all_vars(kwargs)
|
|
lc_prompt_value = lc_template.format_prompt(**mapped_kwargs)
|
|
lc_messages = lc_prompt_value.to_messages()
|
|
return from_lc_messages(lc_messages)
|
|
|
|
def get_template(self, llm: Optional[BaseLLM] = None) -> str:
|
|
from llama_index.llms.langchain import LangChainLLM
|
|
|
|
if llm is not None:
|
|
# if llamaindex LLM is provided, and we require a langchain LLM,
|
|
# then error. but otherwise if `requires_langchain_llm` is False,
|
|
# then we can just use the default prompt
|
|
if not isinstance(llm, LangChainLLM) and self.requires_langchain_llm:
|
|
raise ValueError("Must provide a LangChainLLM.")
|
|
elif not isinstance(llm, LangChainLLM):
|
|
lc_template = self.selector.default_prompt
|
|
else:
|
|
lc_template = self.selector.get_prompt(llm=llm.llm)
|
|
else:
|
|
lc_template = self.selector.default_prompt
|
|
|
|
try:
|
|
return str(lc_template.template) # type: ignore
|
|
except AttributeError:
|
|
return str(lc_template)
|
|
|
|
|
|
# NOTE: only for backwards compatibility
|
|
Prompt = PromptTemplate
|
|
|
|
|
|
class PromptComponent(QueryComponent):
|
|
"""Prompt component."""
|
|
|
|
prompt: BasePromptTemplate = Field(..., description="Prompt")
|
|
llm: Optional[BaseLLM] = Field(
|
|
default=None, description="LLM to use for formatting prompt."
|
|
)
|
|
format_messages: bool = Field(
|
|
default=False,
|
|
description="Whether to format the prompt into a list of chat messages.",
|
|
)
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def set_callback_manager(self, callback_manager: Any) -> None:
|
|
"""Set callback manager."""
|
|
|
|
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate component inputs during run_component."""
|
|
keys = list(input.keys())
|
|
for k in keys:
|
|
input[k] = validate_and_convert_stringable(input[k])
|
|
return input
|
|
|
|
def _run_component(self, **kwargs: Any) -> Any:
|
|
"""Run component."""
|
|
if self.format_messages:
|
|
output: Union[str, List[ChatMessage]] = self.prompt.format_messages(
|
|
llm=self.llm, **kwargs
|
|
)
|
|
else:
|
|
output = self.prompt.format(llm=self.llm, **kwargs)
|
|
return {"prompt": output}
|
|
|
|
async def _arun_component(self, **kwargs: Any) -> Any:
|
|
"""Run component."""
|
|
# NOTE: no native async for prompt
|
|
return self._run_component(**kwargs)
|
|
|
|
@property
|
|
def input_keys(self) -> InputKeys:
|
|
"""Input keys."""
|
|
return InputKeys.from_keys(
|
|
set(self.prompt.template_vars) - set(self.prompt.kwargs)
|
|
)
|
|
|
|
@property
|
|
def output_keys(self) -> OutputKeys:
|
|
"""Output keys."""
|
|
return OutputKeys.from_keys({"prompt"})
|