faiss_rag_enterprise/llama_index/agent/custom/pipeline_worker.py

200 lines
6.5 KiB
Python

"""Agent worker that takes in a query pipeline."""
import uuid
from typing import (
Any,
List,
Optional,
cast,
)
from llama_index.agent.types import (
BaseAgentWorker,
Task,
TaskStep,
TaskStepOutput,
)
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.callbacks import (
CallbackManager,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
)
from llama_index.core.query_pipeline.query_component import QueryComponent
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
from llama_index.query_pipeline.components.agent import (
AgentFnComponent,
AgentInputComponent,
BaseAgentComponent,
)
from llama_index.query_pipeline.query import QueryPipeline
from llama_index.tools import ToolOutput
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
def _get_agent_components(query_component: QueryComponent) -> List[BaseAgentComponent]:
"""Get agent components."""
agent_components: List[BaseAgentComponent] = []
for c in query_component.sub_query_components:
if isinstance(c, BaseAgentComponent):
agent_components.append(cast(BaseAgentComponent, c))
if len(c.sub_query_components) > 0:
agent_components.extend(_get_agent_components(c))
return agent_components
class QueryPipelineAgentWorker(BaseModel, BaseAgentWorker):
"""Query Pipeline agent worker.
Barebones agent worker that takes in a query pipeline.
Assumes that the first component in the query pipeline is an
`AgentInputComponent` and last is `AgentFnComponent`.
Args:
pipeline (QueryPipeline): Query pipeline
"""
pipeline: QueryPipeline = Field(..., description="Query pipeline")
callback_manager: CallbackManager = Field(..., exclude=True)
class Config:
arbitrary_types_allowed = True
def __init__(
self,
pipeline: QueryPipeline,
callback_manager: Optional[CallbackManager] = None,
) -> None:
"""Initialize."""
if callback_manager is not None:
# set query pipeline callback
pipeline.set_callback_manager(callback_manager)
else:
callback_manager = pipeline.callback_manager
super().__init__(
pipeline=pipeline,
callback_manager=callback_manager,
)
# validate query pipeline
self.agent_input_component
self.agent_components
@property
def agent_input_component(self) -> AgentInputComponent:
"""Get agent input component."""
root_key = self.pipeline.get_root_keys()[0]
if not isinstance(self.pipeline.module_dict[root_key], AgentInputComponent):
raise ValueError(
"Query pipeline first component must be AgentInputComponent, got "
f"{self.pipeline.module_dict[root_key]}"
)
return cast(AgentInputComponent, self.pipeline.module_dict[root_key])
@property
def agent_components(self) -> List[AgentFnComponent]:
"""Get agent output component."""
return _get_agent_components(self.pipeline)
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
"""Initialize step from task."""
sources: List[ToolOutput] = []
# temporary memory for new messages
new_memory = ChatMemoryBuffer.from_defaults()
# initialize initial state
initial_state = {
"sources": sources,
"memory": new_memory,
}
return TaskStep(
task_id=task.task_id,
step_id=str(uuid.uuid4()),
input=task.input,
step_state=initial_state,
)
def _get_task_step_response(
self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
) -> TaskStepOutput:
"""Get task step response."""
if is_done:
new_steps = []
else:
new_steps = [
step.get_next_step(
step_id=str(uuid.uuid4()),
# NOTE: input is unused
input=None,
)
]
return TaskStepOutput(
output=agent_response,
task_step=step,
is_last=is_done,
next_steps=new_steps,
)
@trace_method("run_step")
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step."""
# partial agent output component with task and step
for agent_fn_component in self.agent_components:
agent_fn_component.partial(task=task, state=step.step_state)
agent_response, is_done = self.pipeline.run(state=step.step_state, task=task)
response = self._get_task_step_response(agent_response, step, is_done)
# sync step state with task state
task.extra_state.update(step.step_state)
return response
@trace_method("run_step")
async def arun_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async)."""
# partial agent output component with task and step
for agent_fn_component in self.agent_components:
agent_fn_component.partial(task=task, state=step.step_state)
agent_response, is_done = await self.pipeline.arun(
state=step.step_state, task=task
)
response = self._get_task_step_response(agent_response, step, is_done)
task.extra_state.update(step.step_state)
return response
@trace_method("run_step")
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step (stream)."""
raise NotImplementedError("This agent does not support streaming.")
@trace_method("run_step")
async def astream_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async stream)."""
raise NotImplementedError("This agent does not support streaming.")
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
# add new messages to memory
task.memory.set(task.memory.get() + task.extra_state["memory"].get_all())
# reset new memory
task.extra_state["memory"].reset()
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
self.callback_manager = callback_manager
self.pipeline.set_callback_manager(callback_manager)