645 lines
22 KiB
Python
645 lines
22 KiB
Python
"""OpenAI agent worker."""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from threading import Thread
|
|
from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args
|
|
|
|
from llama_index.agent.openai.utils import resolve_tool_choice
|
|
from llama_index.agent.types import (
|
|
BaseAgentWorker,
|
|
Task,
|
|
TaskStep,
|
|
TaskStepOutput,
|
|
)
|
|
from llama_index.agent.utils import add_user_step_to_memory
|
|
from llama_index.callbacks import (
|
|
CallbackManager,
|
|
CBEventType,
|
|
EventPayload,
|
|
trace_method,
|
|
)
|
|
from llama_index.chat_engine.types import (
|
|
AGENT_CHAT_RESPONSE_TYPE,
|
|
AgentChatResponse,
|
|
ChatResponseMode,
|
|
StreamingAgentChatResponse,
|
|
)
|
|
from llama_index.core.llms.types import MessageRole
|
|
from llama_index.llms.base import ChatMessage, ChatResponse
|
|
from llama_index.llms.llm import LLM
|
|
from llama_index.llms.openai import OpenAI
|
|
from llama_index.llms.openai_utils import OpenAIToolCall
|
|
from llama_index.memory import BaseMemory, ChatMemoryBuffer
|
|
from llama_index.memory.types import BaseMemory
|
|
from llama_index.objects.base import ObjectRetriever
|
|
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.WARNING)
|
|
|
|
DEFAULT_MAX_FUNCTION_CALLS = 5
|
|
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
|
|
|
|
|
|
def get_function_by_name(tools: List[BaseTool], name: str) -> BaseTool:
|
|
"""Get function by name."""
|
|
name_to_tool = {tool.metadata.name: tool for tool in tools}
|
|
if name not in name_to_tool:
|
|
raise ValueError(f"Tool with name {name} not found")
|
|
return name_to_tool[name]
|
|
|
|
|
|
def call_tool_with_error_handling(
|
|
tool: BaseTool,
|
|
input_dict: Dict,
|
|
error_message: Optional[str] = None,
|
|
raise_error: bool = False,
|
|
) -> ToolOutput:
|
|
"""Call tool with error handling.
|
|
|
|
Input is a dictionary with args and kwargs
|
|
|
|
"""
|
|
try:
|
|
return tool(**input_dict)
|
|
except Exception as e:
|
|
if raise_error:
|
|
raise
|
|
error_message = error_message or f"Error: {e!s}"
|
|
return ToolOutput(
|
|
content=error_message,
|
|
tool_name=tool.metadata.name,
|
|
raw_input={"kwargs": input_dict},
|
|
raw_output=e,
|
|
)
|
|
|
|
|
|
def call_function(
|
|
tools: List[BaseTool],
|
|
tool_call: OpenAIToolCall,
|
|
verbose: bool = False,
|
|
) -> Tuple[ChatMessage, ToolOutput]:
|
|
"""Call a function and return the output as a string."""
|
|
# validations to get passed mypy
|
|
assert tool_call.id is not None
|
|
assert tool_call.function is not None
|
|
assert tool_call.function.name is not None
|
|
assert tool_call.function.arguments is not None
|
|
|
|
id_ = tool_call.id
|
|
function_call = tool_call.function
|
|
name = tool_call.function.name
|
|
arguments_str = tool_call.function.arguments
|
|
if verbose:
|
|
print("=== Calling Function ===")
|
|
print(f"Calling function: {name} with args: {arguments_str}")
|
|
tool = get_function_by_name(tools, name)
|
|
argument_dict = json.loads(arguments_str)
|
|
|
|
# Call tool
|
|
# Use default error message
|
|
output = call_tool_with_error_handling(tool, argument_dict, error_message=None)
|
|
if verbose:
|
|
print(f"Got output: {output!s}")
|
|
print("========================\n")
|
|
return (
|
|
ChatMessage(
|
|
content=str(output),
|
|
role=MessageRole.TOOL,
|
|
additional_kwargs={
|
|
"name": name,
|
|
"tool_call_id": id_,
|
|
},
|
|
),
|
|
output,
|
|
)
|
|
|
|
|
|
async def acall_function(
|
|
tools: List[BaseTool], tool_call: OpenAIToolCall, verbose: bool = False
|
|
) -> Tuple[ChatMessage, ToolOutput]:
|
|
"""Call a function and return the output as a string."""
|
|
# validations to get passed mypy
|
|
assert tool_call.id is not None
|
|
assert tool_call.function is not None
|
|
assert tool_call.function.name is not None
|
|
assert tool_call.function.arguments is not None
|
|
|
|
id_ = tool_call.id
|
|
function_call = tool_call.function
|
|
name = tool_call.function.name
|
|
arguments_str = tool_call.function.arguments
|
|
if verbose:
|
|
print("=== Calling Function ===")
|
|
print(f"Calling function: {name} with args: {arguments_str}")
|
|
tool = get_function_by_name(tools, name)
|
|
async_tool = adapt_to_async_tool(tool)
|
|
argument_dict = json.loads(arguments_str)
|
|
output = await async_tool.acall(**argument_dict)
|
|
if verbose:
|
|
print(f"Got output: {output!s}")
|
|
print("========================\n")
|
|
return (
|
|
ChatMessage(
|
|
content=str(output),
|
|
role=MessageRole.TOOL,
|
|
additional_kwargs={
|
|
"name": name,
|
|
"tool_call_id": id_,
|
|
},
|
|
),
|
|
output,
|
|
)
|
|
|
|
|
|
class OpenAIAgentWorker(BaseAgentWorker):
|
|
"""OpenAI Agent agent worker."""
|
|
|
|
def __init__(
|
|
self,
|
|
tools: List[BaseTool],
|
|
llm: OpenAI,
|
|
prefix_messages: List[ChatMessage],
|
|
verbose: bool = False,
|
|
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
|
):
|
|
self._llm = llm
|
|
self._verbose = verbose
|
|
self._max_function_calls = max_function_calls
|
|
self.prefix_messages = prefix_messages
|
|
self.callback_manager = callback_manager or self._llm.callback_manager
|
|
|
|
if len(tools) > 0 and tool_retriever is not None:
|
|
raise ValueError("Cannot specify both tools and tool_retriever")
|
|
elif len(tools) > 0:
|
|
self._get_tools = lambda _: tools
|
|
elif tool_retriever is not None:
|
|
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
|
|
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
|
|
else:
|
|
# no tools
|
|
self._get_tools = lambda _: []
|
|
|
|
@classmethod
|
|
def from_tools(
|
|
cls,
|
|
tools: Optional[List[BaseTool]] = None,
|
|
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
|
llm: Optional[LLM] = None,
|
|
verbose: bool = False,
|
|
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
system_prompt: Optional[str] = None,
|
|
prefix_messages: Optional[List[ChatMessage]] = None,
|
|
**kwargs: Any,
|
|
) -> "OpenAIAgentWorker":
|
|
"""Create an OpenAIAgent from a list of tools.
|
|
|
|
Similar to `from_defaults` in other classes, this method will
|
|
infer defaults for a variety of parameters, including the LLM,
|
|
if they are not specified.
|
|
|
|
"""
|
|
tools = tools or []
|
|
|
|
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
|
|
if not isinstance(llm, OpenAI):
|
|
raise ValueError("llm must be a OpenAI instance")
|
|
|
|
if callback_manager is not None:
|
|
llm.callback_manager = callback_manager
|
|
|
|
if not llm.metadata.is_function_calling_model:
|
|
raise ValueError(
|
|
f"Model name {llm.model} does not support function calling API. "
|
|
)
|
|
|
|
if system_prompt is not None:
|
|
if prefix_messages is not None:
|
|
raise ValueError(
|
|
"Cannot specify both system_prompt and prefix_messages"
|
|
)
|
|
prefix_messages = [ChatMessage(content=system_prompt, role="system")]
|
|
|
|
prefix_messages = prefix_messages or []
|
|
|
|
return cls(
|
|
tools=tools,
|
|
tool_retriever=tool_retriever,
|
|
llm=llm,
|
|
prefix_messages=prefix_messages,
|
|
verbose=verbose,
|
|
max_function_calls=max_function_calls,
|
|
callback_manager=callback_manager,
|
|
)
|
|
|
|
def get_all_messages(self, task: Task) -> List[ChatMessage]:
|
|
return (
|
|
self.prefix_messages
|
|
+ task.memory.get()
|
|
+ task.extra_state["new_memory"].get_all()
|
|
)
|
|
|
|
def get_latest_tool_calls(self, task: Task) -> Optional[List[OpenAIToolCall]]:
|
|
chat_history: List[ChatMessage] = task.extra_state["new_memory"].get_all()
|
|
return (
|
|
chat_history[-1].additional_kwargs.get("tool_calls", None)
|
|
if chat_history
|
|
else None
|
|
)
|
|
|
|
def _get_llm_chat_kwargs(
|
|
self,
|
|
task: Task,
|
|
openai_tools: List[dict],
|
|
tool_choice: Union[str, dict] = "auto",
|
|
) -> Dict[str, Any]:
|
|
llm_chat_kwargs: dict = {"messages": self.get_all_messages(task)}
|
|
if openai_tools:
|
|
llm_chat_kwargs.update(
|
|
tools=openai_tools, tool_choice=resolve_tool_choice(tool_choice)
|
|
)
|
|
return llm_chat_kwargs
|
|
|
|
def _process_message(
|
|
self, task: Task, chat_response: ChatResponse
|
|
) -> AgentChatResponse:
|
|
ai_message = chat_response.message
|
|
task.extra_state["new_memory"].put(ai_message)
|
|
return AgentChatResponse(
|
|
response=str(ai_message.content), sources=task.extra_state["sources"]
|
|
)
|
|
|
|
def _get_stream_ai_response(
|
|
self, task: Task, **llm_chat_kwargs: Any
|
|
) -> StreamingAgentChatResponse:
|
|
chat_stream_response = StreamingAgentChatResponse(
|
|
chat_stream=self._llm.stream_chat(**llm_chat_kwargs),
|
|
sources=task.extra_state["sources"],
|
|
)
|
|
# Get the response in a separate thread so we can yield the response
|
|
thread = Thread(
|
|
target=chat_stream_response.write_response_to_history,
|
|
args=(task.extra_state["new_memory"],),
|
|
)
|
|
thread.start()
|
|
# Wait for the event to be set
|
|
chat_stream_response._is_function_not_none_thread_event.wait()
|
|
# If it is executing an openAI function, wait for the thread to finish
|
|
if chat_stream_response._is_function:
|
|
thread.join()
|
|
|
|
# if it's false, return the answer (to stream)
|
|
return chat_stream_response
|
|
|
|
async def _get_async_stream_ai_response(
|
|
self, task: Task, **llm_chat_kwargs: Any
|
|
) -> StreamingAgentChatResponse:
|
|
chat_stream_response = StreamingAgentChatResponse(
|
|
achat_stream=await self._llm.astream_chat(**llm_chat_kwargs),
|
|
sources=task.extra_state["sources"],
|
|
)
|
|
# create task to write chat response to history
|
|
asyncio.create_task(
|
|
chat_stream_response.awrite_response_to_history(
|
|
task.extra_state["new_memory"]
|
|
)
|
|
)
|
|
# wait until openAI functions stop executing
|
|
await chat_stream_response._is_function_false_event.wait()
|
|
# return response stream
|
|
return chat_stream_response
|
|
|
|
def _get_agent_response(
|
|
self, task: Task, mode: ChatResponseMode, **llm_chat_kwargs: Any
|
|
) -> AGENT_CHAT_RESPONSE_TYPE:
|
|
if mode == ChatResponseMode.WAIT:
|
|
chat_response: ChatResponse = self._llm.chat(**llm_chat_kwargs)
|
|
return self._process_message(task, chat_response)
|
|
elif mode == ChatResponseMode.STREAM:
|
|
return self._get_stream_ai_response(task, **llm_chat_kwargs)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
async def _get_async_agent_response(
|
|
self, task: Task, mode: ChatResponseMode, **llm_chat_kwargs: Any
|
|
) -> AGENT_CHAT_RESPONSE_TYPE:
|
|
if mode == ChatResponseMode.WAIT:
|
|
chat_response: ChatResponse = await self._llm.achat(**llm_chat_kwargs)
|
|
return self._process_message(task, chat_response)
|
|
elif mode == ChatResponseMode.STREAM:
|
|
return await self._get_async_stream_ai_response(task, **llm_chat_kwargs)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def _call_function(
|
|
self,
|
|
tools: List[BaseTool],
|
|
tool_call: OpenAIToolCall,
|
|
memory: BaseMemory,
|
|
sources: List[ToolOutput],
|
|
) -> None:
|
|
function_call = tool_call.function
|
|
# validations to get passed mypy
|
|
assert function_call is not None
|
|
assert function_call.name is not None
|
|
assert function_call.arguments is not None
|
|
|
|
with self.callback_manager.event(
|
|
CBEventType.FUNCTION_CALL,
|
|
payload={
|
|
EventPayload.FUNCTION_CALL: function_call.arguments,
|
|
EventPayload.TOOL: get_function_by_name(
|
|
tools, function_call.name
|
|
).metadata,
|
|
},
|
|
) as event:
|
|
function_message, tool_output = call_function(
|
|
tools, tool_call, verbose=self._verbose
|
|
)
|
|
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
|
sources.append(tool_output)
|
|
memory.put(function_message)
|
|
|
|
async def _acall_function(
|
|
self,
|
|
tools: List[BaseTool],
|
|
tool_call: OpenAIToolCall,
|
|
memory: BaseMemory,
|
|
sources: List[ToolOutput],
|
|
) -> None:
|
|
function_call = tool_call.function
|
|
# validations to get passed mypy
|
|
assert function_call is not None
|
|
assert function_call.name is not None
|
|
assert function_call.arguments is not None
|
|
|
|
with self.callback_manager.event(
|
|
CBEventType.FUNCTION_CALL,
|
|
payload={
|
|
EventPayload.FUNCTION_CALL: function_call.arguments,
|
|
EventPayload.TOOL: get_function_by_name(
|
|
tools, function_call.name
|
|
).metadata,
|
|
},
|
|
) as event:
|
|
function_message, tool_output = await acall_function(
|
|
tools, tool_call, verbose=self._verbose
|
|
)
|
|
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
|
sources.append(tool_output)
|
|
memory.put(function_message)
|
|
|
|
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
|
|
"""Initialize step from task."""
|
|
sources: List[ToolOutput] = []
|
|
# temporary memory for new messages
|
|
new_memory = ChatMemoryBuffer.from_defaults()
|
|
# initialize task state
|
|
task_state = {
|
|
"sources": sources,
|
|
"n_function_calls": 0,
|
|
"new_memory": new_memory,
|
|
}
|
|
task.extra_state.update(task_state)
|
|
|
|
return TaskStep(
|
|
task_id=task.task_id,
|
|
step_id=str(uuid.uuid4()),
|
|
input=task.input,
|
|
)
|
|
|
|
def _should_continue(
|
|
self, tool_calls: Optional[List[OpenAIToolCall]], n_function_calls: int
|
|
) -> bool:
|
|
if n_function_calls > self._max_function_calls:
|
|
return False
|
|
if not tool_calls:
|
|
return False
|
|
return True
|
|
|
|
def get_tools(self, input: str) -> List[BaseTool]:
|
|
"""Get tools."""
|
|
return self._get_tools(input)
|
|
|
|
def _run_step(
|
|
self,
|
|
step: TaskStep,
|
|
task: Task,
|
|
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
|
tool_choice: Union[str, dict] = "auto",
|
|
) -> TaskStepOutput:
|
|
"""Run step."""
|
|
if step.input is not None:
|
|
add_user_step_to_memory(
|
|
step, task.extra_state["new_memory"], verbose=self._verbose
|
|
)
|
|
# TODO: see if we want to do step-based inputs
|
|
tools = self.get_tools(task.input)
|
|
openai_tools = [tool.metadata.to_openai_tool() for tool in tools]
|
|
|
|
llm_chat_kwargs = self._get_llm_chat_kwargs(task, openai_tools, tool_choice)
|
|
|
|
agent_chat_response = self._get_agent_response(
|
|
task, mode=mode, **llm_chat_kwargs
|
|
)
|
|
|
|
# TODO: implement _should_continue
|
|
latest_tool_calls = self.get_latest_tool_calls(task) or []
|
|
if not self._should_continue(
|
|
latest_tool_calls, task.extra_state["n_function_calls"]
|
|
):
|
|
is_done = True
|
|
new_steps = []
|
|
# TODO: return response
|
|
else:
|
|
is_done = False
|
|
for tool_call in latest_tool_calls:
|
|
# Some validation
|
|
if not isinstance(tool_call, get_args(OpenAIToolCall)):
|
|
raise ValueError("Invalid tool_call object")
|
|
|
|
if tool_call.type != "function":
|
|
raise ValueError("Invalid tool type. Unsupported by OpenAI")
|
|
# TODO: maybe execute this with multi-threading
|
|
self._call_function(
|
|
tools,
|
|
tool_call,
|
|
task.extra_state["new_memory"],
|
|
task.extra_state["sources"],
|
|
)
|
|
# change function call to the default value, if a custom function was given
|
|
# as an argument (none and auto are predefined by OpenAI)
|
|
if tool_choice not in ("auto", "none"):
|
|
tool_choice = "auto"
|
|
task.extra_state["n_function_calls"] += 1
|
|
new_steps = [
|
|
step.get_next_step(
|
|
step_id=str(uuid.uuid4()),
|
|
# NOTE: input is unused
|
|
input=None,
|
|
)
|
|
]
|
|
|
|
# attach next step to task
|
|
|
|
return TaskStepOutput(
|
|
output=agent_chat_response,
|
|
task_step=step,
|
|
is_last=is_done,
|
|
next_steps=new_steps,
|
|
)
|
|
|
|
async def _arun_step(
|
|
self,
|
|
step: TaskStep,
|
|
task: Task,
|
|
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
|
tool_choice: Union[str, dict] = "auto",
|
|
) -> TaskStepOutput:
|
|
"""Run step."""
|
|
if step.input is not None:
|
|
add_user_step_to_memory(
|
|
step, task.extra_state["new_memory"], verbose=self._verbose
|
|
)
|
|
|
|
# TODO: see if we want to do step-based inputs
|
|
tools = self.get_tools(task.input)
|
|
openai_tools = [tool.metadata.to_openai_tool() for tool in tools]
|
|
|
|
llm_chat_kwargs = self._get_llm_chat_kwargs(task, openai_tools, tool_choice)
|
|
agent_chat_response = await self._get_async_agent_response(
|
|
task, mode=mode, **llm_chat_kwargs
|
|
)
|
|
|
|
# TODO: implement _should_continue
|
|
latest_tool_calls = self.get_latest_tool_calls(task) or []
|
|
if not self._should_continue(
|
|
latest_tool_calls, task.extra_state["n_function_calls"]
|
|
):
|
|
is_done = True
|
|
|
|
else:
|
|
is_done = False
|
|
for tool_call in latest_tool_calls:
|
|
# Some validation
|
|
if not isinstance(tool_call, get_args(OpenAIToolCall)):
|
|
raise ValueError("Invalid tool_call object")
|
|
|
|
if tool_call.type != "function":
|
|
raise ValueError("Invalid tool type. Unsupported by OpenAI")
|
|
# TODO: maybe execute this with multi-threading
|
|
await self._acall_function(
|
|
tools,
|
|
tool_call,
|
|
task.extra_state["new_memory"],
|
|
task.extra_state["sources"],
|
|
)
|
|
# change function call to the default value, if a custom function was given
|
|
# as an argument (none and auto are predefined by OpenAI)
|
|
if tool_choice not in ("auto", "none"):
|
|
tool_choice = "auto"
|
|
task.extra_state["n_function_calls"] += 1
|
|
|
|
# generate next step, append to task queue
|
|
new_steps = (
|
|
[
|
|
step.get_next_step(
|
|
step_id=str(uuid.uuid4()),
|
|
# NOTE: input is unused
|
|
input=None,
|
|
)
|
|
]
|
|
if not is_done
|
|
else []
|
|
)
|
|
|
|
return TaskStepOutput(
|
|
output=agent_chat_response,
|
|
task_step=step,
|
|
is_last=is_done,
|
|
next_steps=new_steps,
|
|
)
|
|
|
|
@trace_method("run_step")
|
|
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
|
"""Run step."""
|
|
tool_choice = kwargs.get("tool_choice", "auto")
|
|
return self._run_step(
|
|
step, task, mode=ChatResponseMode.WAIT, tool_choice=tool_choice
|
|
)
|
|
|
|
@trace_method("run_step")
|
|
async def arun_step(
|
|
self, step: TaskStep, task: Task, **kwargs: Any
|
|
) -> TaskStepOutput:
|
|
"""Run step (async)."""
|
|
tool_choice = kwargs.get("tool_choice", "auto")
|
|
return await self._arun_step(
|
|
step, task, mode=ChatResponseMode.WAIT, tool_choice=tool_choice
|
|
)
|
|
|
|
@trace_method("run_step")
|
|
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
|
"""Run step (stream)."""
|
|
# TODO: figure out if we need a different type for TaskStepOutput
|
|
tool_choice = kwargs.get("tool_choice", "auto")
|
|
return self._run_step(
|
|
step, task, mode=ChatResponseMode.STREAM, tool_choice=tool_choice
|
|
)
|
|
|
|
@trace_method("run_step")
|
|
async def astream_step(
|
|
self, step: TaskStep, task: Task, **kwargs: Any
|
|
) -> TaskStepOutput:
|
|
"""Run step (async stream)."""
|
|
tool_choice = kwargs.get("tool_choice", "auto")
|
|
return await self._arun_step(
|
|
step, task, mode=ChatResponseMode.STREAM, tool_choice=tool_choice
|
|
)
|
|
|
|
def finalize_task(self, task: Task, **kwargs: Any) -> None:
|
|
"""Finalize task, after all the steps are completed."""
|
|
# add new messages to memory
|
|
task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
|
|
# reset new memory
|
|
task.extra_state["new_memory"].reset()
|
|
|
|
def undo_step(self, task: Task, **kwargs: Any) -> Optional[TaskStep]:
|
|
"""Undo step from task.
|
|
|
|
If this cannot be implemented, return None.
|
|
|
|
"""
|
|
raise NotImplementedError("Undo is not yet implemented")
|
|
# if len(task.completed_steps) == 0:
|
|
# return None
|
|
|
|
# # pop last step output
|
|
# last_step_output = task.completed_steps.pop()
|
|
# # add step to the front of the queue
|
|
# task.step_queue.appendleft(last_step_output.task_step)
|
|
|
|
# # undo any `step_state` variables that have changed
|
|
# last_step_output.step_state["n_function_calls"] -= 1
|
|
|
|
# # TODO: we don't have memory pop capabilities yet
|
|
# # # now pop the memory until we get to the state
|
|
# # last_step_response = cast(AgentChatResponse, last_step_output.output)
|
|
# # while last_step_response != task.memory.:
|
|
# # last_message = last_step_output.task_step.memory.pop()
|
|
# # if last_message == cast(AgentChatResponse, last_step_output.output).response:
|
|
# # break
|
|
|
|
# # while cast(AgentChatResponse, last_step_output.output).response !=
|
|
|
|
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
|
"""Set callback manager."""
|
|
# TODO: make this abstractmethod (right now will break some agent impls)
|
|
self.callback_manager = callback_manager
|