200 lines
6.5 KiB
Python
200 lines
6.5 KiB
Python
"""Agent worker that takes in a query pipeline."""
|
|
|
|
import uuid
|
|
from typing import (
|
|
Any,
|
|
List,
|
|
Optional,
|
|
cast,
|
|
)
|
|
|
|
from llama_index.agent.types import (
|
|
BaseAgentWorker,
|
|
Task,
|
|
TaskStep,
|
|
TaskStepOutput,
|
|
)
|
|
from llama_index.bridge.pydantic import BaseModel, Field
|
|
from llama_index.callbacks import (
|
|
CallbackManager,
|
|
trace_method,
|
|
)
|
|
from llama_index.chat_engine.types import (
|
|
AGENT_CHAT_RESPONSE_TYPE,
|
|
)
|
|
from llama_index.core.query_pipeline.query_component import QueryComponent
|
|
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
|
|
from llama_index.query_pipeline.components.agent import (
|
|
AgentFnComponent,
|
|
AgentInputComponent,
|
|
BaseAgentComponent,
|
|
)
|
|
from llama_index.query_pipeline.query import QueryPipeline
|
|
from llama_index.tools import ToolOutput
|
|
|
|
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
|
|
|
|
|
|
def _get_agent_components(query_component: QueryComponent) -> List[BaseAgentComponent]:
|
|
"""Get agent components."""
|
|
agent_components: List[BaseAgentComponent] = []
|
|
for c in query_component.sub_query_components:
|
|
if isinstance(c, BaseAgentComponent):
|
|
agent_components.append(cast(BaseAgentComponent, c))
|
|
|
|
if len(c.sub_query_components) > 0:
|
|
agent_components.extend(_get_agent_components(c))
|
|
|
|
return agent_components
|
|
|
|
|
|
class QueryPipelineAgentWorker(BaseModel, BaseAgentWorker):
|
|
"""Query Pipeline agent worker.
|
|
|
|
Barebones agent worker that takes in a query pipeline.
|
|
|
|
Assumes that the first component in the query pipeline is an
|
|
`AgentInputComponent` and last is `AgentFnComponent`.
|
|
|
|
Args:
|
|
pipeline (QueryPipeline): Query pipeline
|
|
|
|
"""
|
|
|
|
pipeline: QueryPipeline = Field(..., description="Query pipeline")
|
|
callback_manager: CallbackManager = Field(..., exclude=True)
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def __init__(
|
|
self,
|
|
pipeline: QueryPipeline,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
) -> None:
|
|
"""Initialize."""
|
|
if callback_manager is not None:
|
|
# set query pipeline callback
|
|
pipeline.set_callback_manager(callback_manager)
|
|
else:
|
|
callback_manager = pipeline.callback_manager
|
|
super().__init__(
|
|
pipeline=pipeline,
|
|
callback_manager=callback_manager,
|
|
)
|
|
# validate query pipeline
|
|
self.agent_input_component
|
|
self.agent_components
|
|
|
|
@property
|
|
def agent_input_component(self) -> AgentInputComponent:
|
|
"""Get agent input component."""
|
|
root_key = self.pipeline.get_root_keys()[0]
|
|
if not isinstance(self.pipeline.module_dict[root_key], AgentInputComponent):
|
|
raise ValueError(
|
|
"Query pipeline first component must be AgentInputComponent, got "
|
|
f"{self.pipeline.module_dict[root_key]}"
|
|
)
|
|
|
|
return cast(AgentInputComponent, self.pipeline.module_dict[root_key])
|
|
|
|
@property
|
|
def agent_components(self) -> List[AgentFnComponent]:
|
|
"""Get agent output component."""
|
|
return _get_agent_components(self.pipeline)
|
|
|
|
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
|
|
"""Initialize step from task."""
|
|
sources: List[ToolOutput] = []
|
|
# temporary memory for new messages
|
|
new_memory = ChatMemoryBuffer.from_defaults()
|
|
|
|
# initialize initial state
|
|
initial_state = {
|
|
"sources": sources,
|
|
"memory": new_memory,
|
|
}
|
|
|
|
return TaskStep(
|
|
task_id=task.task_id,
|
|
step_id=str(uuid.uuid4()),
|
|
input=task.input,
|
|
step_state=initial_state,
|
|
)
|
|
|
|
def _get_task_step_response(
|
|
self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
|
|
) -> TaskStepOutput:
|
|
"""Get task step response."""
|
|
if is_done:
|
|
new_steps = []
|
|
else:
|
|
new_steps = [
|
|
step.get_next_step(
|
|
step_id=str(uuid.uuid4()),
|
|
# NOTE: input is unused
|
|
input=None,
|
|
)
|
|
]
|
|
|
|
return TaskStepOutput(
|
|
output=agent_response,
|
|
task_step=step,
|
|
is_last=is_done,
|
|
next_steps=new_steps,
|
|
)
|
|
|
|
@trace_method("run_step")
|
|
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
|
"""Run step."""
|
|
# partial agent output component with task and step
|
|
for agent_fn_component in self.agent_components:
|
|
agent_fn_component.partial(task=task, state=step.step_state)
|
|
|
|
agent_response, is_done = self.pipeline.run(state=step.step_state, task=task)
|
|
response = self._get_task_step_response(agent_response, step, is_done)
|
|
# sync step state with task state
|
|
task.extra_state.update(step.step_state)
|
|
return response
|
|
|
|
@trace_method("run_step")
|
|
async def arun_step(
|
|
self, step: TaskStep, task: Task, **kwargs: Any
|
|
) -> TaskStepOutput:
|
|
"""Run step (async)."""
|
|
# partial agent output component with task and step
|
|
for agent_fn_component in self.agent_components:
|
|
agent_fn_component.partial(task=task, state=step.step_state)
|
|
|
|
agent_response, is_done = await self.pipeline.arun(
|
|
state=step.step_state, task=task
|
|
)
|
|
response = self._get_task_step_response(agent_response, step, is_done)
|
|
task.extra_state.update(step.step_state)
|
|
return response
|
|
|
|
@trace_method("run_step")
|
|
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
|
"""Run step (stream)."""
|
|
raise NotImplementedError("This agent does not support streaming.")
|
|
|
|
@trace_method("run_step")
|
|
async def astream_step(
|
|
self, step: TaskStep, task: Task, **kwargs: Any
|
|
) -> TaskStepOutput:
|
|
"""Run step (async stream)."""
|
|
raise NotImplementedError("This agent does not support streaming.")
|
|
|
|
def finalize_task(self, task: Task, **kwargs: Any) -> None:
|
|
"""Finalize task, after all the steps are completed."""
|
|
# add new messages to memory
|
|
task.memory.set(task.memory.get() + task.extra_state["memory"].get_all())
|
|
# reset new memory
|
|
task.extra_state["memory"].reset()
|
|
|
|
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
|
"""Set callback manager."""
|
|
# TODO: make this abstractmethod (right now will break some agent impls)
|
|
self.callback_manager = callback_manager
|
|
self.pipeline.set_callback_manager(callback_manager)
|