faiss_rag_enterprise/llama_index/agent/runner/base.py

632 lines
21 KiB
Python

from abc import abstractmethod
from collections import deque
from typing import Any, Deque, Dict, List, Optional, Union, cast
from llama_index.agent.types import (
BaseAgent,
BaseAgentWorker,
Task,
TaskStep,
TaskStepOutput,
)
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.callbacks import (
CallbackManager,
CBEventType,
EventPayload,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
AgentChatResponse,
ChatResponseMode,
StreamingAgentChatResponse,
)
from llama_index.llms.base import ChatMessage
from llama_index.llms.llm import LLM
from llama_index.memory import BaseMemory, ChatMemoryBuffer
from llama_index.memory.types import BaseMemory
from llama_index.tools.types import BaseTool
class BaseAgentRunner(BaseAgent):
"""Base agent runner."""
@abstractmethod
def create_task(self, input: str, **kwargs: Any) -> Task:
"""Create task."""
@abstractmethod
def delete_task(
self,
task_id: str,
) -> None:
"""Delete task.
NOTE: this will not delete any previous executions from memory.
"""
@abstractmethod
def list_tasks(self, **kwargs: Any) -> List[Task]:
"""List tasks."""
@abstractmethod
def get_task(self, task_id: str, **kwargs: Any) -> Task:
"""Get task."""
@abstractmethod
def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
"""Get upcoming steps."""
@abstractmethod
def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
"""Get completed steps."""
def get_completed_step(
self, task_id: str, step_id: str, **kwargs: Any
) -> TaskStepOutput:
"""Get completed step."""
# call get_completed_steps, and then find the right task
completed_steps = self.get_completed_steps(task_id, **kwargs)
for step_output in completed_steps:
if step_output.task_step.step_id == step_id:
return step_output
raise ValueError(f"Could not find step_id: {step_id}")
@abstractmethod
def run_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step."""
@abstractmethod
async def arun_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (async)."""
@abstractmethod
def stream_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (stream)."""
@abstractmethod
async def astream_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (async stream)."""
@abstractmethod
def finalize_response(
self,
task_id: str,
step_output: Optional[TaskStepOutput] = None,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Finalize response."""
@abstractmethod
def undo_step(self, task_id: str) -> None:
"""Undo previous step."""
raise NotImplementedError("undo_step not implemented")
def validate_step_from_args(
task_id: str, input: Optional[str] = None, step: Optional[Any] = None, **kwargs: Any
) -> Optional[TaskStep]:
"""Validate step from args."""
if step is not None:
if input is not None:
raise ValueError("Cannot specify both `step` and `input`")
if not isinstance(step, TaskStep):
raise ValueError(f"step must be TaskStep: {step}")
return step
else:
return None
class TaskState(BaseModel):
"""Task state."""
task: Task = Field(..., description="Task.")
step_queue: Deque[TaskStep] = Field(
default_factory=deque, description="Task step queue."
)
completed_steps: List[TaskStepOutput] = Field(
default_factory=list, description="Completed step outputs."
)
class AgentState(BaseModel):
"""Agent state."""
task_dict: Dict[str, TaskState] = Field(
default_factory=dict, description="Task dictionary."
)
def get_task(self, task_id: str) -> Task:
"""Get task state."""
return self.task_dict[task_id].task
def get_completed_steps(self, task_id: str) -> List[TaskStepOutput]:
"""Get completed steps."""
return self.task_dict[task_id].completed_steps
def get_step_queue(self, task_id: str) -> Deque[TaskStep]:
"""Get step queue."""
return self.task_dict[task_id].step_queue
def reset(self) -> None:
"""Reset."""
self.task_dict = {}
class AgentRunner(BaseAgentRunner):
"""Agent runner.
Top-level agent orchestrator that can create tasks, run each step in a task,
or run a task e2e. Stores state and keeps track of tasks.
Args:
agent_worker (BaseAgentWorker): step executor
chat_history (Optional[List[ChatMessage]], optional): chat history. Defaults to None.
state (Optional[AgentState], optional): agent state. Defaults to None.
memory (Optional[BaseMemory], optional): memory. Defaults to None.
llm (Optional[LLM], optional): LLM. Defaults to None.
callback_manager (Optional[CallbackManager], optional): callback manager. Defaults to None.
init_task_state_kwargs (Optional[dict], optional): init task state kwargs. Defaults to None.
"""
# # TODO: implement this in Pydantic
def __init__(
self,
agent_worker: BaseAgentWorker,
chat_history: Optional[List[ChatMessage]] = None,
state: Optional[AgentState] = None,
memory: Optional[BaseMemory] = None,
llm: Optional[LLM] = None,
callback_manager: Optional[CallbackManager] = None,
init_task_state_kwargs: Optional[dict] = None,
delete_task_on_finish: bool = False,
default_tool_choice: str = "auto",
verbose: bool = False,
) -> None:
"""Initialize."""
self.agent_worker = agent_worker
self.state = state or AgentState()
self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm)
# get and set callback manager
if callback_manager is not None:
self.agent_worker.set_callback_manager(callback_manager)
self.callback_manager = callback_manager
else:
# TODO: This is *temporary*
# Stopgap before having a callback on the BaseAgentWorker interface.
# Doing that requires a bit more refactoring to make sure existing code
# doesn't break.
if hasattr(self.agent_worker, "callback_manager"):
self.callback_manager = (
self.agent_worker.callback_manager or CallbackManager()
)
else:
self.callback_manager = CallbackManager()
self.init_task_state_kwargs = init_task_state_kwargs or {}
self.delete_task_on_finish = delete_task_on_finish
self.default_tool_choice = default_tool_choice
self.verbose = verbose
@staticmethod
def from_llm(
tools: Optional[List[BaseTool]] = None,
llm: Optional[LLM] = None,
**kwargs: Any,
) -> "AgentRunner":
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_utils import is_function_calling_model
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
from llama_index.agent import OpenAIAgent
return OpenAIAgent.from_tools(
tools=tools,
llm=llm,
**kwargs,
)
else:
from llama_index.agent import ReActAgent
return ReActAgent.from_tools(
tools=tools,
llm=llm,
**kwargs,
)
@property
def chat_history(self) -> List[ChatMessage]:
return self.memory.get_all()
def reset(self) -> None:
self.memory.reset()
self.state.reset()
def create_task(self, input: str, **kwargs: Any) -> Task:
"""Create task."""
if not self.init_task_state_kwargs:
extra_state = kwargs.pop("extra_state", {})
else:
if "extra_state" in kwargs:
raise ValueError(
"Cannot specify both `extra_state` and `init_task_state_kwargs`"
)
else:
extra_state = self.init_task_state_kwargs
callback_manager = kwargs.pop("callback_manager", self.callback_manager)
task = Task(
input=input,
memory=self.memory,
extra_state=extra_state,
callback_manager=callback_manager,
**kwargs,
)
# # put input into memory
# self.memory.put(ChatMessage(content=input, role=MessageRole.USER))
# get initial step from task, and put it in the step queue
initial_step = self.agent_worker.initialize_step(task)
task_state = TaskState(
task=task,
step_queue=deque([initial_step]),
)
# add it to state
self.state.task_dict[task.task_id] = task_state
return task
def delete_task(
self,
task_id: str,
) -> None:
"""Delete task.
NOTE: this will not delete any previous executions from memory.
"""
self.state.task_dict.pop(task_id)
def list_tasks(self, **kwargs: Any) -> List[Task]:
"""List tasks."""
return list(self.state.task_dict.values())
def get_task(self, task_id: str, **kwargs: Any) -> Task:
"""Get task."""
return self.state.get_task(task_id)
def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
"""Get upcoming steps."""
return list(self.state.get_step_queue(task_id))
def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
"""Get completed steps."""
return self.state.get_completed_steps(task_id)
def _run_step(
self,
task_id: str,
step: Optional[TaskStep] = None,
input: Optional[str] = None,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any,
) -> TaskStepOutput:
"""Execute step."""
task = self.state.get_task(task_id)
step_queue = self.state.get_step_queue(task_id)
step = step or step_queue.popleft()
if input is not None:
step.input = input
if self.verbose:
print(f"> Running step {step.step_id}. Step input: {step.input}")
# TODO: figure out if you can dynamically swap in different step executors
# not clear when you would do that by theoretically possible
if mode == ChatResponseMode.WAIT:
cur_step_output = self.agent_worker.run_step(step, task, **kwargs)
elif mode == ChatResponseMode.STREAM:
cur_step_output = self.agent_worker.stream_step(step, task, **kwargs)
else:
raise ValueError(f"Invalid mode: {mode}")
# append cur_step_output next steps to queue
next_steps = cur_step_output.next_steps
step_queue.extend(next_steps)
# add cur_step_output to completed steps
completed_steps = self.state.get_completed_steps(task_id)
completed_steps.append(cur_step_output)
return cur_step_output
async def _arun_step(
self,
task_id: str,
step: Optional[TaskStep] = None,
input: Optional[str] = None,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any,
) -> TaskStepOutput:
"""Execute step."""
task = self.state.get_task(task_id)
step_queue = self.state.get_step_queue(task_id)
step = step or step_queue.popleft()
if input is not None:
step.input = input
if self.verbose:
print(f"> Running step {step.step_id}. Step input: {step.input}")
# TODO: figure out if you can dynamically swap in different step executors
# not clear when you would do that by theoretically possible
if mode == ChatResponseMode.WAIT:
cur_step_output = await self.agent_worker.arun_step(step, task, **kwargs)
elif mode == ChatResponseMode.STREAM:
cur_step_output = await self.agent_worker.astream_step(step, task, **kwargs)
else:
raise ValueError(f"Invalid mode: {mode}")
# append cur_step_output next steps to queue
next_steps = cur_step_output.next_steps
step_queue.extend(next_steps)
# add cur_step_output to completed steps
completed_steps = self.state.get_completed_steps(task_id)
completed_steps.append(cur_step_output)
return cur_step_output
def run_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step."""
step = validate_step_from_args(task_id, input, step, **kwargs)
return self._run_step(
task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs
)
async def arun_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (async)."""
step = validate_step_from_args(task_id, input, step, **kwargs)
return await self._arun_step(
task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs
)
def stream_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (stream)."""
step = validate_step_from_args(task_id, input, step, **kwargs)
return self._run_step(
task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs
)
async def astream_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (async stream)."""
step = validate_step_from_args(task_id, input, step, **kwargs)
return await self._arun_step(
task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs
)
def finalize_response(
self,
task_id: str,
step_output: Optional[TaskStepOutput] = None,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Finalize response."""
if step_output is None:
step_output = self.state.get_completed_steps(task_id)[-1]
if not step_output.is_last:
raise ValueError(
"finalize_response can only be called on the last step output"
)
if not isinstance(
step_output.output,
(AgentChatResponse, StreamingAgentChatResponse),
):
raise ValueError(
"When `is_last` is True, cur_step_output.output must be "
f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}"
)
# finalize task
self.agent_worker.finalize_task(self.state.get_task(task_id))
if self.delete_task_on_finish:
self.delete_task(task_id)
return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output)
def _chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)
result_output = None
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_output = self._run_step(
task.task_id, mode=mode, tool_choice=tool_choice
)
if cur_step_output.is_last:
result_output = cur_step_output
break
# ensure tool_choice does not cause endless loops
tool_choice = "auto"
return self.finalize_response(task.task_id, result_output)
async def _achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)
result_output = None
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_output = await self._arun_step(
task.task_id, mode=mode, tool_choice=tool_choice
)
if cur_step_output.is_last:
result_output = cur_step_output
break
# ensure tool_choice does not cause endless loops
tool_choice = "auto"
return self.finalize_response(task.task_id, result_output)
@trace_method("chat")
def chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> AgentChatResponse:
# override tool choice is provided as input.
if tool_choice is None:
tool_choice = self.default_tool_choice
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = self._chat(
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
async def achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> AgentChatResponse:
# override tool choice is provided as input.
if tool_choice is None:
tool_choice = self.default_tool_choice
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = await self._achat(
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
def stream_chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> StreamingAgentChatResponse:
# override tool choice is provided as input.
if tool_choice is None:
tool_choice = self.default_tool_choice
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = self._chat(
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
)
assert isinstance(chat_response, StreamingAgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
async def astream_chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> StreamingAgentChatResponse:
# override tool choice is provided as input.
if tool_choice is None:
tool_choice = self.default_tool_choice
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = await self._achat(
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
)
assert isinstance(chat_response, StreamingAgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
def undo_step(self, task_id: str) -> None:
"""Undo previous step."""
raise NotImplementedError("undo_step not implemented")