148 lines
4.7 KiB
Python
148 lines
4.7 KiB
Python
from typing import Any, Dict, Optional, Sequence
|
|
|
|
from llama_index.core.base_selector import (
|
|
BaseSelector,
|
|
MultiSelection,
|
|
SelectorResult,
|
|
SingleSelection,
|
|
)
|
|
from llama_index.llms.openai import OpenAI
|
|
from llama_index.program.openai_program import OpenAIPydanticProgram
|
|
from llama_index.prompts.mixin import PromptDictType
|
|
from llama_index.schema import QueryBundle
|
|
from llama_index.selectors.llm_selectors import _build_choices_text
|
|
from llama_index.selectors.prompts import (
|
|
DEFAULT_MULTI_PYD_SELECT_PROMPT_TMPL,
|
|
DEFAULT_SINGLE_PYD_SELECT_PROMPT_TMPL,
|
|
)
|
|
from llama_index.tools.types import ToolMetadata
|
|
from llama_index.types import BasePydanticProgram
|
|
|
|
|
|
def _pydantic_output_to_selector_result(output: Any) -> SelectorResult:
|
|
"""
|
|
Convert pydantic output to selector result.
|
|
Takes into account zero-indexing on answer indexes.
|
|
"""
|
|
if isinstance(output, SingleSelection):
|
|
output.index -= 1
|
|
return SelectorResult(selections=[output])
|
|
elif isinstance(output, MultiSelection):
|
|
for idx in range(len(output.selections)):
|
|
output.selections[idx].index -= 1
|
|
return SelectorResult(selections=output.selections)
|
|
else:
|
|
raise ValueError(f"Unsupported output type: {type(output)}")
|
|
|
|
|
|
class PydanticSingleSelector(BaseSelector):
|
|
def __init__(self, selector_program: BasePydanticProgram) -> None:
|
|
self._selector_program = selector_program
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
program: Optional[BasePydanticProgram] = None,
|
|
llm: Optional[OpenAI] = None,
|
|
prompt_template_str: str = DEFAULT_SINGLE_PYD_SELECT_PROMPT_TMPL,
|
|
verbose: bool = False,
|
|
) -> "PydanticSingleSelector":
|
|
if program is None:
|
|
program = OpenAIPydanticProgram.from_defaults(
|
|
output_cls=SingleSelection,
|
|
prompt_template_str=prompt_template_str,
|
|
llm=llm,
|
|
verbose=verbose,
|
|
)
|
|
|
|
return cls(selector_program=program)
|
|
|
|
def _get_prompts(self) -> Dict[str, Any]:
|
|
"""Get prompts."""
|
|
# TODO: no accessible prompts for a base pydantic program
|
|
return {}
|
|
|
|
def _update_prompts(self, prompts: PromptDictType) -> None:
|
|
"""Update prompts."""
|
|
|
|
def _select(
|
|
self, choices: Sequence[ToolMetadata], query: QueryBundle
|
|
) -> SelectorResult:
|
|
# prepare input
|
|
choices_text = _build_choices_text(choices)
|
|
|
|
# predict
|
|
prediction = self._selector_program(
|
|
num_choices=len(choices),
|
|
context_list=choices_text,
|
|
query_str=query.query_str,
|
|
)
|
|
|
|
# parse output
|
|
return _pydantic_output_to_selector_result(prediction)
|
|
|
|
async def _aselect(
|
|
self, choices: Sequence[ToolMetadata], query: QueryBundle
|
|
) -> SelectorResult:
|
|
raise NotImplementedError(
|
|
"Async selection not supported for Pydantic Selectors."
|
|
)
|
|
|
|
|
|
class PydanticMultiSelector(BaseSelector):
|
|
def __init__(
|
|
self, selector_program: BasePydanticProgram, max_outputs: Optional[int] = None
|
|
) -> None:
|
|
self._selector_program = selector_program
|
|
self._max_outputs = max_outputs
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
program: Optional[BasePydanticProgram] = None,
|
|
llm: Optional[OpenAI] = None,
|
|
prompt_template_str: str = DEFAULT_MULTI_PYD_SELECT_PROMPT_TMPL,
|
|
max_outputs: Optional[int] = None,
|
|
verbose: bool = False,
|
|
) -> "PydanticMultiSelector":
|
|
if program is None:
|
|
program = OpenAIPydanticProgram.from_defaults(
|
|
output_cls=MultiSelection,
|
|
prompt_template_str=prompt_template_str,
|
|
llm=llm,
|
|
verbose=verbose,
|
|
)
|
|
|
|
return cls(selector_program=program, max_outputs=max_outputs)
|
|
|
|
def _get_prompts(self) -> Dict[str, Any]:
|
|
"""Get prompts."""
|
|
# TODO: no accessible prompts for a base pydantic program
|
|
return {}
|
|
|
|
def _update_prompts(self, prompts: PromptDictType) -> None:
|
|
"""Update prompts."""
|
|
|
|
def _select(
|
|
self, choices: Sequence[ToolMetadata], query: QueryBundle
|
|
) -> SelectorResult:
|
|
# prepare input
|
|
context_list = _build_choices_text(choices)
|
|
max_outputs = self._max_outputs or len(choices)
|
|
|
|
# predict
|
|
prediction = self._selector_program(
|
|
num_choices=len(choices),
|
|
max_outputs=max_outputs,
|
|
context_list=context_list,
|
|
query_str=query.query_str,
|
|
)
|
|
|
|
# parse output
|
|
return _pydantic_output_to_selector_result(prediction)
|
|
|
|
async def _aselect(
|
|
self, choices: Sequence[ToolMetadata], query: QueryBundle
|
|
) -> SelectorResult:
|
|
return self._select(choices, query)
|