632 lines
21 KiB
Python
632 lines
21 KiB
Python
from abc import abstractmethod
|
|
from collections import deque
|
|
from typing import Any, Deque, Dict, List, Optional, Union, cast
|
|
|
|
from llama_index.agent.types import (
|
|
BaseAgent,
|
|
BaseAgentWorker,
|
|
Task,
|
|
TaskStep,
|
|
TaskStepOutput,
|
|
)
|
|
from llama_index.bridge.pydantic import BaseModel, Field
|
|
from llama_index.callbacks import (
|
|
CallbackManager,
|
|
CBEventType,
|
|
EventPayload,
|
|
trace_method,
|
|
)
|
|
from llama_index.chat_engine.types import (
|
|
AGENT_CHAT_RESPONSE_TYPE,
|
|
AgentChatResponse,
|
|
ChatResponseMode,
|
|
StreamingAgentChatResponse,
|
|
)
|
|
from llama_index.llms.base import ChatMessage
|
|
from llama_index.llms.llm import LLM
|
|
from llama_index.memory import BaseMemory, ChatMemoryBuffer
|
|
from llama_index.memory.types import BaseMemory
|
|
from llama_index.tools.types import BaseTool
|
|
|
|
|
|
class BaseAgentRunner(BaseAgent):
|
|
"""Base agent runner."""
|
|
|
|
@abstractmethod
|
|
def create_task(self, input: str, **kwargs: Any) -> Task:
|
|
"""Create task."""
|
|
|
|
@abstractmethod
|
|
def delete_task(
|
|
self,
|
|
task_id: str,
|
|
) -> None:
|
|
"""Delete task.
|
|
|
|
NOTE: this will not delete any previous executions from memory.
|
|
|
|
"""
|
|
|
|
@abstractmethod
|
|
def list_tasks(self, **kwargs: Any) -> List[Task]:
|
|
"""List tasks."""
|
|
|
|
@abstractmethod
|
|
def get_task(self, task_id: str, **kwargs: Any) -> Task:
|
|
"""Get task."""
|
|
|
|
@abstractmethod
|
|
def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
|
|
"""Get upcoming steps."""
|
|
|
|
@abstractmethod
|
|
def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
|
|
"""Get completed steps."""
|
|
|
|
def get_completed_step(
|
|
self, task_id: str, step_id: str, **kwargs: Any
|
|
) -> TaskStepOutput:
|
|
"""Get completed step."""
|
|
# call get_completed_steps, and then find the right task
|
|
completed_steps = self.get_completed_steps(task_id, **kwargs)
|
|
for step_output in completed_steps:
|
|
if step_output.task_step.step_id == step_id:
|
|
return step_output
|
|
raise ValueError(f"Could not find step_id: {step_id}")
|
|
|
|
@abstractmethod
|
|
def run_step(
|
|
self,
|
|
task_id: str,
|
|
input: Optional[str] = None,
|
|
step: Optional[TaskStep] = None,
|
|
**kwargs: Any,
|
|
) -> TaskStepOutput:
|
|
"""Run step."""
|
|
|
|
@abstractmethod
|
|
async def arun_step(
|
|
self,
|
|
task_id: str,
|
|
input: Optional[str] = None,
|
|
step: Optional[TaskStep] = None,
|
|
**kwargs: Any,
|
|
) -> TaskStepOutput:
|
|
"""Run step (async)."""
|
|
|
|
@abstractmethod
|
|
def stream_step(
|
|
self,
|
|
task_id: str,
|
|
input: Optional[str] = None,
|
|
step: Optional[TaskStep] = None,
|
|
**kwargs: Any,
|
|
) -> TaskStepOutput:
|
|
"""Run step (stream)."""
|
|
|
|
@abstractmethod
|
|
async def astream_step(
|
|
self,
|
|
task_id: str,
|
|
input: Optional[str] = None,
|
|
step: Optional[TaskStep] = None,
|
|
**kwargs: Any,
|
|
) -> TaskStepOutput:
|
|
"""Run step (async stream)."""
|
|
|
|
@abstractmethod
|
|
def finalize_response(
|
|
self,
|
|
task_id: str,
|
|
step_output: Optional[TaskStepOutput] = None,
|
|
) -> AGENT_CHAT_RESPONSE_TYPE:
|
|
"""Finalize response."""
|
|
|
|
@abstractmethod
|
|
def undo_step(self, task_id: str) -> None:
|
|
"""Undo previous step."""
|
|
raise NotImplementedError("undo_step not implemented")
|
|
|
|
|
|
def validate_step_from_args(
|
|
task_id: str, input: Optional[str] = None, step: Optional[Any] = None, **kwargs: Any
|
|
) -> Optional[TaskStep]:
|
|
"""Validate step from args."""
|
|
if step is not None:
|
|
if input is not None:
|
|
raise ValueError("Cannot specify both `step` and `input`")
|
|
if not isinstance(step, TaskStep):
|
|
raise ValueError(f"step must be TaskStep: {step}")
|
|
return step
|
|
else:
|
|
return None
|
|
|
|
|
|
class TaskState(BaseModel):
|
|
"""Task state."""
|
|
|
|
task: Task = Field(..., description="Task.")
|
|
step_queue: Deque[TaskStep] = Field(
|
|
default_factory=deque, description="Task step queue."
|
|
)
|
|
completed_steps: List[TaskStepOutput] = Field(
|
|
default_factory=list, description="Completed step outputs."
|
|
)
|
|
|
|
|
|
class AgentState(BaseModel):
|
|
"""Agent state."""
|
|
|
|
task_dict: Dict[str, TaskState] = Field(
|
|
default_factory=dict, description="Task dictionary."
|
|
)
|
|
|
|
def get_task(self, task_id: str) -> Task:
|
|
"""Get task state."""
|
|
return self.task_dict[task_id].task
|
|
|
|
def get_completed_steps(self, task_id: str) -> List[TaskStepOutput]:
|
|
"""Get completed steps."""
|
|
return self.task_dict[task_id].completed_steps
|
|
|
|
def get_step_queue(self, task_id: str) -> Deque[TaskStep]:
|
|
"""Get step queue."""
|
|
return self.task_dict[task_id].step_queue
|
|
|
|
def reset(self) -> None:
|
|
"""Reset."""
|
|
self.task_dict = {}
|
|
|
|
|
|
class AgentRunner(BaseAgentRunner):
|
|
"""Agent runner.
|
|
|
|
Top-level agent orchestrator that can create tasks, run each step in a task,
|
|
or run a task e2e. Stores state and keeps track of tasks.
|
|
|
|
Args:
|
|
agent_worker (BaseAgentWorker): step executor
|
|
chat_history (Optional[List[ChatMessage]], optional): chat history. Defaults to None.
|
|
state (Optional[AgentState], optional): agent state. Defaults to None.
|
|
memory (Optional[BaseMemory], optional): memory. Defaults to None.
|
|
llm (Optional[LLM], optional): LLM. Defaults to None.
|
|
callback_manager (Optional[CallbackManager], optional): callback manager. Defaults to None.
|
|
init_task_state_kwargs (Optional[dict], optional): init task state kwargs. Defaults to None.
|
|
|
|
"""
|
|
|
|
# # TODO: implement this in Pydantic
|
|
|
|
def __init__(
|
|
self,
|
|
agent_worker: BaseAgentWorker,
|
|
chat_history: Optional[List[ChatMessage]] = None,
|
|
state: Optional[AgentState] = None,
|
|
memory: Optional[BaseMemory] = None,
|
|
llm: Optional[LLM] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
init_task_state_kwargs: Optional[dict] = None,
|
|
delete_task_on_finish: bool = False,
|
|
default_tool_choice: str = "auto",
|
|
verbose: bool = False,
|
|
) -> None:
|
|
"""Initialize."""
|
|
self.agent_worker = agent_worker
|
|
self.state = state or AgentState()
|
|
self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm)
|
|
|
|
# get and set callback manager
|
|
if callback_manager is not None:
|
|
self.agent_worker.set_callback_manager(callback_manager)
|
|
self.callback_manager = callback_manager
|
|
else:
|
|
# TODO: This is *temporary*
|
|
# Stopgap before having a callback on the BaseAgentWorker interface.
|
|
# Doing that requires a bit more refactoring to make sure existing code
|
|
# doesn't break.
|
|
if hasattr(self.agent_worker, "callback_manager"):
|
|
self.callback_manager = (
|
|
self.agent_worker.callback_manager or CallbackManager()
|
|
)
|
|
else:
|
|
self.callback_manager = CallbackManager()
|
|
|
|
self.init_task_state_kwargs = init_task_state_kwargs or {}
|
|
self.delete_task_on_finish = delete_task_on_finish
|
|
self.default_tool_choice = default_tool_choice
|
|
self.verbose = verbose
|
|
|
|
@staticmethod
|
|
def from_llm(
|
|
tools: Optional[List[BaseTool]] = None,
|
|
llm: Optional[LLM] = None,
|
|
**kwargs: Any,
|
|
) -> "AgentRunner":
|
|
from llama_index.llms.openai import OpenAI
|
|
from llama_index.llms.openai_utils import is_function_calling_model
|
|
|
|
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
|
|
from llama_index.agent import OpenAIAgent
|
|
|
|
return OpenAIAgent.from_tools(
|
|
tools=tools,
|
|
llm=llm,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
from llama_index.agent import ReActAgent
|
|
|
|
return ReActAgent.from_tools(
|
|
tools=tools,
|
|
llm=llm,
|
|
**kwargs,
|
|
)
|
|
|
|
@property
|
|
def chat_history(self) -> List[ChatMessage]:
|
|
return self.memory.get_all()
|
|
|
|
def reset(self) -> None:
|
|
self.memory.reset()
|
|
self.state.reset()
|
|
|
|
def create_task(self, input: str, **kwargs: Any) -> Task:
|
|
"""Create task."""
|
|
if not self.init_task_state_kwargs:
|
|
extra_state = kwargs.pop("extra_state", {})
|
|
else:
|
|
if "extra_state" in kwargs:
|
|
raise ValueError(
|
|
"Cannot specify both `extra_state` and `init_task_state_kwargs`"
|
|
)
|
|
else:
|
|
extra_state = self.init_task_state_kwargs
|
|
|
|
callback_manager = kwargs.pop("callback_manager", self.callback_manager)
|
|
task = Task(
|
|
input=input,
|
|
memory=self.memory,
|
|
extra_state=extra_state,
|
|
callback_manager=callback_manager,
|
|
**kwargs,
|
|
)
|
|
# # put input into memory
|
|
# self.memory.put(ChatMessage(content=input, role=MessageRole.USER))
|
|
|
|
# get initial step from task, and put it in the step queue
|
|
initial_step = self.agent_worker.initialize_step(task)
|
|
task_state = TaskState(
|
|
task=task,
|
|
step_queue=deque([initial_step]),
|
|
)
|
|
# add it to state
|
|
self.state.task_dict[task.task_id] = task_state
|
|
|
|
return task
|
|
|
|
def delete_task(
|
|
self,
|
|
task_id: str,
|
|
) -> None:
|
|
"""Delete task.
|
|
|
|
NOTE: this will not delete any previous executions from memory.
|
|
|
|
"""
|
|
self.state.task_dict.pop(task_id)
|
|
|
|
def list_tasks(self, **kwargs: Any) -> List[Task]:
|
|
"""List tasks."""
|
|
return list(self.state.task_dict.values())
|
|
|
|
def get_task(self, task_id: str, **kwargs: Any) -> Task:
|
|
"""Get task."""
|
|
return self.state.get_task(task_id)
|
|
|
|
def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
|
|
"""Get upcoming steps."""
|
|
return list(self.state.get_step_queue(task_id))
|
|
|
|
def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
|
|
"""Get completed steps."""
|
|
return self.state.get_completed_steps(task_id)
|
|
|
|
def _run_step(
|
|
self,
|
|
task_id: str,
|
|
step: Optional[TaskStep] = None,
|
|
input: Optional[str] = None,
|
|
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
|
**kwargs: Any,
|
|
) -> TaskStepOutput:
|
|
"""Execute step."""
|
|
task = self.state.get_task(task_id)
|
|
step_queue = self.state.get_step_queue(task_id)
|
|
step = step or step_queue.popleft()
|
|
if input is not None:
|
|
step.input = input
|
|
|
|
if self.verbose:
|
|
print(f"> Running step {step.step_id}. Step input: {step.input}")
|
|
|
|
# TODO: figure out if you can dynamically swap in different step executors
|
|
# not clear when you would do that by theoretically possible
|
|
|
|
if mode == ChatResponseMode.WAIT:
|
|
cur_step_output = self.agent_worker.run_step(step, task, **kwargs)
|
|
elif mode == ChatResponseMode.STREAM:
|
|
cur_step_output = self.agent_worker.stream_step(step, task, **kwargs)
|
|
else:
|
|
raise ValueError(f"Invalid mode: {mode}")
|
|
# append cur_step_output next steps to queue
|
|
next_steps = cur_step_output.next_steps
|
|
step_queue.extend(next_steps)
|
|
|
|
# add cur_step_output to completed steps
|
|
completed_steps = self.state.get_completed_steps(task_id)
|
|
completed_steps.append(cur_step_output)
|
|
|
|
return cur_step_output
|
|
|
|
async def _arun_step(
|
|
self,
|
|
task_id: str,
|
|
step: Optional[TaskStep] = None,
|
|
input: Optional[str] = None,
|
|
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
|
**kwargs: Any,
|
|
) -> TaskStepOutput:
|
|
"""Execute step."""
|
|
task = self.state.get_task(task_id)
|
|
step_queue = self.state.get_step_queue(task_id)
|
|
step = step or step_queue.popleft()
|
|
if input is not None:
|
|
step.input = input
|
|
|
|
if self.verbose:
|
|
print(f"> Running step {step.step_id}. Step input: {step.input}")
|
|
|
|
# TODO: figure out if you can dynamically swap in different step executors
|
|
# not clear when you would do that by theoretically possible
|
|
if mode == ChatResponseMode.WAIT:
|
|
cur_step_output = await self.agent_worker.arun_step(step, task, **kwargs)
|
|
elif mode == ChatResponseMode.STREAM:
|
|
cur_step_output = await self.agent_worker.astream_step(step, task, **kwargs)
|
|
else:
|
|
raise ValueError(f"Invalid mode: {mode}")
|
|
# append cur_step_output next steps to queue
|
|
next_steps = cur_step_output.next_steps
|
|
step_queue.extend(next_steps)
|
|
|
|
# add cur_step_output to completed steps
|
|
completed_steps = self.state.get_completed_steps(task_id)
|
|
completed_steps.append(cur_step_output)
|
|
|
|
return cur_step_output
|
|
|
|
def run_step(
|
|
self,
|
|
task_id: str,
|
|
input: Optional[str] = None,
|
|
step: Optional[TaskStep] = None,
|
|
**kwargs: Any,
|
|
) -> TaskStepOutput:
|
|
"""Run step."""
|
|
step = validate_step_from_args(task_id, input, step, **kwargs)
|
|
return self._run_step(
|
|
task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs
|
|
)
|
|
|
|
async def arun_step(
|
|
self,
|
|
task_id: str,
|
|
input: Optional[str] = None,
|
|
step: Optional[TaskStep] = None,
|
|
**kwargs: Any,
|
|
) -> TaskStepOutput:
|
|
"""Run step (async)."""
|
|
step = validate_step_from_args(task_id, input, step, **kwargs)
|
|
return await self._arun_step(
|
|
task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs
|
|
)
|
|
|
|
def stream_step(
|
|
self,
|
|
task_id: str,
|
|
input: Optional[str] = None,
|
|
step: Optional[TaskStep] = None,
|
|
**kwargs: Any,
|
|
) -> TaskStepOutput:
|
|
"""Run step (stream)."""
|
|
step = validate_step_from_args(task_id, input, step, **kwargs)
|
|
return self._run_step(
|
|
task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs
|
|
)
|
|
|
|
async def astream_step(
|
|
self,
|
|
task_id: str,
|
|
input: Optional[str] = None,
|
|
step: Optional[TaskStep] = None,
|
|
**kwargs: Any,
|
|
) -> TaskStepOutput:
|
|
"""Run step (async stream)."""
|
|
step = validate_step_from_args(task_id, input, step, **kwargs)
|
|
return await self._arun_step(
|
|
task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs
|
|
)
|
|
|
|
def finalize_response(
|
|
self,
|
|
task_id: str,
|
|
step_output: Optional[TaskStepOutput] = None,
|
|
) -> AGENT_CHAT_RESPONSE_TYPE:
|
|
"""Finalize response."""
|
|
if step_output is None:
|
|
step_output = self.state.get_completed_steps(task_id)[-1]
|
|
if not step_output.is_last:
|
|
raise ValueError(
|
|
"finalize_response can only be called on the last step output"
|
|
)
|
|
|
|
if not isinstance(
|
|
step_output.output,
|
|
(AgentChatResponse, StreamingAgentChatResponse),
|
|
):
|
|
raise ValueError(
|
|
"When `is_last` is True, cur_step_output.output must be "
|
|
f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}"
|
|
)
|
|
|
|
# finalize task
|
|
self.agent_worker.finalize_task(self.state.get_task(task_id))
|
|
|
|
if self.delete_task_on_finish:
|
|
self.delete_task(task_id)
|
|
|
|
return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output)
|
|
|
|
def _chat(
|
|
self,
|
|
message: str,
|
|
chat_history: Optional[List[ChatMessage]] = None,
|
|
tool_choice: Union[str, dict] = "auto",
|
|
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
|
) -> AGENT_CHAT_RESPONSE_TYPE:
|
|
"""Chat with step executor."""
|
|
if chat_history is not None:
|
|
self.memory.set(chat_history)
|
|
task = self.create_task(message)
|
|
|
|
result_output = None
|
|
while True:
|
|
# pass step queue in as argument, assume step executor is stateless
|
|
cur_step_output = self._run_step(
|
|
task.task_id, mode=mode, tool_choice=tool_choice
|
|
)
|
|
|
|
if cur_step_output.is_last:
|
|
result_output = cur_step_output
|
|
break
|
|
|
|
# ensure tool_choice does not cause endless loops
|
|
tool_choice = "auto"
|
|
|
|
return self.finalize_response(task.task_id, result_output)
|
|
|
|
async def _achat(
|
|
self,
|
|
message: str,
|
|
chat_history: Optional[List[ChatMessage]] = None,
|
|
tool_choice: Union[str, dict] = "auto",
|
|
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
|
) -> AGENT_CHAT_RESPONSE_TYPE:
|
|
"""Chat with step executor."""
|
|
if chat_history is not None:
|
|
self.memory.set(chat_history)
|
|
task = self.create_task(message)
|
|
|
|
result_output = None
|
|
while True:
|
|
# pass step queue in as argument, assume step executor is stateless
|
|
cur_step_output = await self._arun_step(
|
|
task.task_id, mode=mode, tool_choice=tool_choice
|
|
)
|
|
|
|
if cur_step_output.is_last:
|
|
result_output = cur_step_output
|
|
break
|
|
|
|
# ensure tool_choice does not cause endless loops
|
|
tool_choice = "auto"
|
|
|
|
return self.finalize_response(task.task_id, result_output)
|
|
|
|
@trace_method("chat")
|
|
def chat(
|
|
self,
|
|
message: str,
|
|
chat_history: Optional[List[ChatMessage]] = None,
|
|
tool_choice: Optional[Union[str, dict]] = None,
|
|
) -> AgentChatResponse:
|
|
# override tool choice is provided as input.
|
|
if tool_choice is None:
|
|
tool_choice = self.default_tool_choice
|
|
with self.callback_manager.event(
|
|
CBEventType.AGENT_STEP,
|
|
payload={EventPayload.MESSAGES: [message]},
|
|
) as e:
|
|
chat_response = self._chat(
|
|
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
|
|
)
|
|
assert isinstance(chat_response, AgentChatResponse)
|
|
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
|
return chat_response
|
|
|
|
@trace_method("chat")
|
|
async def achat(
|
|
self,
|
|
message: str,
|
|
chat_history: Optional[List[ChatMessage]] = None,
|
|
tool_choice: Optional[Union[str, dict]] = None,
|
|
) -> AgentChatResponse:
|
|
# override tool choice is provided as input.
|
|
if tool_choice is None:
|
|
tool_choice = self.default_tool_choice
|
|
with self.callback_manager.event(
|
|
CBEventType.AGENT_STEP,
|
|
payload={EventPayload.MESSAGES: [message]},
|
|
) as e:
|
|
chat_response = await self._achat(
|
|
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
|
|
)
|
|
assert isinstance(chat_response, AgentChatResponse)
|
|
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
|
return chat_response
|
|
|
|
@trace_method("chat")
|
|
def stream_chat(
|
|
self,
|
|
message: str,
|
|
chat_history: Optional[List[ChatMessage]] = None,
|
|
tool_choice: Optional[Union[str, dict]] = None,
|
|
) -> StreamingAgentChatResponse:
|
|
# override tool choice is provided as input.
|
|
if tool_choice is None:
|
|
tool_choice = self.default_tool_choice
|
|
with self.callback_manager.event(
|
|
CBEventType.AGENT_STEP,
|
|
payload={EventPayload.MESSAGES: [message]},
|
|
) as e:
|
|
chat_response = self._chat(
|
|
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
|
|
)
|
|
assert isinstance(chat_response, StreamingAgentChatResponse)
|
|
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
|
return chat_response
|
|
|
|
@trace_method("chat")
|
|
async def astream_chat(
|
|
self,
|
|
message: str,
|
|
chat_history: Optional[List[ChatMessage]] = None,
|
|
tool_choice: Optional[Union[str, dict]] = None,
|
|
) -> StreamingAgentChatResponse:
|
|
# override tool choice is provided as input.
|
|
if tool_choice is None:
|
|
tool_choice = self.default_tool_choice
|
|
with self.callback_manager.event(
|
|
CBEventType.AGENT_STEP,
|
|
payload={EventPayload.MESSAGES: [message]},
|
|
) as e:
|
|
chat_response = await self._achat(
|
|
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
|
|
)
|
|
assert isinstance(chat_response, StreamingAgentChatResponse)
|
|
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
|
return chat_response
|
|
|
|
def undo_step(self, task_id: str) -> None:
|
|
"""Undo previous step."""
|
|
raise NotImplementedError("undo_step not implemented")
|