"""Agent executor.""" import asyncio from collections import deque from typing import Any, Deque, Dict, List, Optional, Union, cast from llama_index.agent.runner.base import BaseAgentRunner from llama_index.agent.types import ( BaseAgentWorker, Task, TaskStep, TaskStepOutput, ) from llama_index.bridge.pydantic import BaseModel, Field from llama_index.callbacks import ( CallbackManager, CBEventType, EventPayload, trace_method, ) from llama_index.chat_engine.types import ( AGENT_CHAT_RESPONSE_TYPE, AgentChatResponse, ChatResponseMode, StreamingAgentChatResponse, ) from llama_index.llms.base import ChatMessage from llama_index.llms.llm import LLM from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.memory.types import BaseMemory class DAGTaskState(BaseModel): """DAG Task state.""" task: Task = Field(..., description="Task.") root_step: TaskStep = Field(..., description="Root step.") step_queue: Deque[TaskStep] = Field( default_factory=deque, description="Task step queue." ) completed_steps: List[TaskStepOutput] = Field( default_factory=list, description="Completed step outputs." ) @property def task_id(self) -> str: """Task id.""" return self.task.task_id class DAGAgentState(BaseModel): """Agent state.""" task_dict: Dict[str, DAGTaskState] = Field( default_factory=dict, description="Task dictionary." ) def get_task(self, task_id: str) -> Task: """Get task state.""" return self.task_dict[task_id].task def get_completed_steps(self, task_id: str) -> List[TaskStepOutput]: """Get completed steps.""" return self.task_dict[task_id].completed_steps def get_step_queue(self, task_id: str) -> Deque[TaskStep]: """Get step queue.""" return self.task_dict[task_id].step_queue class ParallelAgentRunner(BaseAgentRunner): """Parallel agent runner. Executes steps in queue in parallel. Requires async support. """ def __init__( self, agent_worker: BaseAgentWorker, chat_history: Optional[List[ChatMessage]] = None, state: Optional[DAGAgentState] = None, memory: Optional[BaseMemory] = None, llm: Optional[LLM] = None, callback_manager: Optional[CallbackManager] = None, init_task_state_kwargs: Optional[dict] = None, delete_task_on_finish: bool = False, ) -> None: """Initialize.""" self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm) self.state = state or DAGAgentState() self.callback_manager = callback_manager or CallbackManager([]) self.init_task_state_kwargs = init_task_state_kwargs or {} self.agent_worker = agent_worker self.delete_task_on_finish = delete_task_on_finish @property def chat_history(self) -> List[ChatMessage]: return self.memory.get_all() def reset(self) -> None: self.memory.reset() def create_task(self, input: str, **kwargs: Any) -> Task: """Create task.""" task = Task( input=input, memory=self.memory, extra_state=self.init_task_state_kwargs, **kwargs, ) # # put input into memory # self.memory.put(ChatMessage(content=input, role=MessageRole.USER)) # add it to state # get initial step from task, and put it in the step queue initial_step = self.agent_worker.initialize_step(task) task_state = DAGTaskState( task=task, root_step=initial_step, step_queue=deque([initial_step]), ) self.state.task_dict[task.task_id] = task_state return task def delete_task( self, task_id: str, ) -> None: """Delete task. NOTE: this will not delete any previous executions from memory. """ self.state.task_dict.pop(task_id) def list_tasks(self, **kwargs: Any) -> List[Task]: """List tasks.""" task_states = list(self.state.task_dict.values()) return [task_state.task for task_state in task_states] def get_task(self, task_id: str, **kwargs: Any) -> Task: """Get task.""" return self.state.get_task(task_id) def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]: """Get upcoming steps.""" return list(self.state.get_step_queue(task_id)) def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]: """Get completed steps.""" return self.state.get_completed_steps(task_id) def run_steps_in_queue( self, task_id: str, mode: ChatResponseMode = ChatResponseMode.WAIT, **kwargs: Any, ) -> List[TaskStepOutput]: """Execute steps in queue. Run all steps in queue, clearing it out. Assume that all steps can be run in parallel. """ return asyncio.run(self.arun_steps_in_queue(task_id, mode=mode, **kwargs)) async def arun_steps_in_queue( self, task_id: str, mode: ChatResponseMode = ChatResponseMode.WAIT, **kwargs: Any, ) -> List[TaskStepOutput]: """Execute all steps in queue. All steps in queue are assumed to be ready. """ # first pop all steps from step_queue steps: List[TaskStep] = [] while len(self.state.get_step_queue(task_id)) > 0: steps.append(self.state.get_step_queue(task_id).popleft()) # take every item in the queue, and run it tasks = [] for step in steps: tasks.append(self._arun_step(task_id, step=step, mode=mode, **kwargs)) return await asyncio.gather(*tasks) def _run_step( self, task_id: str, step: Optional[TaskStep] = None, mode: ChatResponseMode = ChatResponseMode.WAIT, **kwargs: Any, ) -> TaskStepOutput: """Execute step.""" task = self.state.get_task(task_id) task_queue = self.state.get_step_queue(task_id) step = step or task_queue.popleft() if not step.is_ready: raise ValueError(f"Step {step.step_id} is not ready") if mode == ChatResponseMode.WAIT: cur_step_output: TaskStepOutput = self.agent_worker.run_step( step, task, **kwargs ) elif mode == ChatResponseMode.STREAM: cur_step_output = self.agent_worker.stream_step(step, task, **kwargs) else: raise ValueError(f"Invalid mode: {mode}") for next_step in cur_step_output.next_steps: if next_step.is_ready: task_queue.append(next_step) # add cur_step_output to completed steps completed_steps = self.state.get_completed_steps(task_id) completed_steps.append(cur_step_output) return cur_step_output async def _arun_step( self, task_id: str, step: Optional[TaskStep] = None, mode: ChatResponseMode = ChatResponseMode.WAIT, **kwargs: Any, ) -> TaskStepOutput: """Execute step.""" task = self.state.get_task(task_id) task_queue = self.state.get_step_queue(task_id) step = step or task_queue.popleft() if not step.is_ready: raise ValueError(f"Step {step.step_id} is not ready") if mode == ChatResponseMode.WAIT: cur_step_output = await self.agent_worker.arun_step(step, task, **kwargs) elif mode == ChatResponseMode.STREAM: cur_step_output = await self.agent_worker.astream_step(step, task, **kwargs) else: raise ValueError(f"Invalid mode: {mode}") for next_step in cur_step_output.next_steps: if next_step.is_ready: task_queue.append(next_step) # add cur_step_output to completed steps completed_steps = self.state.get_completed_steps(task_id) completed_steps.append(cur_step_output) return cur_step_output def run_step( self, task_id: str, input: Optional[str] = None, step: Optional[TaskStep] = None, **kwargs: Any, ) -> TaskStepOutput: """Run step.""" return self._run_step(task_id, step, mode=ChatResponseMode.WAIT, **kwargs) async def arun_step( self, task_id: str, input: Optional[str] = None, step: Optional[TaskStep] = None, **kwargs: Any, ) -> TaskStepOutput: """Run step (async).""" return await self._arun_step( task_id, step, mode=ChatResponseMode.WAIT, **kwargs ) def stream_step( self, task_id: str, input: Optional[str] = None, step: Optional[TaskStep] = None, **kwargs: Any, ) -> TaskStepOutput: """Run step (stream).""" return self._run_step(task_id, step, mode=ChatResponseMode.STREAM, **kwargs) async def astream_step( self, task_id: str, input: Optional[str] = None, step: Optional[TaskStep] = None, **kwargs: Any, ) -> TaskStepOutput: """Run step (async stream).""" return await self._arun_step( task_id, step, mode=ChatResponseMode.STREAM, **kwargs ) def finalize_response( self, task_id: str, step_output: Optional[TaskStepOutput] = None, ) -> AGENT_CHAT_RESPONSE_TYPE: """Finalize response.""" if step_output is None: step_output = self.state.get_completed_steps(task_id)[-1] if not step_output.is_last: raise ValueError( "finalize_response can only be called on the last step output" ) if not isinstance( step_output.output, (AgentChatResponse, StreamingAgentChatResponse), ): raise ValueError( "When `is_last` is True, cur_step_output.output must be " f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}" ) # finalize task self.agent_worker.finalize_task(self.state.get_task(task_id)) if self.delete_task_on_finish: self.delete_task(task_id) return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output) def _chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None, tool_choice: Union[str, dict] = "auto", mode: ChatResponseMode = ChatResponseMode.WAIT, ) -> AGENT_CHAT_RESPONSE_TYPE: """Chat with step executor.""" if chat_history is not None: self.memory.set(chat_history) task = self.create_task(message) result_output = None while True: # pass step queue in as argument, assume step executor is stateless cur_step_outputs = self.run_steps_in_queue(task.task_id, mode=mode) # check if a step output is_last is_last = any( cur_step_output.is_last for cur_step_output in cur_step_outputs ) if is_last: if len(cur_step_outputs) > 1: raise ValueError( "More than one step output returned in final step." ) cur_step_output = cur_step_outputs[0] result_output = cur_step_output break return self.finalize_response(task.task_id, result_output) async def _achat( self, message: str, chat_history: Optional[List[ChatMessage]] = None, tool_choice: Union[str, dict] = "auto", mode: ChatResponseMode = ChatResponseMode.WAIT, ) -> AGENT_CHAT_RESPONSE_TYPE: """Chat with step executor.""" if chat_history is not None: self.memory.set(chat_history) task = self.create_task(message) result_output = None while True: # pass step queue in as argument, assume step executor is stateless cur_step_outputs = await self.arun_steps_in_queue(task.task_id, mode=mode) # check if a step output is_last is_last = any( cur_step_output.is_last for cur_step_output in cur_step_outputs ) if is_last: if len(cur_step_outputs) > 1: raise ValueError( "More than one step output returned in final step." ) cur_step_output = cur_step_outputs[0] result_output = cur_step_output break return self.finalize_response(task.task_id, result_output) @trace_method("chat") def chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None, tool_choice: Union[str, dict] = "auto", ) -> AgentChatResponse: with self.callback_manager.event( CBEventType.AGENT_STEP, payload={EventPayload.MESSAGES: [message]}, ) as e: chat_response = self._chat( message, chat_history, tool_choice, mode=ChatResponseMode.WAIT ) assert isinstance(chat_response, AgentChatResponse) e.on_end(payload={EventPayload.RESPONSE: chat_response}) return chat_response @trace_method("chat") async def achat( self, message: str, chat_history: Optional[List[ChatMessage]] = None, tool_choice: Union[str, dict] = "auto", ) -> AgentChatResponse: with self.callback_manager.event( CBEventType.AGENT_STEP, payload={EventPayload.MESSAGES: [message]}, ) as e: chat_response = await self._achat( message, chat_history, tool_choice, mode=ChatResponseMode.WAIT ) assert isinstance(chat_response, AgentChatResponse) e.on_end(payload={EventPayload.RESPONSE: chat_response}) return chat_response @trace_method("chat") def stream_chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None, tool_choice: Union[str, dict] = "auto", ) -> StreamingAgentChatResponse: with self.callback_manager.event( CBEventType.AGENT_STEP, payload={EventPayload.MESSAGES: [message]}, ) as e: chat_response = self._chat( message, chat_history, tool_choice, mode=ChatResponseMode.STREAM ) assert isinstance(chat_response, StreamingAgentChatResponse) e.on_end(payload={EventPayload.RESPONSE: chat_response}) return chat_response @trace_method("chat") async def astream_chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None, tool_choice: Union[str, dict] = "auto", ) -> StreamingAgentChatResponse: with self.callback_manager.event( CBEventType.AGENT_STEP, payload={EventPayload.MESSAGES: [message]}, ) as e: chat_response = await self._achat( message, chat_history, tool_choice, mode=ChatResponseMode.STREAM ) assert isinstance(chat_response, StreamingAgentChatResponse) e.on_end(payload={EventPayload.RESPONSE: chat_response}) return chat_response def undo_step(self, task_id: str) -> None: """Undo previous step.""" raise NotImplementedError("undo_step not implemented")