faiss_rag_enterprise/llama_index/agent/legacy/context_retriever_agent.py

200 lines
7.2 KiB
Python

"""Context retriever agent."""
from typing import List, Optional, Type, Union
from llama_index.agent.legacy.openai_agent import (
DEFAULT_MAX_FUNCTION_CALLS,
DEFAULT_MODEL_NAME,
BaseOpenAIAgent,
)
from llama_index.callbacks import CallbackManager
from llama_index.chat_engine.types import (
AgentChatResponse,
)
from llama_index.core.base_retriever import BaseRetriever
from llama_index.core.llms.types import ChatMessage
from llama_index.llms.llm import LLM
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_utils import is_function_calling_model
from llama_index.memory import BaseMemory, ChatMemoryBuffer
from llama_index.prompts import PromptTemplate
from llama_index.schema import NodeWithScore
from llama_index.tools import BaseTool
from llama_index.utils import print_text
# inspired by DEFAULT_QA_PROMPT_TMPL from llama_index/prompts/default_prompts.py
DEFAULT_QA_PROMPT_TMPL = (
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"either pick the corresponding tool or answer the function: {query_str}\n"
)
DEFAULT_QA_PROMPT = PromptTemplate(DEFAULT_QA_PROMPT_TMPL)
class ContextRetrieverOpenAIAgent(BaseOpenAIAgent):
"""ContextRetriever OpenAI Agent.
This agent performs retrieval from BaseRetriever before
calling the LLM. Allows it to augment user message with context.
NOTE: this is a beta feature, function interfaces might change.
Args:
tools (List[BaseTool]): A list of tools.
retriever (BaseRetriever): A retriever.
qa_prompt (Optional[PromptTemplate]): A QA prompt.
context_separator (str): A context separator.
llm (Optional[OpenAI]): An OpenAI LLM.
chat_history (Optional[List[ChatMessage]]): A chat history.
prefix_messages: List[ChatMessage]: A list of prefix messages.
verbose (bool): Whether to print debug statements.
max_function_calls (int): Maximum number of function calls.
callback_manager (Optional[CallbackManager]): A callback manager.
"""
def __init__(
self,
tools: List[BaseTool],
retriever: BaseRetriever,
qa_prompt: PromptTemplate,
context_separator: str,
llm: OpenAI,
memory: BaseMemory,
prefix_messages: List[ChatMessage],
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
callback_manager: Optional[CallbackManager] = None,
) -> None:
super().__init__(
llm=llm,
memory=memory,
prefix_messages=prefix_messages,
verbose=verbose,
max_function_calls=max_function_calls,
callback_manager=callback_manager,
)
self._tools = tools
self._qa_prompt = qa_prompt
self._retriever = retriever
self._context_separator = context_separator
@classmethod
def from_tools_and_retriever(
cls,
tools: List[BaseTool],
retriever: BaseRetriever,
qa_prompt: Optional[PromptTemplate] = None,
context_separator: str = "\n",
llm: Optional[LLM] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
prefix_messages: Optional[List[ChatMessage]] = None,
) -> "ContextRetrieverOpenAIAgent":
"""Create a ContextRetrieverOpenAIAgent from a retriever.
Args:
retriever (BaseRetriever): A retriever.
qa_prompt (Optional[PromptTemplate]): A QA prompt.
context_separator (str): A context separator.
llm (Optional[OpenAI]): An OpenAI LLM.
chat_history (Optional[ChatMessageHistory]): A chat history.
verbose (bool): Whether to print debug statements.
max_function_calls (int): Maximum number of function calls.
callback_manager (Optional[CallbackManager]): A callback manager.
"""
qa_prompt = qa_prompt or DEFAULT_QA_PROMPT
chat_history = chat_history or []
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
if not isinstance(llm, OpenAI):
raise ValueError("llm must be a OpenAI instance")
if callback_manager is not None:
llm.callback_manager = callback_manager
memory = memory or memory_cls.from_defaults(chat_history=chat_history, llm=llm)
if not is_function_calling_model(llm.model):
raise ValueError(
f"Model name {llm.model} does not support function calling API."
)
if system_prompt is not None:
if prefix_messages is not None:
raise ValueError(
"Cannot specify both system_prompt and prefix_messages"
)
prefix_messages = [ChatMessage(content=system_prompt, role="system")]
prefix_messages = prefix_messages or []
return cls(
tools=tools,
retriever=retriever,
qa_prompt=qa_prompt,
context_separator=context_separator,
llm=llm,
memory=memory,
prefix_messages=prefix_messages,
verbose=verbose,
max_function_calls=max_function_calls,
callback_manager=callback_manager,
)
def _get_tools(self, message: str) -> List[BaseTool]:
"""Get tools."""
return self._tools
def _build_formatted_message(self, message: str) -> str:
# augment user message
retrieved_nodes_w_scores: List[NodeWithScore] = self._retriever.retrieve(
message
)
retrieved_nodes = [node.node for node in retrieved_nodes_w_scores]
retrieved_texts = [node.get_content() for node in retrieved_nodes]
# format message
context_str = self._context_separator.join(retrieved_texts)
return self._qa_prompt.format(context_str=context_str, query_str=message)
def chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> AgentChatResponse:
"""Chat."""
formatted_message = self._build_formatted_message(message)
if self._verbose:
print_text(formatted_message + "\n", color="yellow")
return super().chat(
formatted_message, chat_history=chat_history, tool_choice=tool_choice
)
async def achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> AgentChatResponse:
"""Chat."""
formatted_message = self._build_formatted_message(message)
if self._verbose:
print_text(formatted_message + "\n", color="yellow")
return await super().achat(
formatted_message, chat_history=chat_history, tool_choice=tool_choice
)
def get_tools(self, message: str) -> List[BaseTool]:
"""Get tools."""
return self._get_tools(message)