97 lines
3.2 KiB
Python
97 lines
3.2 KiB
Python
"""Prompt Mixin."""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from collections import defaultdict
|
|
from copy import deepcopy
|
|
from typing import Dict, Union
|
|
|
|
from llama_index.prompts.base import BasePromptTemplate
|
|
|
|
HasPromptType = Union["PromptMixin", BasePromptTemplate]
|
|
PromptDictType = Dict[str, BasePromptTemplate]
|
|
PromptMixinType = Dict[str, "PromptMixin"]
|
|
|
|
|
|
class PromptMixin(ABC):
|
|
"""Prompt mixin.
|
|
|
|
This mixin is used in other modules, like query engines, response synthesizers.
|
|
This shows that the module supports getting, setting prompts,
|
|
both within the immediate module as well as child modules.
|
|
|
|
"""
|
|
|
|
def _validate_prompts(
|
|
self,
|
|
prompts_dict: PromptDictType,
|
|
module_dict: PromptMixinType,
|
|
) -> None:
|
|
"""Validate prompts."""
|
|
# check if prompts_dict, module_dict has restricted ":" token
|
|
for key in prompts_dict:
|
|
if ":" in key:
|
|
raise ValueError(f"Prompt key {key} cannot contain ':'.")
|
|
|
|
for key in module_dict:
|
|
if ":" in key:
|
|
raise ValueError(f"Prompt key {key} cannot contain ':'.")
|
|
|
|
def get_prompts(self) -> Dict[str, BasePromptTemplate]:
|
|
"""Get a prompt."""
|
|
prompts_dict = self._get_prompts()
|
|
module_dict = self._get_prompt_modules()
|
|
self._validate_prompts(prompts_dict, module_dict)
|
|
|
|
# avoid modifying the original dict
|
|
all_prompts = deepcopy(prompts_dict)
|
|
for module_name, prompt_module in module_dict.items():
|
|
# append module name to each key in sub-modules by ":"
|
|
for key, prompt in prompt_module.get_prompts().items():
|
|
all_prompts[f"{module_name}:{key}"] = prompt
|
|
return all_prompts
|
|
|
|
def update_prompts(self, prompts_dict: Dict[str, BasePromptTemplate]) -> None:
|
|
"""Update prompts.
|
|
|
|
Other prompts will remain in place.
|
|
|
|
"""
|
|
prompt_modules = self._get_prompt_modules()
|
|
|
|
# update prompts for current module
|
|
self._update_prompts(prompts_dict)
|
|
|
|
# get sub-module keys
|
|
# mapping from module name to sub-module prompt keys
|
|
sub_prompt_dicts: Dict[str, PromptDictType] = defaultdict(dict)
|
|
for key in prompts_dict:
|
|
if ":" in key:
|
|
module_name, sub_key = key.split(":")
|
|
sub_prompt_dicts[module_name][sub_key] = prompts_dict[key]
|
|
|
|
# now update prompts for submodules
|
|
for module_name, sub_prompt_dict in sub_prompt_dicts.items():
|
|
if module_name not in prompt_modules:
|
|
raise ValueError(f"Module {module_name} not found.")
|
|
module = prompt_modules[module_name]
|
|
module.update_prompts(sub_prompt_dict)
|
|
|
|
@abstractmethod
|
|
def _get_prompts(self) -> PromptDictType:
|
|
"""Get prompts."""
|
|
|
|
@abstractmethod
|
|
def _get_prompt_modules(self) -> PromptMixinType:
|
|
"""Get prompt sub-modules.
|
|
|
|
Return a dictionary of sub-modules within the current module
|
|
that also implement PromptMixin (so that their prompts can also be get/set).
|
|
|
|
Can be blank if no sub-modules.
|
|
|
|
"""
|
|
|
|
@abstractmethod
|
|
def _update_prompts(self, prompts_dict: PromptDictType) -> None:
|
|
"""Update prompts."""
|