faiss_rag_enterprise/llama_index/agent/react/formatter.py

128 lines
4.0 KiB
Python

# ReAct agent formatter
import logging
from abc import abstractmethod
from typing import List, Optional, Sequence
from llama_index.agent.react.prompts import (
CONTEXT_REACT_CHAT_SYSTEM_HEADER,
REACT_CHAT_SYSTEM_HEADER,
)
from llama_index.agent.react.types import BaseReasoningStep, ObservationReasoningStep
from llama_index.bridge.pydantic import BaseModel
from llama_index.core.llms.types import ChatMessage, MessageRole
from llama_index.tools import BaseTool
logger = logging.getLogger(__name__)
def get_react_tool_descriptions(tools: Sequence[BaseTool]) -> List[str]:
"""Tool."""
tool_descs = []
for tool in tools:
tool_desc = (
f"> Tool Name: {tool.metadata.name}\n"
f"Tool Description: {tool.metadata.description}\n"
f"Tool Args: {tool.metadata.fn_schema_str}\n"
)
tool_descs.append(tool_desc)
return tool_descs
# TODO: come up with better name
class BaseAgentChatFormatter(BaseModel):
"""Base chat formatter."""
class Config:
arbitrary_types_allowed = True
@abstractmethod
def format(
self,
tools: Sequence[BaseTool],
chat_history: List[ChatMessage],
current_reasoning: Optional[List[BaseReasoningStep]] = None,
) -> List[ChatMessage]:
"""Format chat history into list of ChatMessage."""
class ReActChatFormatter(BaseAgentChatFormatter):
"""ReAct chat formatter."""
system_header: str = REACT_CHAT_SYSTEM_HEADER # default
context: str = "" # not needed w/ default
def format(
self,
tools: Sequence[BaseTool],
chat_history: List[ChatMessage],
current_reasoning: Optional[List[BaseReasoningStep]] = None,
) -> List[ChatMessage]:
"""Format chat history into list of ChatMessage."""
current_reasoning = current_reasoning or []
format_args = {
"tool_desc": "\n".join(get_react_tool_descriptions(tools)),
"tool_names": ", ".join([tool.metadata.get_name() for tool in tools]),
}
if self.context:
format_args["context"] = self.context
fmt_sys_header = self.system_header.format(**format_args)
# format reasoning history as alternating user and assistant messages
# where the assistant messages are thoughts and actions and the user
# messages are observations
reasoning_history = []
for reasoning_step in current_reasoning:
if isinstance(reasoning_step, ObservationReasoningStep):
message = ChatMessage(
role=MessageRole.USER,
content=reasoning_step.get_content(),
)
else:
message = ChatMessage(
role=MessageRole.ASSISTANT,
content=reasoning_step.get_content(),
)
reasoning_history.append(message)
return [
ChatMessage(role=MessageRole.SYSTEM, content=fmt_sys_header),
*chat_history,
*reasoning_history,
]
@classmethod
def from_defaults(
cls,
system_header: Optional[str] = None,
context: Optional[str] = None,
) -> "ReActChatFormatter":
"""Create ReActChatFormatter from defaults."""
if not system_header:
system_header = (
REACT_CHAT_SYSTEM_HEADER
if not context
else CONTEXT_REACT_CHAT_SYSTEM_HEADER
)
return ReActChatFormatter(
system_header=system_header,
context=context or "",
)
@classmethod
def from_context(cls, context: str) -> "ReActChatFormatter":
"""Create ReActChatFormatter from context.
NOTE: deprecated
"""
logger.warning(
"ReActChatFormatter.from_context is deprecated, please use `from_defaults` instead."
)
return ReActChatFormatter.from_defaults(
system_header=CONTEXT_REACT_CHAT_SYSTEM_HEADER, context=context
)