faiss_rag_enterprise/llama_index/query_pipeline/components/agent.py

318 lines
10 KiB
Python

"""Agent components."""
from inspect import signature
from typing import Any, Callable, Dict, Optional, Set, Tuple, cast
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks.base import CallbackManager
from llama_index.core.query_pipeline.query_component import (
InputKeys,
OutputKeys,
QueryComponent,
)
def get_parameters(fn: Callable) -> Tuple[Set[str], Set[str]]:
"""Get parameters from function.
Returns:
Tuple[Set[str], Set[str]]: required and optional parameters
"""
# please write function below
params = signature(fn).parameters
required_params = set()
optional_params = set()
for param_name in params:
param_default = params[param_name].default
if param_default is params[param_name].empty:
required_params.add(param_name)
else:
optional_params.add(param_name)
return required_params, optional_params
def default_agent_input_fn(task: Any, state: dict) -> dict:
"""Default agent input function."""
from llama_index.agent.types import Task
task = cast(Task, task)
return {"input": task.input}
class AgentInputComponent(QueryComponent):
"""Takes in agent inputs and transforms it into desired outputs."""
fn: Callable = Field(..., description="Function to run.")
async_fn: Optional[Callable] = Field(
None, description="Async function to run. If not provided, will run `fn`."
)
_req_params: Set[str] = PrivateAttr()
_opt_params: Set[str] = PrivateAttr()
def __init__(
self,
fn: Callable,
async_fn: Optional[Callable] = None,
req_params: Optional[Set[str]] = None,
opt_params: Optional[Set[str]] = None,
**kwargs: Any,
) -> None:
"""Initialize."""
# determine parameters
default_req_params, default_opt_params = get_parameters(fn)
if req_params is None:
req_params = default_req_params
if opt_params is None:
opt_params = default_opt_params
self._req_params = req_params
self._opt_params = opt_params
super().__init__(fn=fn, async_fn=async_fn, **kwargs)
class Config:
arbitrary_types_allowed = True
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: implement
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
from llama_index.agent.types import Task
if "task" not in input:
raise ValueError("Input must have key 'task'")
if not isinstance(input["task"], Task):
raise ValueError("Input must have key 'task' of type Task")
if "state" not in input:
raise ValueError("Input must have key 'state'")
if not isinstance(input["state"], dict):
raise ValueError("Input must have key 'state' of type dict")
return input
def validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component outputs."""
# NOTE: we override this to do nothing
return output
def _validate_component_outputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
return input
def _run_component(self, **kwargs: Any) -> Dict:
"""Run component."""
output = self.fn(**kwargs)
if not isinstance(output, dict):
raise ValueError("Output must be a dictionary")
return output
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
if self.async_fn is None:
return self._run_component(**kwargs)
else:
output = await self.async_fn(**kwargs)
if not isinstance(output, dict):
raise ValueError("Output must be a dictionary")
return output
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
return InputKeys.from_keys(
required_keys={"task", "state", *self._req_params},
optional_keys=self._opt_params,
)
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
# output can be anything, overrode validate function
return OutputKeys.from_keys(set())
class BaseAgentComponent(QueryComponent):
"""Agent component.
Abstract class used for type checking.
"""
class AgentFnComponent(BaseAgentComponent):
"""Function component for agents.
Designed to let users easily modify state.
"""
fn: Callable = Field(..., description="Function to run.")
async_fn: Optional[Callable] = Field(
None, description="Async function to run. If not provided, will run `fn`."
)
_req_params: Set[str] = PrivateAttr()
_opt_params: Set[str] = PrivateAttr()
def __init__(
self,
fn: Callable,
async_fn: Optional[Callable] = None,
req_params: Optional[Set[str]] = None,
opt_params: Optional[Set[str]] = None,
**kwargs: Any,
) -> None:
"""Initialize."""
# determine parameters
default_req_params, default_opt_params = get_parameters(fn)
# make sure task and step are part of the list, and remove them from the list
if "task" not in default_req_params or "state" not in default_req_params:
raise ValueError(
"AgentFnComponent must have 'task' and 'state' as required parameters"
)
default_req_params = default_req_params - {"task", "state"}
default_opt_params = default_opt_params - {"task", "state"}
if req_params is None:
req_params = default_req_params
if opt_params is None:
opt_params = default_opt_params
self._req_params = req_params
self._opt_params = opt_params
super().__init__(fn=fn, async_fn=async_fn, **kwargs)
class Config:
arbitrary_types_allowed = True
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: implement
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
from llama_index.agent.types import Task
if "task" not in input:
raise ValueError("Input must have key 'task'")
if not isinstance(input["task"], Task):
raise ValueError("Input must have key 'task' of type Task")
if "state" not in input:
raise ValueError("Input must have key 'state'")
if not isinstance(input["state"], dict):
raise ValueError("Input must have key 'state' of type dict")
return input
def validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component outputs."""
# NOTE: we override this to do nothing
return output
def _validate_component_outputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
return input
def _run_component(self, **kwargs: Any) -> Dict:
"""Run component."""
output = self.fn(**kwargs)
# if not isinstance(output, dict):
# raise ValueError("Output must be a dictionary")
return {"output": output}
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
if self.async_fn is None:
return self._run_component(**kwargs)
else:
output = await self.async_fn(**kwargs)
# if not isinstance(output, dict):
# raise ValueError("Output must be a dictionary")
return {"output": output}
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
return InputKeys.from_keys(
required_keys={"task", "state", *self._req_params},
optional_keys=self._opt_params,
)
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
# output can be anything, overrode validate function
return OutputKeys.from_keys({"output"})
class CustomAgentComponent(BaseAgentComponent):
"""Custom component for agents.
Designed to let users easily modify state.
"""
callback_manager: CallbackManager = Field(
default_factory=CallbackManager, description="Callback manager"
)
class Config:
arbitrary_types_allowed = True
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
self.callback_manager = callback_manager
# TODO: refactor to put this on base class
for component in self.sub_query_components:
component.set_callback_manager(callback_manager)
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
# NOTE: user can override this method to validate inputs
# but we do this by default for convenience
return input
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
raise NotImplementedError("This component does not support async run.")
@property
def _input_keys(self) -> Set[str]:
"""Input keys dict."""
raise NotImplementedError("Not implemented yet. Please override this method.")
@property
def _optional_input_keys(self) -> Set[str]:
"""Optional input keys dict."""
return set()
@property
def _output_keys(self) -> Set[str]:
"""Output keys dict."""
raise NotImplementedError("Not implemented yet. Please override this method.")
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
# NOTE: user can override this too, but we have them implement an
# abstract method to make sure they do it
input_keys = self._input_keys.union({"task", "state"})
return InputKeys.from_keys(
required_keys=input_keys, optional_keys=self._optional_input_keys
)
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
# NOTE: user can override this too, but we have them implement an
# abstract method to make sure they do it
return OutputKeys.from_keys(self._output_keys)