230 lines
7.5 KiB
Python
230 lines
7.5 KiB
Python
from typing import Any, Dict, List, Optional, Sequence, cast
|
|
|
|
from llama_index.core.base_selector import (
|
|
BaseSelector,
|
|
SelectorResult,
|
|
SingleSelection,
|
|
)
|
|
from llama_index.llm_predictor.base import LLMPredictorType
|
|
from llama_index.output_parsers.base import StructuredOutput
|
|
from llama_index.output_parsers.selection import Answer, SelectionOutputParser
|
|
from llama_index.prompts.mixin import PromptDictType
|
|
from llama_index.prompts.prompt_type import PromptType
|
|
from llama_index.schema import QueryBundle
|
|
from llama_index.selectors.prompts import (
|
|
DEFAULT_MULTI_SELECT_PROMPT_TMPL,
|
|
DEFAULT_SINGLE_SELECT_PROMPT_TMPL,
|
|
MultiSelectPrompt,
|
|
SingleSelectPrompt,
|
|
)
|
|
from llama_index.service_context import ServiceContext
|
|
from llama_index.tools.types import ToolMetadata
|
|
from llama_index.types import BaseOutputParser
|
|
|
|
|
|
def _build_choices_text(choices: Sequence[ToolMetadata]) -> str:
|
|
"""Convert sequence of metadata to enumeration text."""
|
|
texts: List[str] = []
|
|
for ind, choice in enumerate(choices):
|
|
text = " ".join(choice.description.splitlines())
|
|
text = f"({ind + 1}) {text}" # to one indexing
|
|
texts.append(text)
|
|
return "\n\n".join(texts)
|
|
|
|
|
|
def _structured_output_to_selector_result(output: Any) -> SelectorResult:
|
|
"""Convert structured output to selector result."""
|
|
structured_output = cast(StructuredOutput, output)
|
|
answers = cast(List[Answer], structured_output.parsed_output)
|
|
|
|
# adjust for zero indexing
|
|
selections = [
|
|
SingleSelection(index=answer.choice - 1, reason=answer.reason)
|
|
for answer in answers
|
|
]
|
|
return SelectorResult(selections=selections)
|
|
|
|
|
|
class LLMSingleSelector(BaseSelector):
|
|
"""LLM single selector.
|
|
|
|
LLM-based selector that chooses one out of many options.
|
|
|
|
Args:
|
|
LLM (LLM): An LLM.
|
|
prompt (SingleSelectPrompt): A LLM prompt for selecting one out of many options.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
llm: LLMPredictorType,
|
|
prompt: SingleSelectPrompt,
|
|
) -> None:
|
|
self._llm = llm
|
|
self._prompt = prompt
|
|
|
|
if self._prompt.output_parser is None:
|
|
raise ValueError("Prompt should have output parser.")
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
service_context: Optional[ServiceContext] = None,
|
|
prompt_template_str: Optional[str] = None,
|
|
output_parser: Optional[BaseOutputParser] = None,
|
|
) -> "LLMSingleSelector":
|
|
# optionally initialize defaults
|
|
service_context = service_context or ServiceContext.from_defaults()
|
|
prompt_template_str = prompt_template_str or DEFAULT_SINGLE_SELECT_PROMPT_TMPL
|
|
output_parser = output_parser or SelectionOutputParser()
|
|
|
|
# construct prompt
|
|
prompt = SingleSelectPrompt(
|
|
template=prompt_template_str,
|
|
output_parser=output_parser,
|
|
prompt_type=PromptType.SINGLE_SELECT,
|
|
)
|
|
return cls(service_context.llm, prompt)
|
|
|
|
def _get_prompts(self) -> Dict[str, Any]:
|
|
"""Get prompts."""
|
|
return {"prompt": self._prompt}
|
|
|
|
def _update_prompts(self, prompts: PromptDictType) -> None:
|
|
"""Update prompts."""
|
|
if "prompt" in prompts:
|
|
self._prompt = prompts["prompt"]
|
|
|
|
def _select(
|
|
self, choices: Sequence[ToolMetadata], query: QueryBundle
|
|
) -> SelectorResult:
|
|
# prepare input
|
|
choices_text = _build_choices_text(choices)
|
|
|
|
# predict
|
|
prediction = self._llm.predict(
|
|
prompt=self._prompt,
|
|
num_choices=len(choices),
|
|
context_list=choices_text,
|
|
query_str=query.query_str,
|
|
)
|
|
|
|
# parse output
|
|
assert self._prompt.output_parser is not None
|
|
parse = self._prompt.output_parser.parse(prediction)
|
|
return _structured_output_to_selector_result(parse)
|
|
|
|
async def _aselect(
|
|
self, choices: Sequence[ToolMetadata], query: QueryBundle
|
|
) -> SelectorResult:
|
|
# prepare input
|
|
choices_text = _build_choices_text(choices)
|
|
|
|
# predict
|
|
prediction = await self._llm.apredict(
|
|
prompt=self._prompt,
|
|
num_choices=len(choices),
|
|
context_list=choices_text,
|
|
query_str=query.query_str,
|
|
)
|
|
|
|
# parse output
|
|
assert self._prompt.output_parser is not None
|
|
parse = self._prompt.output_parser.parse(prediction)
|
|
return _structured_output_to_selector_result(parse)
|
|
|
|
|
|
class LLMMultiSelector(BaseSelector):
|
|
"""LLM multi selector.
|
|
|
|
LLM-based selector that chooses multiple out of many options.
|
|
|
|
Args:
|
|
llm (LLM): An LLM.
|
|
prompt (SingleSelectPrompt): A LLM prompt for selecting multiple out of many
|
|
options.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
llm: LLMPredictorType,
|
|
prompt: MultiSelectPrompt,
|
|
max_outputs: Optional[int] = None,
|
|
) -> None:
|
|
self._llm = llm
|
|
self._prompt = prompt
|
|
self._max_outputs = max_outputs
|
|
|
|
if self._prompt.output_parser is None:
|
|
raise ValueError("Prompt should have output parser.")
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
service_context: Optional[ServiceContext] = None,
|
|
prompt_template_str: Optional[str] = None,
|
|
output_parser: Optional[BaseOutputParser] = None,
|
|
max_outputs: Optional[int] = None,
|
|
) -> "LLMMultiSelector":
|
|
service_context = service_context or ServiceContext.from_defaults()
|
|
prompt_template_str = prompt_template_str or DEFAULT_MULTI_SELECT_PROMPT_TMPL
|
|
output_parser = output_parser or SelectionOutputParser()
|
|
|
|
# add output formatting
|
|
prompt_template_str = output_parser.format(prompt_template_str)
|
|
|
|
# construct prompt
|
|
prompt = MultiSelectPrompt(
|
|
template=prompt_template_str,
|
|
output_parser=output_parser,
|
|
prompt_type=PromptType.MULTI_SELECT,
|
|
)
|
|
return cls(service_context.llm, prompt, max_outputs)
|
|
|
|
def _get_prompts(self) -> Dict[str, Any]:
|
|
"""Get prompts."""
|
|
return {"prompt": self._prompt}
|
|
|
|
def _update_prompts(self, prompts: PromptDictType) -> None:
|
|
"""Update prompts."""
|
|
if "prompt" in prompts:
|
|
self._prompt = prompts["prompt"]
|
|
|
|
def _select(
|
|
self, choices: Sequence[ToolMetadata], query: QueryBundle
|
|
) -> SelectorResult:
|
|
# prepare input
|
|
context_list = _build_choices_text(choices)
|
|
max_outputs = self._max_outputs or len(choices)
|
|
|
|
prediction = self._llm.predict(
|
|
prompt=self._prompt,
|
|
num_choices=len(choices),
|
|
max_outputs=max_outputs,
|
|
context_list=context_list,
|
|
query_str=query.query_str,
|
|
)
|
|
|
|
assert self._prompt.output_parser is not None
|
|
parsed = self._prompt.output_parser.parse(prediction)
|
|
return _structured_output_to_selector_result(parsed)
|
|
|
|
async def _aselect(
|
|
self, choices: Sequence[ToolMetadata], query: QueryBundle
|
|
) -> SelectorResult:
|
|
# prepare input
|
|
context_list = _build_choices_text(choices)
|
|
max_outputs = self._max_outputs or len(choices)
|
|
|
|
prediction = await self._llm.apredict(
|
|
prompt=self._prompt,
|
|
num_choices=len(choices),
|
|
max_outputs=max_outputs,
|
|
context_list=context_list,
|
|
query_str=query.query_str,
|
|
)
|
|
|
|
assert self._prompt.output_parser is not None
|
|
parsed = self._prompt.output_parser.parse(prediction)
|
|
return _structured_output_to_selector_result(parsed)
|