200 lines
7.2 KiB
Python
200 lines
7.2 KiB
Python
"""Context retriever agent."""
|
|
|
|
from typing import List, Optional, Type, Union
|
|
|
|
from llama_index.agent.legacy.openai_agent import (
|
|
DEFAULT_MAX_FUNCTION_CALLS,
|
|
DEFAULT_MODEL_NAME,
|
|
BaseOpenAIAgent,
|
|
)
|
|
from llama_index.callbacks import CallbackManager
|
|
from llama_index.chat_engine.types import (
|
|
AgentChatResponse,
|
|
)
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.core.llms.types import ChatMessage
|
|
from llama_index.llms.llm import LLM
|
|
from llama_index.llms.openai import OpenAI
|
|
from llama_index.llms.openai_utils import is_function_calling_model
|
|
from llama_index.memory import BaseMemory, ChatMemoryBuffer
|
|
from llama_index.prompts import PromptTemplate
|
|
from llama_index.schema import NodeWithScore
|
|
from llama_index.tools import BaseTool
|
|
from llama_index.utils import print_text
|
|
|
|
# inspired by DEFAULT_QA_PROMPT_TMPL from llama_index/prompts/default_prompts.py
|
|
DEFAULT_QA_PROMPT_TMPL = (
|
|
"Context information is below.\n"
|
|
"---------------------\n"
|
|
"{context_str}\n"
|
|
"---------------------\n"
|
|
"Given the context information and not prior knowledge, "
|
|
"either pick the corresponding tool or answer the function: {query_str}\n"
|
|
)
|
|
DEFAULT_QA_PROMPT = PromptTemplate(DEFAULT_QA_PROMPT_TMPL)
|
|
|
|
|
|
class ContextRetrieverOpenAIAgent(BaseOpenAIAgent):
|
|
"""ContextRetriever OpenAI Agent.
|
|
|
|
This agent performs retrieval from BaseRetriever before
|
|
calling the LLM. Allows it to augment user message with context.
|
|
|
|
NOTE: this is a beta feature, function interfaces might change.
|
|
|
|
Args:
|
|
tools (List[BaseTool]): A list of tools.
|
|
retriever (BaseRetriever): A retriever.
|
|
qa_prompt (Optional[PromptTemplate]): A QA prompt.
|
|
context_separator (str): A context separator.
|
|
llm (Optional[OpenAI]): An OpenAI LLM.
|
|
chat_history (Optional[List[ChatMessage]]): A chat history.
|
|
prefix_messages: List[ChatMessage]: A list of prefix messages.
|
|
verbose (bool): Whether to print debug statements.
|
|
max_function_calls (int): Maximum number of function calls.
|
|
callback_manager (Optional[CallbackManager]): A callback manager.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tools: List[BaseTool],
|
|
retriever: BaseRetriever,
|
|
qa_prompt: PromptTemplate,
|
|
context_separator: str,
|
|
llm: OpenAI,
|
|
memory: BaseMemory,
|
|
prefix_messages: List[ChatMessage],
|
|
verbose: bool = False,
|
|
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
) -> None:
|
|
super().__init__(
|
|
llm=llm,
|
|
memory=memory,
|
|
prefix_messages=prefix_messages,
|
|
verbose=verbose,
|
|
max_function_calls=max_function_calls,
|
|
callback_manager=callback_manager,
|
|
)
|
|
self._tools = tools
|
|
self._qa_prompt = qa_prompt
|
|
self._retriever = retriever
|
|
self._context_separator = context_separator
|
|
|
|
@classmethod
|
|
def from_tools_and_retriever(
|
|
cls,
|
|
tools: List[BaseTool],
|
|
retriever: BaseRetriever,
|
|
qa_prompt: Optional[PromptTemplate] = None,
|
|
context_separator: str = "\n",
|
|
llm: Optional[LLM] = None,
|
|
chat_history: Optional[List[ChatMessage]] = None,
|
|
memory: Optional[BaseMemory] = None,
|
|
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
|
|
verbose: bool = False,
|
|
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
system_prompt: Optional[str] = None,
|
|
prefix_messages: Optional[List[ChatMessage]] = None,
|
|
) -> "ContextRetrieverOpenAIAgent":
|
|
"""Create a ContextRetrieverOpenAIAgent from a retriever.
|
|
|
|
Args:
|
|
retriever (BaseRetriever): A retriever.
|
|
qa_prompt (Optional[PromptTemplate]): A QA prompt.
|
|
context_separator (str): A context separator.
|
|
llm (Optional[OpenAI]): An OpenAI LLM.
|
|
chat_history (Optional[ChatMessageHistory]): A chat history.
|
|
verbose (bool): Whether to print debug statements.
|
|
max_function_calls (int): Maximum number of function calls.
|
|
callback_manager (Optional[CallbackManager]): A callback manager.
|
|
|
|
"""
|
|
qa_prompt = qa_prompt or DEFAULT_QA_PROMPT
|
|
chat_history = chat_history or []
|
|
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
|
|
if not isinstance(llm, OpenAI):
|
|
raise ValueError("llm must be a OpenAI instance")
|
|
if callback_manager is not None:
|
|
llm.callback_manager = callback_manager
|
|
|
|
memory = memory or memory_cls.from_defaults(chat_history=chat_history, llm=llm)
|
|
|
|
if not is_function_calling_model(llm.model):
|
|
raise ValueError(
|
|
f"Model name {llm.model} does not support function calling API."
|
|
)
|
|
if system_prompt is not None:
|
|
if prefix_messages is not None:
|
|
raise ValueError(
|
|
"Cannot specify both system_prompt and prefix_messages"
|
|
)
|
|
prefix_messages = [ChatMessage(content=system_prompt, role="system")]
|
|
|
|
prefix_messages = prefix_messages or []
|
|
|
|
return cls(
|
|
tools=tools,
|
|
retriever=retriever,
|
|
qa_prompt=qa_prompt,
|
|
context_separator=context_separator,
|
|
llm=llm,
|
|
memory=memory,
|
|
prefix_messages=prefix_messages,
|
|
verbose=verbose,
|
|
max_function_calls=max_function_calls,
|
|
callback_manager=callback_manager,
|
|
)
|
|
|
|
def _get_tools(self, message: str) -> List[BaseTool]:
|
|
"""Get tools."""
|
|
return self._tools
|
|
|
|
def _build_formatted_message(self, message: str) -> str:
|
|
# augment user message
|
|
retrieved_nodes_w_scores: List[NodeWithScore] = self._retriever.retrieve(
|
|
message
|
|
)
|
|
retrieved_nodes = [node.node for node in retrieved_nodes_w_scores]
|
|
retrieved_texts = [node.get_content() for node in retrieved_nodes]
|
|
|
|
# format message
|
|
context_str = self._context_separator.join(retrieved_texts)
|
|
return self._qa_prompt.format(context_str=context_str, query_str=message)
|
|
|
|
def chat(
|
|
self,
|
|
message: str,
|
|
chat_history: Optional[List[ChatMessage]] = None,
|
|
tool_choice: Union[str, dict] = "auto",
|
|
) -> AgentChatResponse:
|
|
"""Chat."""
|
|
formatted_message = self._build_formatted_message(message)
|
|
if self._verbose:
|
|
print_text(formatted_message + "\n", color="yellow")
|
|
|
|
return super().chat(
|
|
formatted_message, chat_history=chat_history, tool_choice=tool_choice
|
|
)
|
|
|
|
async def achat(
|
|
self,
|
|
message: str,
|
|
chat_history: Optional[List[ChatMessage]] = None,
|
|
tool_choice: Union[str, dict] = "auto",
|
|
) -> AgentChatResponse:
|
|
"""Chat."""
|
|
formatted_message = self._build_formatted_message(message)
|
|
if self._verbose:
|
|
print_text(formatted_message + "\n", color="yellow")
|
|
|
|
return await super().achat(
|
|
formatted_message, chat_history=chat_history, tool_choice=tool_choice
|
|
)
|
|
|
|
def get_tools(self, message: str) -> List[BaseTool]:
|
|
"""Get tools."""
|
|
return self._get_tools(message)
|