faiss_rag_enterprise/llama_index/selectors/pydantic_selectors.py

148 lines
4.7 KiB
Python

from typing import Any, Dict, Optional, Sequence
from llama_index.core.base_selector import (
BaseSelector,
MultiSelection,
SelectorResult,
SingleSelection,
)
from llama_index.llms.openai import OpenAI
from llama_index.program.openai_program import OpenAIPydanticProgram
from llama_index.prompts.mixin import PromptDictType
from llama_index.schema import QueryBundle
from llama_index.selectors.llm_selectors import _build_choices_text
from llama_index.selectors.prompts import (
DEFAULT_MULTI_PYD_SELECT_PROMPT_TMPL,
DEFAULT_SINGLE_PYD_SELECT_PROMPT_TMPL,
)
from llama_index.tools.types import ToolMetadata
from llama_index.types import BasePydanticProgram
def _pydantic_output_to_selector_result(output: Any) -> SelectorResult:
"""
Convert pydantic output to selector result.
Takes into account zero-indexing on answer indexes.
"""
if isinstance(output, SingleSelection):
output.index -= 1
return SelectorResult(selections=[output])
elif isinstance(output, MultiSelection):
for idx in range(len(output.selections)):
output.selections[idx].index -= 1
return SelectorResult(selections=output.selections)
else:
raise ValueError(f"Unsupported output type: {type(output)}")
class PydanticSingleSelector(BaseSelector):
def __init__(self, selector_program: BasePydanticProgram) -> None:
self._selector_program = selector_program
@classmethod
def from_defaults(
cls,
program: Optional[BasePydanticProgram] = None,
llm: Optional[OpenAI] = None,
prompt_template_str: str = DEFAULT_SINGLE_PYD_SELECT_PROMPT_TMPL,
verbose: bool = False,
) -> "PydanticSingleSelector":
if program is None:
program = OpenAIPydanticProgram.from_defaults(
output_cls=SingleSelection,
prompt_template_str=prompt_template_str,
llm=llm,
verbose=verbose,
)
return cls(selector_program=program)
def _get_prompts(self) -> Dict[str, Any]:
"""Get prompts."""
# TODO: no accessible prompts for a base pydantic program
return {}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
def _select(
self, choices: Sequence[ToolMetadata], query: QueryBundle
) -> SelectorResult:
# prepare input
choices_text = _build_choices_text(choices)
# predict
prediction = self._selector_program(
num_choices=len(choices),
context_list=choices_text,
query_str=query.query_str,
)
# parse output
return _pydantic_output_to_selector_result(prediction)
async def _aselect(
self, choices: Sequence[ToolMetadata], query: QueryBundle
) -> SelectorResult:
raise NotImplementedError(
"Async selection not supported for Pydantic Selectors."
)
class PydanticMultiSelector(BaseSelector):
def __init__(
self, selector_program: BasePydanticProgram, max_outputs: Optional[int] = None
) -> None:
self._selector_program = selector_program
self._max_outputs = max_outputs
@classmethod
def from_defaults(
cls,
program: Optional[BasePydanticProgram] = None,
llm: Optional[OpenAI] = None,
prompt_template_str: str = DEFAULT_MULTI_PYD_SELECT_PROMPT_TMPL,
max_outputs: Optional[int] = None,
verbose: bool = False,
) -> "PydanticMultiSelector":
if program is None:
program = OpenAIPydanticProgram.from_defaults(
output_cls=MultiSelection,
prompt_template_str=prompt_template_str,
llm=llm,
verbose=verbose,
)
return cls(selector_program=program, max_outputs=max_outputs)
def _get_prompts(self) -> Dict[str, Any]:
"""Get prompts."""
# TODO: no accessible prompts for a base pydantic program
return {}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
def _select(
self, choices: Sequence[ToolMetadata], query: QueryBundle
) -> SelectorResult:
# prepare input
context_list = _build_choices_text(choices)
max_outputs = self._max_outputs or len(choices)
# predict
prediction = self._selector_program(
num_choices=len(choices),
max_outputs=max_outputs,
context_list=context_list,
query_str=query.query_str,
)
# parse output
return _pydantic_output_to_selector_result(prediction)
async def _aselect(
self, choices: Sequence[ToolMetadata], query: QueryBundle
) -> SelectorResult:
return self._select(choices, query)