faiss_rag_enterprise/llama_index/agent/react_multimodal/step.py

480 lines
17 KiB
Python

"""ReAct multimodal agent."""
import uuid
from typing import (
Any,
Dict,
List,
Optional,
Sequence,
Tuple,
cast,
)
from llama_index.agent.react.formatter import ReActChatFormatter
from llama_index.agent.react.output_parser import ReActOutputParser
from llama_index.agent.react.types import (
ActionReasoningStep,
BaseReasoningStep,
ObservationReasoningStep,
ResponseReasoningStep,
)
from llama_index.agent.react_multimodal.prompts import REACT_MM_CHAT_SYSTEM_HEADER
from llama_index.agent.types import (
BaseAgentWorker,
Task,
TaskStep,
TaskStepOutput,
)
from llama_index.callbacks import (
CallbackManager,
CBEventType,
EventPayload,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
AgentChatResponse,
)
from llama_index.core.llms.types import MessageRole
from llama_index.llms.base import ChatMessage, ChatResponse
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
from llama_index.memory.types import BaseMemory
from llama_index.multi_modal_llms.base import MultiModalLLM
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.multi_modal_llms.openai_utils import (
generate_openai_multi_modal_chat_message,
)
from llama_index.objects.base import ObjectRetriever
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
from llama_index.tools.types import AsyncBaseTool
from llama_index.utils import print_text
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
def add_user_step_to_reasoning(
step: TaskStep,
memory: BaseMemory,
current_reasoning: List[BaseReasoningStep],
verbose: bool = False,
) -> None:
"""Add user step to reasoning.
Adds both text input and image input to reasoning.
"""
# raise error if step.input is None
if step.input is None:
raise ValueError("Step input is None.")
# TODO: support gemini as well. Currently just supports OpenAI
# TODO: currently assume that you can't generate images in the loop,
# so step_state contains the original image_docs from the task
# (it doesn't change)
image_docs = step.step_state["image_docs"]
image_kwargs = step.step_state.get("image_kwargs", {})
if "is_first" in step.step_state and step.step_state["is_first"]:
mm_message = generate_openai_multi_modal_chat_message(
prompt=step.input,
role=MessageRole.USER,
image_documents=image_docs,
**image_kwargs,
)
# add to new memory
memory.put(mm_message)
step.step_state["is_first"] = False
else:
# NOTE: this is where the user specifies an intermediate step in the middle
# TODO: don't support specifying image_docs here for now
reasoning_step = ObservationReasoningStep(observation=step.input)
current_reasoning.append(reasoning_step)
if verbose:
print(f"Added user message to memory: {step.input}")
class MultimodalReActAgentWorker(BaseAgentWorker):
"""Multimodal ReAct Agent worker.
**NOTE**: This is a BETA feature.
"""
def __init__(
self,
tools: Sequence[BaseTool],
multi_modal_llm: MultiModalLLM,
max_iterations: int = 10,
react_chat_formatter: Optional[ReActChatFormatter] = None,
output_parser: Optional[ReActOutputParser] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
) -> None:
self._multi_modal_llm = multi_modal_llm
self.callback_manager = callback_manager or CallbackManager([])
self._max_iterations = max_iterations
self._react_chat_formatter = react_chat_formatter or ReActChatFormatter(
system_header=REACT_MM_CHAT_SYSTEM_HEADER
)
self._output_parser = output_parser or ReActOutputParser()
self._verbose = verbose
if len(tools) > 0 and tool_retriever is not None:
raise ValueError("Cannot specify both tools and tool_retriever")
elif len(tools) > 0:
self._get_tools = lambda _: tools
elif tool_retriever is not None:
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
else:
self._get_tools = lambda _: []
@classmethod
def from_tools(
cls,
tools: Optional[Sequence[BaseTool]] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
multi_modal_llm: Optional[MultiModalLLM] = None,
max_iterations: int = 10,
react_chat_formatter: Optional[ReActChatFormatter] = None,
output_parser: Optional[ReActOutputParser] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
**kwargs: Any,
) -> "MultimodalReActAgentWorker":
"""Convenience constructor method from set of of BaseTools (Optional).
NOTE: kwargs should have been exhausted by this point. In other words
the various upstream components such as BaseSynthesizer (response synthesizer)
or BaseRetriever should have picked up off their respective kwargs in their
constructions.
Returns:
ReActAgent
"""
multi_modal_llm = multi_modal_llm or OpenAIMultiModal(
model="gpt-4-vision-preview", max_new_tokens=1000
)
return cls(
tools=tools or [],
tool_retriever=tool_retriever,
multi_modal_llm=multi_modal_llm,
max_iterations=max_iterations,
react_chat_formatter=react_chat_formatter,
output_parser=output_parser,
callback_manager=callback_manager,
verbose=verbose,
)
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
"""Initialize step from task."""
sources: List[ToolOutput] = []
current_reasoning: List[BaseReasoningStep] = []
# temporary memory for new messages
new_memory = ChatMemoryBuffer.from_defaults()
# validation
if "image_docs" not in task.extra_state:
raise ValueError("Image docs not found in task extra state.")
# initialize task state
task_state = {
"sources": sources,
"current_reasoning": current_reasoning,
"new_memory": new_memory,
}
task.extra_state.update(task_state)
return TaskStep(
task_id=task.task_id,
step_id=str(uuid.uuid4()),
input=task.input,
step_state={"is_first": True, "image_docs": task.extra_state["image_docs"]},
)
def get_tools(self, input: str) -> List[AsyncBaseTool]:
"""Get tools."""
return [adapt_to_async_tool(t) for t in self._get_tools(input)]
def _extract_reasoning_step(
self, output: ChatResponse, is_streaming: bool = False
) -> Tuple[str, List[BaseReasoningStep], bool]:
"""
Extracts the reasoning step from the given output.
This method parses the message content from the output,
extracts the reasoning step, and determines whether the processing is
complete. It also performs validation checks on the output and
handles possible errors.
"""
if output.message.content is None:
raise ValueError("Got empty message.")
message_content = output.message.content
current_reasoning = []
try:
reasoning_step = self._output_parser.parse(message_content, is_streaming)
except BaseException as exc:
raise ValueError(f"Could not parse output: {message_content}") from exc
if self._verbose:
print_text(f"{reasoning_step.get_content()}\n", color="pink")
current_reasoning.append(reasoning_step)
if reasoning_step.is_done:
return message_content, current_reasoning, True
reasoning_step = cast(ActionReasoningStep, reasoning_step)
if not isinstance(reasoning_step, ActionReasoningStep):
raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}")
return message_content, current_reasoning, False
def _process_actions(
self,
task: Task,
tools: Sequence[AsyncBaseTool],
output: ChatResponse,
is_streaming: bool = False,
) -> Tuple[List[BaseReasoningStep], bool]:
tools_dict: Dict[str, AsyncBaseTool] = {
tool.metadata.get_name(): tool for tool in tools
}
_, current_reasoning, is_done = self._extract_reasoning_step(
output, is_streaming
)
if is_done:
return current_reasoning, True
# call tool with input
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
tool = tools_dict[reasoning_step.action]
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
EventPayload.TOOL: tool.metadata,
},
) as event:
tool_output = tool.call(**reasoning_step.action_input)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
task.extra_state["sources"].append(tool_output)
observation_step = ObservationReasoningStep(observation=str(tool_output))
current_reasoning.append(observation_step)
if self._verbose:
print_text(f"{observation_step.get_content()}\n", color="blue")
return current_reasoning, False
async def _aprocess_actions(
self,
task: Task,
tools: Sequence[AsyncBaseTool],
output: ChatResponse,
is_streaming: bool = False,
) -> Tuple[List[BaseReasoningStep], bool]:
tools_dict = {tool.metadata.name: tool for tool in tools}
_, current_reasoning, is_done = self._extract_reasoning_step(
output, is_streaming
)
if is_done:
return current_reasoning, True
# call tool with input
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
tool = tools_dict[reasoning_step.action]
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
EventPayload.TOOL: tool.metadata,
},
) as event:
tool_output = await tool.acall(**reasoning_step.action_input)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
task.extra_state["sources"].append(tool_output)
observation_step = ObservationReasoningStep(observation=str(tool_output))
current_reasoning.append(observation_step)
if self._verbose:
print_text(f"{observation_step.get_content()}\n", color="blue")
return current_reasoning, False
def _get_response(
self,
current_reasoning: List[BaseReasoningStep],
sources: List[ToolOutput],
) -> AgentChatResponse:
"""Get response from reasoning steps."""
if len(current_reasoning) == 0:
raise ValueError("No reasoning steps were taken.")
elif len(current_reasoning) == self._max_iterations:
raise ValueError("Reached max iterations.")
if isinstance(current_reasoning[-1], ResponseReasoningStep):
response_step = cast(ResponseReasoningStep, current_reasoning[-1])
response_str = response_step.response
else:
response_str = current_reasoning[-1].get_content()
# TODO: add sources from reasoning steps
return AgentChatResponse(response=response_str, sources=sources)
def _get_task_step_response(
self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
) -> TaskStepOutput:
"""Get task step response."""
if is_done:
new_steps = []
else:
new_steps = [
step.get_next_step(
step_id=str(uuid.uuid4()),
# NOTE: input is unused
input=None,
)
]
return TaskStepOutput(
output=agent_response,
task_step=step,
is_last=is_done,
next_steps=new_steps,
)
def _run_step(
self,
step: TaskStep,
task: Task,
) -> TaskStepOutput:
"""Run step."""
# This is either not None on the first step or if the user specifies
# an intermediate step in the middle
if step.input is not None:
add_user_step_to_reasoning(
step,
task.extra_state["new_memory"],
task.extra_state["current_reasoning"],
verbose=self._verbose,
)
# TODO: see if we want to do step-based inputs
tools = self.get_tools(task.input)
input_chat = self._react_chat_formatter.format(
tools,
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
current_reasoning=task.extra_state["current_reasoning"],
)
# send prompt
chat_response = self._multi_modal_llm.chat(input_chat)
# given react prompt outputs, call tools or return response
reasoning_steps, is_done = self._process_actions(
task, tools, output=chat_response
)
task.extra_state["current_reasoning"].extend(reasoning_steps)
agent_response = self._get_response(
task.extra_state["current_reasoning"], task.extra_state["sources"]
)
if is_done:
task.extra_state["new_memory"].put(
ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
)
return self._get_task_step_response(agent_response, step, is_done)
async def _arun_step(
self,
step: TaskStep,
task: Task,
) -> TaskStepOutput:
"""Run step."""
if step.input is not None:
add_user_step_to_reasoning(
step,
task.extra_state["new_memory"],
task.extra_state["current_reasoning"],
verbose=self._verbose,
)
# TODO: see if we want to do step-based inputs
tools = self.get_tools(task.input)
input_chat = self._react_chat_formatter.format(
tools,
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
current_reasoning=task.extra_state["current_reasoning"],
)
# send prompt
chat_response = await self._multi_modal_llm.achat(input_chat)
# given react prompt outputs, call tools or return response
reasoning_steps, is_done = await self._aprocess_actions(
task, tools, output=chat_response
)
task.extra_state["current_reasoning"].extend(reasoning_steps)
agent_response = self._get_response(
task.extra_state["current_reasoning"], task.extra_state["sources"]
)
if is_done:
task.extra_state["new_memory"].put(
ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
)
return self._get_task_step_response(agent_response, step, is_done)
def _run_step_stream(
self,
step: TaskStep,
task: Task,
) -> TaskStepOutput:
"""Run step."""
raise NotImplementedError("Stream step not implemented yet.")
async def _arun_step_stream(
self,
step: TaskStep,
task: Task,
) -> TaskStepOutput:
"""Run step."""
raise NotImplementedError("Stream step not implemented yet.")
@trace_method("run_step")
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step."""
return self._run_step(step, task)
@trace_method("run_step")
async def arun_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async)."""
return await self._arun_step(step, task)
@trace_method("run_step")
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step (stream)."""
# TODO: figure out if we need a different type for TaskStepOutput
return self._run_step_stream(step, task)
@trace_method("run_step")
async def astream_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async stream)."""
return await self._arun_step_stream(step, task)
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
# add new messages to memory
task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
# reset new memory
task.extra_state["new_memory"].reset()
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
self.callback_manager = callback_manager