faiss_rag_enterprise/llama_index/agent/custom/simple.py

262 lines
8.6 KiB
Python

"""Custom agent worker."""
import uuid
from abc import abstractmethod
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
cast,
)
from llama_index.agent.types import (
BaseAgentWorker,
Task,
TaskStep,
TaskStepOutput,
)
from llama_index.bridge.pydantic import BaseModel, Field, PrivateAttr
from llama_index.callbacks import (
CallbackManager,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
AgentChatResponse,
)
from llama_index.llms.llm import LLM
from llama_index.llms.openai import OpenAI
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
from llama_index.objects.base import ObjectRetriever
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
from llama_index.tools.types import AsyncBaseTool
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
class CustomSimpleAgentWorker(BaseModel, BaseAgentWorker):
"""Custom simple agent worker.
This is "simple" in the sense that some of the scaffolding is setup already.
Assumptions:
- assumes that the agent has tools, llm, callback manager, and tool retriever
- has a `from_tools` convenience function
- assumes that the agent is sequential, and doesn't take in any additional
intermediate inputs.
Args:
tools (Sequence[BaseTool]): Tools to use for reasoning
llm (LLM): LLM to use
callback_manager (CallbackManager): Callback manager
tool_retriever (Optional[ObjectRetriever[BaseTool]]): Tool retriever
verbose (bool): Whether to print out reasoning steps
"""
tools: Sequence[BaseTool] = Field(..., description="Tools to use for reasoning")
llm: LLM = Field(..., description="LLM to use")
callback_manager: CallbackManager = Field(
default_factory=lambda: CallbackManager([]), exclude=True
)
tool_retriever: Optional[ObjectRetriever[BaseTool]] = Field(
default=None, description="Tool retriever"
)
verbose: bool = Field(False, description="Whether to print out reasoning steps")
_get_tools: Callable[[str], Sequence[BaseTool]] = PrivateAttr()
class Config:
arbitrary_types_allowed = True
def __init__(
self,
tools: Sequence[BaseTool],
llm: LLM,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
) -> None:
if len(tools) > 0 and tool_retriever is not None:
raise ValueError("Cannot specify both tools and tool_retriever")
elif len(tools) > 0:
self._get_tools = lambda _: tools
elif tool_retriever is not None:
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
else:
self._get_tools = lambda _: []
super().__init__(
tools=tools,
llm=llm,
callback_manager=callback_manager,
tool_retriever=tool_retriever,
verbose=verbose,
)
@classmethod
def from_tools(
cls,
tools: Optional[Sequence[BaseTool]] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
llm: Optional[LLM] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
**kwargs: Any,
) -> "CustomSimpleAgentWorker":
"""Convenience constructor method from set of of BaseTools (Optional)."""
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
if callback_manager is not None:
llm.callback_manager = callback_manager
return cls(
tools=tools or [],
tool_retriever=tool_retriever,
llm=llm,
callback_manager=callback_manager,
verbose=verbose,
)
@abstractmethod
def _initialize_state(self, task: Task, **kwargs: Any) -> Dict[str, Any]:
"""Initialize state."""
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
"""Initialize step from task."""
sources: List[ToolOutput] = []
# temporary memory for new messages
new_memory = ChatMemoryBuffer.from_defaults()
# initialize initial state
initial_state = {
"sources": sources,
"memory": new_memory,
}
step_state = self._initialize_state(task, **kwargs)
# if intersecting keys, error
if set(step_state.keys()).intersection(set(initial_state.keys())):
raise ValueError(
f"Step state keys {step_state.keys()} and initial state keys {initial_state.keys()} intersect."
f"*NOTE*: initial state keys {initial_state.keys()} are reserved."
)
step_state.update(initial_state)
return TaskStep(
task_id=task.task_id,
step_id=str(uuid.uuid4()),
input=task.input,
step_state=step_state,
)
def get_tools(self, input: str) -> List[AsyncBaseTool]:
"""Get tools."""
return [adapt_to_async_tool(t) for t in self._get_tools(input)]
def _get_task_step_response(
self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
) -> TaskStepOutput:
"""Get task step response."""
if is_done:
new_steps = []
else:
new_steps = [
step.get_next_step(
step_id=str(uuid.uuid4()),
# NOTE: input is unused
input=None,
)
]
return TaskStepOutput(
output=agent_response,
task_step=step,
is_last=is_done,
next_steps=new_steps,
)
@abstractmethod
def _run_step(
self, state: Dict[str, Any], task: Task, input: Optional[str] = None
) -> Tuple[AgentChatResponse, bool]:
"""Run step.
Returns:
Tuple of (agent_response, is_done)
"""
async def _arun_step(
self, state: Dict[str, Any], task: Task, input: Optional[str] = None
) -> Tuple[AgentChatResponse, bool]:
"""Run step (async).
Can override this method if you want to run the step asynchronously.
Returns:
Tuple of (agent_response, is_done)
"""
raise NotImplementedError(
"This agent does not support async." "Please implement _arun_step."
)
@trace_method("run_step")
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step."""
agent_response, is_done = self._run_step(
step.step_state, task, input=step.input
)
response = self._get_task_step_response(agent_response, step, is_done)
# sync step state with task state
task.extra_state.update(step.step_state)
return response
@trace_method("run_step")
async def arun_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async)."""
agent_response, is_done = await self._arun_step(
step.step_state, task, input=step.input
)
response = self._get_task_step_response(agent_response, step, is_done)
task.extra_state.update(step.step_state)
return response
@trace_method("run_step")
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step (stream)."""
raise NotImplementedError("This agent does not support streaming.")
@trace_method("run_step")
async def astream_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async stream)."""
raise NotImplementedError("This agent does not support streaming.")
@abstractmethod
def _finalize_task(self, state: Dict[str, Any], **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed.
State is all the step states.
"""
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
# add new messages to memory
task.memory.set(task.memory.get() + task.extra_state["memory"].get_all())
# reset new memory
task.extra_state["memory"].reset()
self._finalize_task(task.extra_state, **kwargs)
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
self.callback_manager = callback_manager