109 lines
3.7 KiB
Python
109 lines
3.7 KiB
Python
"""Tool runner component."""
|
|
|
|
from typing import Any, Dict, Optional, Sequence, cast
|
|
|
|
from llama_index.bridge.pydantic import Field
|
|
from llama_index.callbacks import (
|
|
CallbackManager,
|
|
CBEventType,
|
|
EventPayload,
|
|
)
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.core.query_pipeline.query_component import (
|
|
InputKeys,
|
|
OutputKeys,
|
|
QueryComponent,
|
|
validate_and_convert_stringable,
|
|
)
|
|
from llama_index.tools import AsyncBaseTool, adapt_to_async_tool
|
|
|
|
|
|
class ToolRunnerComponent(QueryComponent):
|
|
"""Tool runner component that takes in a set of tools."""
|
|
|
|
tool_dict: Dict[str, AsyncBaseTool] = Field(
|
|
..., description="Dictionary of tool names to tools."
|
|
)
|
|
callback_manager: CallbackManager = Field(
|
|
default_factory=lambda: CallbackManager([]), exclude=True
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
tools: Sequence[AsyncBaseTool],
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize."""
|
|
# determine parameters
|
|
tool_dict = {tool.metadata.name: adapt_to_async_tool(tool) for tool in tools}
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
super().__init__(
|
|
tool_dict=tool_dict, callback_manager=callback_manager, **kwargs
|
|
)
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
|
"""Set callback manager."""
|
|
self.callback_manager = callback_manager
|
|
|
|
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate component inputs during run_component."""
|
|
if "tool_name" not in input:
|
|
raise ValueError("tool_name must be provided in input")
|
|
|
|
input["tool_name"] = validate_and_convert_stringable(input["tool_name"])
|
|
|
|
if "tool_input" not in input:
|
|
raise ValueError("tool_input must be provided in input")
|
|
# make sure tool_input is a dictionary
|
|
if not isinstance(input["tool_input"], dict):
|
|
raise ValueError("tool_input must be a dictionary")
|
|
|
|
return input
|
|
|
|
def _run_component(self, **kwargs: Any) -> Dict:
|
|
"""Run component."""
|
|
tool_name = kwargs["tool_name"]
|
|
tool_input = kwargs["tool_input"]
|
|
tool = cast(AsyncBaseTool, self.tool_dict[tool_name])
|
|
with self.callback_manager.event(
|
|
CBEventType.FUNCTION_CALL,
|
|
payload={
|
|
EventPayload.FUNCTION_CALL: tool_input,
|
|
EventPayload.TOOL: tool.metadata,
|
|
},
|
|
) as event:
|
|
tool_output = tool(**tool_input)
|
|
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
|
|
|
return {"output": tool_output}
|
|
|
|
async def _arun_component(self, **kwargs: Any) -> Any:
|
|
"""Run component (async)."""
|
|
tool_name = kwargs["tool_name"]
|
|
tool_input = kwargs["tool_input"]
|
|
tool = cast(AsyncBaseTool, self.tool_dict[tool_name])
|
|
with self.callback_manager.event(
|
|
CBEventType.FUNCTION_CALL,
|
|
payload={
|
|
EventPayload.FUNCTION_CALL: tool_input,
|
|
EventPayload.TOOL: tool.metadata,
|
|
},
|
|
) as event:
|
|
tool_output = await tool.acall(**tool_input)
|
|
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
|
return {"output": tool_output}
|
|
|
|
@property
|
|
def input_keys(self) -> InputKeys:
|
|
"""Input keys."""
|
|
return InputKeys.from_keys({"tool_name", "tool_input"})
|
|
|
|
@property
|
|
def output_keys(self) -> OutputKeys:
|
|
"""Output keys."""
|
|
return OutputKeys.from_keys({"output"})
|