faiss_rag_enterprise/llama_index/query_pipeline/components/router.py

199 lines
6.7 KiB
Python

"""Router components."""
from typing import Any, Dict, List
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks.base import CallbackManager
from llama_index.core.base_selector import BaseSelector
from llama_index.core.query_pipeline.query_component import (
QUERY_COMPONENT_TYPE,
ChainableMixin,
InputKeys,
OutputKeys,
QueryComponent,
validate_and_convert_stringable,
)
from llama_index.utils import print_text
class SelectorComponent(QueryComponent):
"""Selector component."""
selector: BaseSelector = Field(..., description="Selector")
class Config:
arbitrary_types_allowed = True
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
if "choices" not in input:
raise ValueError("Input must have key 'choices'")
if not isinstance(input["choices"], list):
raise ValueError("Input choices must be a list")
for idx, choice in enumerate(input["choices"]):
# make stringable
input["choices"][idx] = validate_and_convert_stringable(choice)
# make sure `query` is stringable
if "query" not in input:
raise ValueError("Input must have key 'query'")
input["query"] = validate_and_convert_stringable(input["query"])
return input
def _run_component(self, **kwargs: Any) -> Any:
"""Run component."""
output = self.selector.select(kwargs["choices"], kwargs["query"])
return {"output": output.selections}
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
# NOTE: no native async for postprocessor
return self._run_component(**kwargs)
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
return InputKeys.from_keys({"choices", "query"})
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
return OutputKeys.from_keys({"output"})
class RouterComponent(QueryComponent):
"""Router Component.
Routes queries to different query components based on a selector.
Assumes a single query component is selected.
"""
selector: BaseSelector = Field(..., description="Selector")
choices: List[str] = Field(
..., description="Choices (must correspond to components)"
)
components: List[QueryComponent] = Field(
..., description="Components (must correspond to choices)"
)
verbose: bool = Field(default=False, description="Verbose")
_query_keys: List[str] = PrivateAttr()
class Config:
arbitrary_types_allowed = True
def __init__(
self,
selector: BaseSelector,
choices: List[str],
components: List[QUERY_COMPONENT_TYPE],
verbose: bool = False,
) -> None:
"""Init."""
new_components = []
query_keys = []
for component in components:
if isinstance(component, ChainableMixin):
new_component = component.as_query_component()
else:
new_component = component
# validate component has one input key
if len(new_component.free_req_input_keys) != 1:
raise ValueError("Expected one required input key")
query_keys.append(next(iter(new_component.free_req_input_keys)))
new_components.append(new_component)
self._query_keys = query_keys
super().__init__(
selector=selector,
choices=choices,
components=new_components,
verbose=verbose,
)
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
for component in self.components:
component.set_callback_manager(callback_manager)
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
# make sure `query` is stringable
if "query" not in input:
raise ValueError("Input must have key 'query'")
input["query"] = validate_and_convert_stringable(input["query"])
return input
def validate_component_outputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
return input
def _validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
raise NotImplementedError
def _run_component(self, **kwargs: Any) -> Any:
"""Run component."""
# for the output selection, run the corresponding component, aggregate into list
sel_output = self.selector.select(self.choices, kwargs["query"])
# assume one selection
if len(sel_output.selections) != 1:
raise ValueError("Expected one selection")
component = self.components[sel_output.ind]
log_str = f"Selecting component {sel_output.ind}: " f"{sel_output.reason}."
if self.verbose:
print_text(log_str + "\n", color="pink")
# run component
# run with input_keys of component
return component.run_component(
**{self._query_keys[sel_output.ind]: kwargs["query"]}
)
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
# for the output selection, run the corresponding component, aggregate into list
sel_output = await self.selector.aselect(self.choices, kwargs["query"])
# assume one selection
if len(sel_output.selections) != 1:
raise ValueError("Expected one selection")
component = self.components[sel_output.ind]
log_str = f"Selecting component {sel_output.ind}: " f"{sel_output.reason}."
if self.verbose:
print_text(log_str + "\n", color="pink")
# run component
return await component.arun_component(
**{self._query_keys[sel_output.ind]: kwargs["query"]}
)
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
return InputKeys.from_keys({"query"})
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
# not used
return OutputKeys.from_keys(set())
@property
def sub_query_components(self) -> List["QueryComponent"]:
"""Get sub query components.
Certain query components may have sub query components, e.g. a
query pipeline will have sub query components, and so will
an IfElseComponent.
"""
return self.components