199 lines
6.7 KiB
Python
199 lines
6.7 KiB
Python
"""Router components."""
|
|
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.core.base_selector import BaseSelector
|
|
from llama_index.core.query_pipeline.query_component import (
|
|
QUERY_COMPONENT_TYPE,
|
|
ChainableMixin,
|
|
InputKeys,
|
|
OutputKeys,
|
|
QueryComponent,
|
|
validate_and_convert_stringable,
|
|
)
|
|
from llama_index.utils import print_text
|
|
|
|
|
|
class SelectorComponent(QueryComponent):
|
|
"""Selector component."""
|
|
|
|
selector: BaseSelector = Field(..., description="Selector")
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
|
"""Set callback manager."""
|
|
|
|
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate component inputs during run_component."""
|
|
if "choices" not in input:
|
|
raise ValueError("Input must have key 'choices'")
|
|
if not isinstance(input["choices"], list):
|
|
raise ValueError("Input choices must be a list")
|
|
|
|
for idx, choice in enumerate(input["choices"]):
|
|
# make stringable
|
|
input["choices"][idx] = validate_and_convert_stringable(choice)
|
|
|
|
# make sure `query` is stringable
|
|
if "query" not in input:
|
|
raise ValueError("Input must have key 'query'")
|
|
input["query"] = validate_and_convert_stringable(input["query"])
|
|
|
|
return input
|
|
|
|
def _run_component(self, **kwargs: Any) -> Any:
|
|
"""Run component."""
|
|
output = self.selector.select(kwargs["choices"], kwargs["query"])
|
|
return {"output": output.selections}
|
|
|
|
async def _arun_component(self, **kwargs: Any) -> Any:
|
|
"""Run component (async)."""
|
|
# NOTE: no native async for postprocessor
|
|
return self._run_component(**kwargs)
|
|
|
|
@property
|
|
def input_keys(self) -> InputKeys:
|
|
"""Input keys."""
|
|
return InputKeys.from_keys({"choices", "query"})
|
|
|
|
@property
|
|
def output_keys(self) -> OutputKeys:
|
|
"""Output keys."""
|
|
return OutputKeys.from_keys({"output"})
|
|
|
|
|
|
class RouterComponent(QueryComponent):
|
|
"""Router Component.
|
|
|
|
Routes queries to different query components based on a selector.
|
|
|
|
Assumes a single query component is selected.
|
|
|
|
"""
|
|
|
|
selector: BaseSelector = Field(..., description="Selector")
|
|
choices: List[str] = Field(
|
|
..., description="Choices (must correspond to components)"
|
|
)
|
|
components: List[QueryComponent] = Field(
|
|
..., description="Components (must correspond to choices)"
|
|
)
|
|
verbose: bool = Field(default=False, description="Verbose")
|
|
|
|
_query_keys: List[str] = PrivateAttr()
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def __init__(
|
|
self,
|
|
selector: BaseSelector,
|
|
choices: List[str],
|
|
components: List[QUERY_COMPONENT_TYPE],
|
|
verbose: bool = False,
|
|
) -> None:
|
|
"""Init."""
|
|
new_components = []
|
|
query_keys = []
|
|
for component in components:
|
|
if isinstance(component, ChainableMixin):
|
|
new_component = component.as_query_component()
|
|
else:
|
|
new_component = component
|
|
|
|
# validate component has one input key
|
|
if len(new_component.free_req_input_keys) != 1:
|
|
raise ValueError("Expected one required input key")
|
|
query_keys.append(next(iter(new_component.free_req_input_keys)))
|
|
new_components.append(new_component)
|
|
|
|
self._query_keys = query_keys
|
|
|
|
super().__init__(
|
|
selector=selector,
|
|
choices=choices,
|
|
components=new_components,
|
|
verbose=verbose,
|
|
)
|
|
|
|
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
|
"""Set callback manager."""
|
|
for component in self.components:
|
|
component.set_callback_manager(callback_manager)
|
|
|
|
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate component inputs during run_component."""
|
|
# make sure `query` is stringable
|
|
if "query" not in input:
|
|
raise ValueError("Input must have key 'query'")
|
|
input["query"] = validate_and_convert_stringable(input["query"])
|
|
|
|
return input
|
|
|
|
def validate_component_outputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate component inputs during run_component."""
|
|
return input
|
|
|
|
def _validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
|
|
raise NotImplementedError
|
|
|
|
def _run_component(self, **kwargs: Any) -> Any:
|
|
"""Run component."""
|
|
# for the output selection, run the corresponding component, aggregate into list
|
|
sel_output = self.selector.select(self.choices, kwargs["query"])
|
|
# assume one selection
|
|
if len(sel_output.selections) != 1:
|
|
raise ValueError("Expected one selection")
|
|
component = self.components[sel_output.ind]
|
|
log_str = f"Selecting component {sel_output.ind}: " f"{sel_output.reason}."
|
|
if self.verbose:
|
|
print_text(log_str + "\n", color="pink")
|
|
# run component
|
|
# run with input_keys of component
|
|
return component.run_component(
|
|
**{self._query_keys[sel_output.ind]: kwargs["query"]}
|
|
)
|
|
|
|
async def _arun_component(self, **kwargs: Any) -> Any:
|
|
"""Run component (async)."""
|
|
# for the output selection, run the corresponding component, aggregate into list
|
|
sel_output = await self.selector.aselect(self.choices, kwargs["query"])
|
|
# assume one selection
|
|
if len(sel_output.selections) != 1:
|
|
raise ValueError("Expected one selection")
|
|
component = self.components[sel_output.ind]
|
|
log_str = f"Selecting component {sel_output.ind}: " f"{sel_output.reason}."
|
|
if self.verbose:
|
|
print_text(log_str + "\n", color="pink")
|
|
# run component
|
|
return await component.arun_component(
|
|
**{self._query_keys[sel_output.ind]: kwargs["query"]}
|
|
)
|
|
|
|
@property
|
|
def input_keys(self) -> InputKeys:
|
|
"""Input keys."""
|
|
return InputKeys.from_keys({"query"})
|
|
|
|
@property
|
|
def output_keys(self) -> OutputKeys:
|
|
"""Output keys."""
|
|
# not used
|
|
return OutputKeys.from_keys(set())
|
|
|
|
@property
|
|
def sub_query_components(self) -> List["QueryComponent"]:
|
|
"""Get sub query components.
|
|
|
|
Certain query components may have sub query components, e.g. a
|
|
query pipeline will have sub query components, and so will
|
|
an IfElseComponent.
|
|
|
|
"""
|
|
return self.components
|