faiss_rag_enterprise/llama_index/agent/legacy/react/base.py

527 lines
19 KiB
Python

import asyncio
from itertools import chain
from threading import Thread
from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Sequence,
Tuple,
Type,
cast,
)
from llama_index.agent.react.formatter import ReActChatFormatter
from llama_index.agent.react.output_parser import ReActOutputParser
from llama_index.agent.react.types import (
ActionReasoningStep,
BaseReasoningStep,
ObservationReasoningStep,
ResponseReasoningStep,
)
from llama_index.agent.types import BaseAgent
from llama_index.callbacks import (
CallbackManager,
CBEventType,
EventPayload,
trace_method,
)
from llama_index.chat_engine.types import AgentChatResponse, StreamingAgentChatResponse
from llama_index.core.llms.types import MessageRole
from llama_index.llms.base import ChatMessage, ChatResponse
from llama_index.llms.llm import LLM
from llama_index.llms.openai import OpenAI
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
from llama_index.memory.types import BaseMemory
from llama_index.objects.base import ObjectRetriever
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
from llama_index.tools.types import AsyncBaseTool
from llama_index.utils import print_text, unit_generator
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
class ReActAgent(BaseAgent):
"""ReAct agent.
Uses a ReAct prompt that can be used in both chat and text
completion endpoints.
Can take in a set of tools that require structured inputs.
"""
def __init__(
self,
tools: Sequence[BaseTool],
llm: LLM,
memory: BaseMemory,
max_iterations: int = 10,
react_chat_formatter: Optional[ReActChatFormatter] = None,
output_parser: Optional[ReActOutputParser] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
) -> None:
super().__init__(callback_manager=callback_manager or llm.callback_manager)
self._llm = llm
self._memory = memory
self._max_iterations = max_iterations
self._react_chat_formatter = react_chat_formatter or ReActChatFormatter()
self._output_parser = output_parser or ReActOutputParser()
self._verbose = verbose
self.sources: List[ToolOutput] = []
if len(tools) > 0 and tool_retriever is not None:
raise ValueError("Cannot specify both tools and tool_retriever")
elif len(tools) > 0:
self._get_tools = lambda _: tools
elif tool_retriever is not None:
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
else:
self._get_tools = lambda _: []
@classmethod
def from_tools(
cls,
tools: Optional[List[BaseTool]] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
llm: Optional[LLM] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
max_iterations: int = 10,
react_chat_formatter: Optional[ReActChatFormatter] = None,
output_parser: Optional[ReActOutputParser] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
**kwargs: Any,
) -> "ReActAgent":
"""Convenience constructor method from set of of BaseTools (Optional).
NOTE: kwargs should have been exhausted by this point. In other words
the various upstream components such as BaseSynthesizer (response synthesizer)
or BaseRetriever should have picked up off their respective kwargs in their
constructions.
Returns:
ReActAgent
"""
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
if callback_manager is not None:
llm.callback_manager = callback_manager
memory = memory or memory_cls.from_defaults(
chat_history=chat_history or [], llm=llm
)
return cls(
tools=tools or [],
tool_retriever=tool_retriever,
llm=llm,
memory=memory,
max_iterations=max_iterations,
react_chat_formatter=react_chat_formatter,
output_parser=output_parser,
callback_manager=callback_manager,
verbose=verbose,
)
@property
def chat_history(self) -> List[ChatMessage]:
"""Chat history."""
return self._memory.get_all()
def reset(self) -> None:
self._memory.reset()
def _extract_reasoning_step(
self, output: ChatResponse, is_streaming: bool = False
) -> Tuple[str, List[BaseReasoningStep], bool]:
"""
Extracts the reasoning step from the given output.
This method parses the message content from the output,
extracts the reasoning step, and determines whether the processing is
complete. It also performs validation checks on the output and
handles possible errors.
"""
if output.message.content is None:
raise ValueError("Got empty message.")
message_content = output.message.content
current_reasoning = []
try:
reasoning_step = self._output_parser.parse(message_content, is_streaming)
except BaseException as exc:
raise ValueError(f"Could not parse output: {message_content}") from exc
if self._verbose:
print_text(f"{reasoning_step.get_content()}\n", color="pink")
current_reasoning.append(reasoning_step)
if reasoning_step.is_done:
return message_content, current_reasoning, True
reasoning_step = cast(ActionReasoningStep, reasoning_step)
if not isinstance(reasoning_step, ActionReasoningStep):
raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}")
return message_content, current_reasoning, False
def _process_actions(
self,
tools: Sequence[AsyncBaseTool],
output: ChatResponse,
is_streaming: bool = False,
) -> Tuple[List[BaseReasoningStep], bool]:
tools_dict: Dict[str, AsyncBaseTool] = {
tool.metadata.get_name(): tool for tool in tools
}
_, current_reasoning, is_done = self._extract_reasoning_step(
output, is_streaming
)
if is_done:
return current_reasoning, True
# call tool with input
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
tool = tools_dict[reasoning_step.action]
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
EventPayload.TOOL: tool.metadata,
},
) as event:
tool_output = tool.call(**reasoning_step.action_input)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
self.sources.append(tool_output)
observation_step = ObservationReasoningStep(observation=str(tool_output))
current_reasoning.append(observation_step)
if self._verbose:
print_text(f"{observation_step.get_content()}\n", color="blue")
return current_reasoning, False
async def _aprocess_actions(
self,
tools: Sequence[AsyncBaseTool],
output: ChatResponse,
is_streaming: bool = False,
) -> Tuple[List[BaseReasoningStep], bool]:
tools_dict = {tool.metadata.name: tool for tool in tools}
_, current_reasoning, is_done = self._extract_reasoning_step(
output, is_streaming
)
if is_done:
return current_reasoning, True
# call tool with input
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
tool = tools_dict[reasoning_step.action]
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
EventPayload.TOOL: tool.metadata,
},
) as event:
tool_output = await tool.acall(**reasoning_step.action_input)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
self.sources.append(tool_output)
observation_step = ObservationReasoningStep(observation=str(tool_output))
current_reasoning.append(observation_step)
if self._verbose:
print_text(f"{observation_step.get_content()}\n", color="blue")
return current_reasoning, False
def _get_response(
self,
current_reasoning: List[BaseReasoningStep],
) -> AgentChatResponse:
"""Get response from reasoning steps."""
if len(current_reasoning) == 0:
raise ValueError("No reasoning steps were taken.")
elif len(current_reasoning) == self._max_iterations:
raise ValueError("Reached max iterations.")
response_step = cast(ResponseReasoningStep, current_reasoning[-1])
# TODO: add sources from reasoning steps
return AgentChatResponse(response=response_step.response, sources=self.sources)
def _infer_stream_chunk_is_final(self, chunk: ChatResponse) -> bool:
"""Infers if a chunk from a live stream is the start of the final
reasoning step. (i.e., and should eventually become
ResponseReasoningStep — not part of this function's logic tho.).
Args:
chunk (ChatResponse): the current chunk stream to check
Returns:
bool: Boolean on whether the chunk is the start of the final response
"""
latest_content = chunk.message.content
if latest_content:
if not latest_content.startswith(
"Thought"
): # doesn't follow thought-action format
return True
else:
if "Answer: " in latest_content:
return True
return False
def _add_back_chunk_to_stream(
self, chunk: ChatResponse, chat_stream: Generator[ChatResponse, None, None]
) -> Generator[ChatResponse, None, None]:
"""Helper method for adding back initial chunk stream of final response
back to the rest of the chat_stream.
Args:
chunk (ChatResponse): the chunk to add back to the beginning of the
chat_stream.
Return:
Generator[ChatResponse, None, None]: the updated chat_stream
"""
updated_stream = chain.from_iterable( # need to add back partial response chunk
[
unit_generator(chunk),
chat_stream,
]
)
# use cast to avoid mypy issue with chain and Generator
updated_stream_c: Generator[ChatResponse, None, None] = cast(
Generator[ChatResponse, None, None], updated_stream
)
return updated_stream_c
async def _async_add_back_chunk_to_stream(
self, chunk: ChatResponse, chat_stream: AsyncGenerator[ChatResponse, None]
) -> AsyncGenerator[ChatResponse, None]:
"""Helper method for adding back initial chunk stream of final response
back to the rest of the chat_stream.
NOTE: this itself is not an async function.
Args:
chunk (ChatResponse): the chunk to add back to the beginning of the
chat_stream.
Return:
AsyncGenerator[ChatResponse, None]: the updated async chat_stream
"""
yield chunk
async for item in chat_stream:
yield item
@trace_method("chat")
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
"""Chat."""
# get tools
# TODO: do get tools dynamically at every iteration of the agent loop
self.sources = []
tools = self.get_tools(message)
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
current_reasoning: List[BaseReasoningStep] = []
# start loop
for _ in range(self._max_iterations):
# prepare inputs
input_chat = self._react_chat_formatter.format(
tools,
chat_history=self._memory.get(),
current_reasoning=current_reasoning,
)
# send prompt
chat_response = self._llm.chat(input_chat)
# given react prompt outputs, call tools or return response
reasoning_steps, is_done = self._process_actions(
tools, output=chat_response
)
current_reasoning.extend(reasoning_steps)
if is_done:
break
response = self._get_response(current_reasoning)
self._memory.put(
ChatMessage(content=response.response, role=MessageRole.ASSISTANT)
)
return response
@trace_method("chat")
async def achat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
# get tools
# TODO: do get tools dynamically at every iteration of the agent loop
self.sources = []
tools = self.get_tools(message)
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
current_reasoning: List[BaseReasoningStep] = []
# start loop
for _ in range(self._max_iterations):
# prepare inputs
input_chat = self._react_chat_formatter.format(
tools,
chat_history=self._memory.get(),
current_reasoning=current_reasoning,
)
# send prompt
chat_response = await self._llm.achat(input_chat)
# given react prompt outputs, call tools or return response
reasoning_steps, is_done = await self._aprocess_actions(
tools, output=chat_response
)
current_reasoning.extend(reasoning_steps)
if is_done:
break
response = self._get_response(current_reasoning)
self._memory.put(
ChatMessage(content=response.response, role=MessageRole.ASSISTANT)
)
return response
@trace_method("chat")
def stream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
# get tools
# TODO: do get tools dynamically at every iteration of the agent loop
self.sources = []
tools = self.get_tools(message)
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
current_reasoning: List[BaseReasoningStep] = []
# start loop
is_done, ix = False, 0
while (not is_done) and (ix < self._max_iterations):
ix += 1
# prepare inputs
input_chat = self._react_chat_formatter.format(
tools,
chat_history=self._memory.get(),
current_reasoning=current_reasoning,
)
# send prompt
chat_stream = self._llm.stream_chat(input_chat)
# iterate over stream, break out if is final answer after the "Answer: "
full_response = ChatResponse(
message=ChatMessage(content=None, role="assistant")
)
for latest_chunk in chat_stream:
full_response = latest_chunk
is_done = self._infer_stream_chunk_is_final(latest_chunk)
if is_done:
break
# given react prompt outputs, call tools or return response
reasoning_steps, _ = self._process_actions(
tools=tools, output=full_response, is_streaming=True
)
current_reasoning.extend(reasoning_steps)
# Get the response in a separate thread so we can yield the response
response_stream = self._add_back_chunk_to_stream(
chunk=latest_chunk, chat_stream=chat_stream
)
chat_stream_response = StreamingAgentChatResponse(
chat_stream=response_stream,
sources=self.sources,
)
thread = Thread(
target=chat_stream_response.write_response_to_history,
args=(self._memory,),
)
thread.start()
return chat_stream_response
@trace_method("chat")
async def astream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
# get tools
# TODO: do get tools dynamically at every iteration of the agent loop
self.sources = []
tools = self.get_tools(message)
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
current_reasoning: List[BaseReasoningStep] = []
# start loop
is_done, ix = False, 0
while (not is_done) and (ix < self._max_iterations):
ix += 1
# prepare inputs
input_chat = self._react_chat_formatter.format(
tools,
chat_history=self._memory.get(),
current_reasoning=current_reasoning,
)
# send prompt
chat_stream = await self._llm.astream_chat(input_chat)
# iterate over stream, break out if is final answer
is_done = False
full_response = ChatResponse(
message=ChatMessage(content=None, role="assistant")
)
async for latest_chunk in chat_stream:
full_response = latest_chunk
is_done = self._infer_stream_chunk_is_final(latest_chunk)
if is_done:
break
# given react prompt outputs, call tools or return response
reasoning_steps, _ = self._process_actions(
tools=tools, output=full_response, is_streaming=True
)
current_reasoning.extend(reasoning_steps)
# Get the response in a separate thread so we can yield the response
response_stream = self._async_add_back_chunk_to_stream(
chunk=latest_chunk, chat_stream=chat_stream
)
chat_stream_response = StreamingAgentChatResponse(
achat_stream=response_stream, sources=self.sources
)
# create task to write chat response to history
asyncio.create_task(
chat_stream_response.awrite_response_to_history(self._memory)
)
# thread.start()
return chat_stream_response
def get_tools(self, message: str) -> List[AsyncBaseTool]:
"""Get tools."""
return [adapt_to_async_tool(t) for t in self._get_tools(message)]