faiss_rag_enterprise/llama_index/agent/runner/parallel.py

473 lines
15 KiB
Python

"""Agent executor."""
import asyncio
from collections import deque
from typing import Any, Deque, Dict, List, Optional, Union, cast
from llama_index.agent.runner.base import BaseAgentRunner
from llama_index.agent.types import (
BaseAgentWorker,
Task,
TaskStep,
TaskStepOutput,
)
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.callbacks import (
CallbackManager,
CBEventType,
EventPayload,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
AgentChatResponse,
ChatResponseMode,
StreamingAgentChatResponse,
)
from llama_index.llms.base import ChatMessage
from llama_index.llms.llm import LLM
from llama_index.memory import BaseMemory, ChatMemoryBuffer
from llama_index.memory.types import BaseMemory
class DAGTaskState(BaseModel):
"""DAG Task state."""
task: Task = Field(..., description="Task.")
root_step: TaskStep = Field(..., description="Root step.")
step_queue: Deque[TaskStep] = Field(
default_factory=deque, description="Task step queue."
)
completed_steps: List[TaskStepOutput] = Field(
default_factory=list, description="Completed step outputs."
)
@property
def task_id(self) -> str:
"""Task id."""
return self.task.task_id
class DAGAgentState(BaseModel):
"""Agent state."""
task_dict: Dict[str, DAGTaskState] = Field(
default_factory=dict, description="Task dictionary."
)
def get_task(self, task_id: str) -> Task:
"""Get task state."""
return self.task_dict[task_id].task
def get_completed_steps(self, task_id: str) -> List[TaskStepOutput]:
"""Get completed steps."""
return self.task_dict[task_id].completed_steps
def get_step_queue(self, task_id: str) -> Deque[TaskStep]:
"""Get step queue."""
return self.task_dict[task_id].step_queue
class ParallelAgentRunner(BaseAgentRunner):
"""Parallel agent runner.
Executes steps in queue in parallel. Requires async support.
"""
def __init__(
self,
agent_worker: BaseAgentWorker,
chat_history: Optional[List[ChatMessage]] = None,
state: Optional[DAGAgentState] = None,
memory: Optional[BaseMemory] = None,
llm: Optional[LLM] = None,
callback_manager: Optional[CallbackManager] = None,
init_task_state_kwargs: Optional[dict] = None,
delete_task_on_finish: bool = False,
) -> None:
"""Initialize."""
self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm)
self.state = state or DAGAgentState()
self.callback_manager = callback_manager or CallbackManager([])
self.init_task_state_kwargs = init_task_state_kwargs or {}
self.agent_worker = agent_worker
self.delete_task_on_finish = delete_task_on_finish
@property
def chat_history(self) -> List[ChatMessage]:
return self.memory.get_all()
def reset(self) -> None:
self.memory.reset()
def create_task(self, input: str, **kwargs: Any) -> Task:
"""Create task."""
task = Task(
input=input,
memory=self.memory,
extra_state=self.init_task_state_kwargs,
**kwargs,
)
# # put input into memory
# self.memory.put(ChatMessage(content=input, role=MessageRole.USER))
# add it to state
# get initial step from task, and put it in the step queue
initial_step = self.agent_worker.initialize_step(task)
task_state = DAGTaskState(
task=task,
root_step=initial_step,
step_queue=deque([initial_step]),
)
self.state.task_dict[task.task_id] = task_state
return task
def delete_task(
self,
task_id: str,
) -> None:
"""Delete task.
NOTE: this will not delete any previous executions from memory.
"""
self.state.task_dict.pop(task_id)
def list_tasks(self, **kwargs: Any) -> List[Task]:
"""List tasks."""
task_states = list(self.state.task_dict.values())
return [task_state.task for task_state in task_states]
def get_task(self, task_id: str, **kwargs: Any) -> Task:
"""Get task."""
return self.state.get_task(task_id)
def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
"""Get upcoming steps."""
return list(self.state.get_step_queue(task_id))
def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
"""Get completed steps."""
return self.state.get_completed_steps(task_id)
def run_steps_in_queue(
self,
task_id: str,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any,
) -> List[TaskStepOutput]:
"""Execute steps in queue.
Run all steps in queue, clearing it out.
Assume that all steps can be run in parallel.
"""
return asyncio.run(self.arun_steps_in_queue(task_id, mode=mode, **kwargs))
async def arun_steps_in_queue(
self,
task_id: str,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any,
) -> List[TaskStepOutput]:
"""Execute all steps in queue.
All steps in queue are assumed to be ready.
"""
# first pop all steps from step_queue
steps: List[TaskStep] = []
while len(self.state.get_step_queue(task_id)) > 0:
steps.append(self.state.get_step_queue(task_id).popleft())
# take every item in the queue, and run it
tasks = []
for step in steps:
tasks.append(self._arun_step(task_id, step=step, mode=mode, **kwargs))
return await asyncio.gather(*tasks)
def _run_step(
self,
task_id: str,
step: Optional[TaskStep] = None,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any,
) -> TaskStepOutput:
"""Execute step."""
task = self.state.get_task(task_id)
task_queue = self.state.get_step_queue(task_id)
step = step or task_queue.popleft()
if not step.is_ready:
raise ValueError(f"Step {step.step_id} is not ready")
if mode == ChatResponseMode.WAIT:
cur_step_output: TaskStepOutput = self.agent_worker.run_step(
step, task, **kwargs
)
elif mode == ChatResponseMode.STREAM:
cur_step_output = self.agent_worker.stream_step(step, task, **kwargs)
else:
raise ValueError(f"Invalid mode: {mode}")
for next_step in cur_step_output.next_steps:
if next_step.is_ready:
task_queue.append(next_step)
# add cur_step_output to completed steps
completed_steps = self.state.get_completed_steps(task_id)
completed_steps.append(cur_step_output)
return cur_step_output
async def _arun_step(
self,
task_id: str,
step: Optional[TaskStep] = None,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any,
) -> TaskStepOutput:
"""Execute step."""
task = self.state.get_task(task_id)
task_queue = self.state.get_step_queue(task_id)
step = step or task_queue.popleft()
if not step.is_ready:
raise ValueError(f"Step {step.step_id} is not ready")
if mode == ChatResponseMode.WAIT:
cur_step_output = await self.agent_worker.arun_step(step, task, **kwargs)
elif mode == ChatResponseMode.STREAM:
cur_step_output = await self.agent_worker.astream_step(step, task, **kwargs)
else:
raise ValueError(f"Invalid mode: {mode}")
for next_step in cur_step_output.next_steps:
if next_step.is_ready:
task_queue.append(next_step)
# add cur_step_output to completed steps
completed_steps = self.state.get_completed_steps(task_id)
completed_steps.append(cur_step_output)
return cur_step_output
def run_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step."""
return self._run_step(task_id, step, mode=ChatResponseMode.WAIT, **kwargs)
async def arun_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (async)."""
return await self._arun_step(
task_id, step, mode=ChatResponseMode.WAIT, **kwargs
)
def stream_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (stream)."""
return self._run_step(task_id, step, mode=ChatResponseMode.STREAM, **kwargs)
async def astream_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (async stream)."""
return await self._arun_step(
task_id, step, mode=ChatResponseMode.STREAM, **kwargs
)
def finalize_response(
self,
task_id: str,
step_output: Optional[TaskStepOutput] = None,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Finalize response."""
if step_output is None:
step_output = self.state.get_completed_steps(task_id)[-1]
if not step_output.is_last:
raise ValueError(
"finalize_response can only be called on the last step output"
)
if not isinstance(
step_output.output,
(AgentChatResponse, StreamingAgentChatResponse),
):
raise ValueError(
"When `is_last` is True, cur_step_output.output must be "
f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}"
)
# finalize task
self.agent_worker.finalize_task(self.state.get_task(task_id))
if self.delete_task_on_finish:
self.delete_task(task_id)
return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output)
def _chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)
result_output = None
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_outputs = self.run_steps_in_queue(task.task_id, mode=mode)
# check if a step output is_last
is_last = any(
cur_step_output.is_last for cur_step_output in cur_step_outputs
)
if is_last:
if len(cur_step_outputs) > 1:
raise ValueError(
"More than one step output returned in final step."
)
cur_step_output = cur_step_outputs[0]
result_output = cur_step_output
break
return self.finalize_response(task.task_id, result_output)
async def _achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)
result_output = None
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_outputs = await self.arun_steps_in_queue(task.task_id, mode=mode)
# check if a step output is_last
is_last = any(
cur_step_output.is_last for cur_step_output in cur_step_outputs
)
if is_last:
if len(cur_step_outputs) > 1:
raise ValueError(
"More than one step output returned in final step."
)
cur_step_output = cur_step_outputs[0]
result_output = cur_step_output
break
return self.finalize_response(task.task_id, result_output)
@trace_method("chat")
def chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> AgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = self._chat(
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
async def achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> AgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = await self._achat(
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
def stream_chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> StreamingAgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = self._chat(
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
)
assert isinstance(chat_response, StreamingAgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
async def astream_chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> StreamingAgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = await self._achat(
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
)
assert isinstance(chat_response, StreamingAgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
def undo_step(self, task_id: str) -> None:
"""Undo previous step."""
raise NotImplementedError("undo_step not implemented")