This commit is contained in:
hailin 2025-05-11 00:41:07 +08:00
parent 5da1f8579d
commit ad59446d14
755 changed files with 112598 additions and 1 deletions

1
llama_index/VERSION Normal file
View File

@ -0,0 +1 @@
0.9.48

168
llama_index/__init__.py Normal file
View File

@ -0,0 +1,168 @@
"""Init file of LlamaIndex."""
from pathlib import Path
with open(Path(__file__).absolute().parents[0] / "VERSION") as _f:
__version__ = _f.read().strip()
import logging
from logging import NullHandler
from typing import Callable, Optional
# import global eval handler
from llama_index.callbacks.global_handlers import set_global_handler
# response
from llama_index.core.response.schema import Response
from llama_index.data_structs.struct_type import IndexStructType
# embeddings
from llama_index.embeddings import OpenAIEmbedding
# indices
# loading
from llama_index.indices import (
ComposableGraph,
DocumentSummaryIndex,
GPTDocumentSummaryIndex,
GPTKeywordTableIndex,
GPTKnowledgeGraphIndex,
GPTListIndex,
GPTRAKEKeywordTableIndex,
GPTSimpleKeywordTableIndex,
GPTTreeIndex,
GPTVectorStoreIndex,
KeywordTableIndex,
KnowledgeGraphIndex,
ListIndex,
RAKEKeywordTableIndex,
SimpleKeywordTableIndex,
SummaryIndex,
TreeIndex,
VectorStoreIndex,
load_graph_from_storage,
load_index_from_storage,
load_indices_from_storage,
)
# structured
from llama_index.indices.common.struct_store.base import SQLDocumentContextBuilder
# prompt helper
from llama_index.indices.prompt_helper import PromptHelper
from llama_index.llm_predictor import LLMPredictor
# token predictor
from llama_index.llm_predictor.mock import MockLLMPredictor
# prompts
from llama_index.prompts import (
BasePromptTemplate,
ChatPromptTemplate,
# backwards compatibility
Prompt,
PromptTemplate,
SelectorPromptTemplate,
)
from llama_index.readers import (
SimpleDirectoryReader,
download_loader,
)
# Response Synthesizer
from llama_index.response_synthesizers.factory import get_response_synthesizer
from llama_index.schema import Document, QueryBundle
from llama_index.service_context import (
ServiceContext,
set_global_service_context,
)
# storage
from llama_index.storage.storage_context import StorageContext
from llama_index.token_counter.mock_embed_model import MockEmbedding
# sql wrapper
from llama_index.utilities.sql_wrapper import SQLDatabase
# global tokenizer
from llama_index.utils import get_tokenizer, set_global_tokenizer
# best practices for library logging:
# https://docs.python.org/3/howto/logging.html#configuring-logging-for-a-library
logging.getLogger(__name__).addHandler(NullHandler())
__all__ = [
"StorageContext",
"ServiceContext",
"ComposableGraph",
# indices
"SummaryIndex",
"VectorStoreIndex",
"SimpleKeywordTableIndex",
"KeywordTableIndex",
"RAKEKeywordTableIndex",
"TreeIndex",
"DocumentSummaryIndex",
"KnowledgeGraphIndex",
# indices - legacy names
"GPTKeywordTableIndex",
"GPTKnowledgeGraphIndex",
"GPTSimpleKeywordTableIndex",
"GPTRAKEKeywordTableIndex",
"GPTListIndex",
"ListIndex",
"GPTTreeIndex",
"GPTVectorStoreIndex",
"GPTDocumentSummaryIndex",
"Prompt",
"PromptTemplate",
"BasePromptTemplate",
"ChatPromptTemplate",
"SelectorPromptTemplate",
"OpenAIEmbedding",
"SummaryPrompt",
"TreeInsertPrompt",
"TreeSelectPrompt",
"TreeSelectMultiplePrompt",
"RefinePrompt",
"QuestionAnswerPrompt",
"KeywordExtractPrompt",
"QueryKeywordExtractPrompt",
"Response",
"Document",
"SimpleDirectoryReader",
"LLMPredictor",
"MockLLMPredictor",
"VellumPredictor",
"VellumPromptRegistry",
"MockEmbedding",
"SQLDatabase",
"SQLDocumentContextBuilder",
"SQLContextBuilder",
"PromptHelper",
"IndexStructType",
"download_loader",
"load_graph_from_storage",
"load_index_from_storage",
"load_indices_from_storage",
"QueryBundle",
"get_response_synthesizer",
"set_global_service_context",
"set_global_handler",
"set_global_tokenizer",
"get_tokenizer",
]
# eval global toggle
from llama_index.callbacks.base_handler import BaseCallbackHandler
global_handler: Optional[BaseCallbackHandler] = None
# NOTE: keep for backwards compatibility
SQLContextBuilder = SQLDocumentContextBuilder
# global service context for ServiceContext.from_defaults()
global_service_context: Optional[ServiceContext] = None
# global tokenizer
global_tokenizer: Optional[Callable[[str], list]] = None

View File

@ -0,0 +1,2 @@
# Include this file
!.gitignore

View File

@ -0,0 +1,2 @@
# Include this file
!.gitignore

View File

@ -0,0 +1,45 @@
# agent runner + agent worker
from llama_index.agent.custom.pipeline_worker import QueryPipelineAgentWorker
from llama_index.agent.custom.simple import CustomSimpleAgentWorker
from llama_index.agent.legacy.context_retriever_agent import ContextRetrieverOpenAIAgent
from llama_index.agent.legacy.openai_agent import OpenAIAgent as OldOpenAIAgent
from llama_index.agent.legacy.react.base import ReActAgent as OldReActAgent
from llama_index.agent.legacy.retriever_openai_agent import FnRetrieverOpenAIAgent
from llama_index.agent.openai.base import OpenAIAgent
from llama_index.agent.openai.step import OpenAIAgentWorker
from llama_index.agent.openai_assistant_agent import OpenAIAssistantAgent
from llama_index.agent.react.base import ReActAgent
from llama_index.agent.react.formatter import ReActChatFormatter
from llama_index.agent.react.step import ReActAgentWorker
from llama_index.agent.react_multimodal.step import MultimodalReActAgentWorker
from llama_index.agent.runner.base import AgentRunner
from llama_index.agent.runner.parallel import ParallelAgentRunner
from llama_index.agent.types import Task
from llama_index.chat_engine.types import AgentChatResponse
# for backwards compatibility
RetrieverOpenAIAgent = FnRetrieverOpenAIAgent
__all__ = [
"AgentRunner",
"ParallelAgentRunner",
"OpenAIAgentWorker",
"ReActAgentWorker",
"OpenAIAgent",
"ReActAgent",
"OpenAIAssistantAgent",
"FnRetrieverOpenAIAgent",
"RetrieverOpenAIAgent", # for backwards compatibility
"ContextRetrieverOpenAIAgent",
"CustomSimpleAgentWorker",
"QueryPipelineAgentWorker",
"ReActChatFormatter",
# beta
"MultimodalReActAgentWorker",
# schema-related
"AgentChatResponse",
"Task",
# legacy
"OldOpenAIAgent",
"OldReActAgent",
]

View File

@ -0,0 +1 @@
"""Init params."""

View File

@ -0,0 +1,199 @@
"""Agent worker that takes in a query pipeline."""
import uuid
from typing import (
Any,
List,
Optional,
cast,
)
from llama_index.agent.types import (
BaseAgentWorker,
Task,
TaskStep,
TaskStepOutput,
)
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.callbacks import (
CallbackManager,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
)
from llama_index.core.query_pipeline.query_component import QueryComponent
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
from llama_index.query_pipeline.components.agent import (
AgentFnComponent,
AgentInputComponent,
BaseAgentComponent,
)
from llama_index.query_pipeline.query import QueryPipeline
from llama_index.tools import ToolOutput
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
def _get_agent_components(query_component: QueryComponent) -> List[BaseAgentComponent]:
"""Get agent components."""
agent_components: List[BaseAgentComponent] = []
for c in query_component.sub_query_components:
if isinstance(c, BaseAgentComponent):
agent_components.append(cast(BaseAgentComponent, c))
if len(c.sub_query_components) > 0:
agent_components.extend(_get_agent_components(c))
return agent_components
class QueryPipelineAgentWorker(BaseModel, BaseAgentWorker):
"""Query Pipeline agent worker.
Barebones agent worker that takes in a query pipeline.
Assumes that the first component in the query pipeline is an
`AgentInputComponent` and last is `AgentFnComponent`.
Args:
pipeline (QueryPipeline): Query pipeline
"""
pipeline: QueryPipeline = Field(..., description="Query pipeline")
callback_manager: CallbackManager = Field(..., exclude=True)
class Config:
arbitrary_types_allowed = True
def __init__(
self,
pipeline: QueryPipeline,
callback_manager: Optional[CallbackManager] = None,
) -> None:
"""Initialize."""
if callback_manager is not None:
# set query pipeline callback
pipeline.set_callback_manager(callback_manager)
else:
callback_manager = pipeline.callback_manager
super().__init__(
pipeline=pipeline,
callback_manager=callback_manager,
)
# validate query pipeline
self.agent_input_component
self.agent_components
@property
def agent_input_component(self) -> AgentInputComponent:
"""Get agent input component."""
root_key = self.pipeline.get_root_keys()[0]
if not isinstance(self.pipeline.module_dict[root_key], AgentInputComponent):
raise ValueError(
"Query pipeline first component must be AgentInputComponent, got "
f"{self.pipeline.module_dict[root_key]}"
)
return cast(AgentInputComponent, self.pipeline.module_dict[root_key])
@property
def agent_components(self) -> List[AgentFnComponent]:
"""Get agent output component."""
return _get_agent_components(self.pipeline)
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
"""Initialize step from task."""
sources: List[ToolOutput] = []
# temporary memory for new messages
new_memory = ChatMemoryBuffer.from_defaults()
# initialize initial state
initial_state = {
"sources": sources,
"memory": new_memory,
}
return TaskStep(
task_id=task.task_id,
step_id=str(uuid.uuid4()),
input=task.input,
step_state=initial_state,
)
def _get_task_step_response(
self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
) -> TaskStepOutput:
"""Get task step response."""
if is_done:
new_steps = []
else:
new_steps = [
step.get_next_step(
step_id=str(uuid.uuid4()),
# NOTE: input is unused
input=None,
)
]
return TaskStepOutput(
output=agent_response,
task_step=step,
is_last=is_done,
next_steps=new_steps,
)
@trace_method("run_step")
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step."""
# partial agent output component with task and step
for agent_fn_component in self.agent_components:
agent_fn_component.partial(task=task, state=step.step_state)
agent_response, is_done = self.pipeline.run(state=step.step_state, task=task)
response = self._get_task_step_response(agent_response, step, is_done)
# sync step state with task state
task.extra_state.update(step.step_state)
return response
@trace_method("run_step")
async def arun_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async)."""
# partial agent output component with task and step
for agent_fn_component in self.agent_components:
agent_fn_component.partial(task=task, state=step.step_state)
agent_response, is_done = await self.pipeline.arun(
state=step.step_state, task=task
)
response = self._get_task_step_response(agent_response, step, is_done)
task.extra_state.update(step.step_state)
return response
@trace_method("run_step")
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step (stream)."""
raise NotImplementedError("This agent does not support streaming.")
@trace_method("run_step")
async def astream_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async stream)."""
raise NotImplementedError("This agent does not support streaming.")
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
# add new messages to memory
task.memory.set(task.memory.get() + task.extra_state["memory"].get_all())
# reset new memory
task.extra_state["memory"].reset()
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
self.callback_manager = callback_manager
self.pipeline.set_callback_manager(callback_manager)

View File

@ -0,0 +1,261 @@
"""Custom agent worker."""
import uuid
from abc import abstractmethod
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
cast,
)
from llama_index.agent.types import (
BaseAgentWorker,
Task,
TaskStep,
TaskStepOutput,
)
from llama_index.bridge.pydantic import BaseModel, Field, PrivateAttr
from llama_index.callbacks import (
CallbackManager,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
AgentChatResponse,
)
from llama_index.llms.llm import LLM
from llama_index.llms.openai import OpenAI
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
from llama_index.objects.base import ObjectRetriever
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
from llama_index.tools.types import AsyncBaseTool
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
class CustomSimpleAgentWorker(BaseModel, BaseAgentWorker):
"""Custom simple agent worker.
This is "simple" in the sense that some of the scaffolding is setup already.
Assumptions:
- assumes that the agent has tools, llm, callback manager, and tool retriever
- has a `from_tools` convenience function
- assumes that the agent is sequential, and doesn't take in any additional
intermediate inputs.
Args:
tools (Sequence[BaseTool]): Tools to use for reasoning
llm (LLM): LLM to use
callback_manager (CallbackManager): Callback manager
tool_retriever (Optional[ObjectRetriever[BaseTool]]): Tool retriever
verbose (bool): Whether to print out reasoning steps
"""
tools: Sequence[BaseTool] = Field(..., description="Tools to use for reasoning")
llm: LLM = Field(..., description="LLM to use")
callback_manager: CallbackManager = Field(
default_factory=lambda: CallbackManager([]), exclude=True
)
tool_retriever: Optional[ObjectRetriever[BaseTool]] = Field(
default=None, description="Tool retriever"
)
verbose: bool = Field(False, description="Whether to print out reasoning steps")
_get_tools: Callable[[str], Sequence[BaseTool]] = PrivateAttr()
class Config:
arbitrary_types_allowed = True
def __init__(
self,
tools: Sequence[BaseTool],
llm: LLM,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
) -> None:
if len(tools) > 0 and tool_retriever is not None:
raise ValueError("Cannot specify both tools and tool_retriever")
elif len(tools) > 0:
self._get_tools = lambda _: tools
elif tool_retriever is not None:
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
else:
self._get_tools = lambda _: []
super().__init__(
tools=tools,
llm=llm,
callback_manager=callback_manager,
tool_retriever=tool_retriever,
verbose=verbose,
)
@classmethod
def from_tools(
cls,
tools: Optional[Sequence[BaseTool]] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
llm: Optional[LLM] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
**kwargs: Any,
) -> "CustomSimpleAgentWorker":
"""Convenience constructor method from set of of BaseTools (Optional)."""
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
if callback_manager is not None:
llm.callback_manager = callback_manager
return cls(
tools=tools or [],
tool_retriever=tool_retriever,
llm=llm,
callback_manager=callback_manager,
verbose=verbose,
)
@abstractmethod
def _initialize_state(self, task: Task, **kwargs: Any) -> Dict[str, Any]:
"""Initialize state."""
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
"""Initialize step from task."""
sources: List[ToolOutput] = []
# temporary memory for new messages
new_memory = ChatMemoryBuffer.from_defaults()
# initialize initial state
initial_state = {
"sources": sources,
"memory": new_memory,
}
step_state = self._initialize_state(task, **kwargs)
# if intersecting keys, error
if set(step_state.keys()).intersection(set(initial_state.keys())):
raise ValueError(
f"Step state keys {step_state.keys()} and initial state keys {initial_state.keys()} intersect."
f"*NOTE*: initial state keys {initial_state.keys()} are reserved."
)
step_state.update(initial_state)
return TaskStep(
task_id=task.task_id,
step_id=str(uuid.uuid4()),
input=task.input,
step_state=step_state,
)
def get_tools(self, input: str) -> List[AsyncBaseTool]:
"""Get tools."""
return [adapt_to_async_tool(t) for t in self._get_tools(input)]
def _get_task_step_response(
self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
) -> TaskStepOutput:
"""Get task step response."""
if is_done:
new_steps = []
else:
new_steps = [
step.get_next_step(
step_id=str(uuid.uuid4()),
# NOTE: input is unused
input=None,
)
]
return TaskStepOutput(
output=agent_response,
task_step=step,
is_last=is_done,
next_steps=new_steps,
)
@abstractmethod
def _run_step(
self, state: Dict[str, Any], task: Task, input: Optional[str] = None
) -> Tuple[AgentChatResponse, bool]:
"""Run step.
Returns:
Tuple of (agent_response, is_done)
"""
async def _arun_step(
self, state: Dict[str, Any], task: Task, input: Optional[str] = None
) -> Tuple[AgentChatResponse, bool]:
"""Run step (async).
Can override this method if you want to run the step asynchronously.
Returns:
Tuple of (agent_response, is_done)
"""
raise NotImplementedError(
"This agent does not support async." "Please implement _arun_step."
)
@trace_method("run_step")
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step."""
agent_response, is_done = self._run_step(
step.step_state, task, input=step.input
)
response = self._get_task_step_response(agent_response, step, is_done)
# sync step state with task state
task.extra_state.update(step.step_state)
return response
@trace_method("run_step")
async def arun_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async)."""
agent_response, is_done = await self._arun_step(
step.step_state, task, input=step.input
)
response = self._get_task_step_response(agent_response, step, is_done)
task.extra_state.update(step.step_state)
return response
@trace_method("run_step")
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step (stream)."""
raise NotImplementedError("This agent does not support streaming.")
@trace_method("run_step")
async def astream_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async stream)."""
raise NotImplementedError("This agent does not support streaming.")
@abstractmethod
def _finalize_task(self, state: Dict[str, Any], **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed.
State is all the step states.
"""
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
# add new messages to memory
task.memory.set(task.memory.get() + task.extra_state["memory"].get_all())
# reset new memory
task.extra_state["memory"].reset()
self._finalize_task(task.extra_state, **kwargs)
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
self.callback_manager = callback_manager

View File

@ -0,0 +1 @@
"""Init params."""

View File

@ -0,0 +1,199 @@
"""Context retriever agent."""
from typing import List, Optional, Type, Union
from llama_index.agent.legacy.openai_agent import (
DEFAULT_MAX_FUNCTION_CALLS,
DEFAULT_MODEL_NAME,
BaseOpenAIAgent,
)
from llama_index.callbacks import CallbackManager
from llama_index.chat_engine.types import (
AgentChatResponse,
)
from llama_index.core.base_retriever import BaseRetriever
from llama_index.core.llms.types import ChatMessage
from llama_index.llms.llm import LLM
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_utils import is_function_calling_model
from llama_index.memory import BaseMemory, ChatMemoryBuffer
from llama_index.prompts import PromptTemplate
from llama_index.schema import NodeWithScore
from llama_index.tools import BaseTool
from llama_index.utils import print_text
# inspired by DEFAULT_QA_PROMPT_TMPL from llama_index/prompts/default_prompts.py
DEFAULT_QA_PROMPT_TMPL = (
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"either pick the corresponding tool or answer the function: {query_str}\n"
)
DEFAULT_QA_PROMPT = PromptTemplate(DEFAULT_QA_PROMPT_TMPL)
class ContextRetrieverOpenAIAgent(BaseOpenAIAgent):
"""ContextRetriever OpenAI Agent.
This agent performs retrieval from BaseRetriever before
calling the LLM. Allows it to augment user message with context.
NOTE: this is a beta feature, function interfaces might change.
Args:
tools (List[BaseTool]): A list of tools.
retriever (BaseRetriever): A retriever.
qa_prompt (Optional[PromptTemplate]): A QA prompt.
context_separator (str): A context separator.
llm (Optional[OpenAI]): An OpenAI LLM.
chat_history (Optional[List[ChatMessage]]): A chat history.
prefix_messages: List[ChatMessage]: A list of prefix messages.
verbose (bool): Whether to print debug statements.
max_function_calls (int): Maximum number of function calls.
callback_manager (Optional[CallbackManager]): A callback manager.
"""
def __init__(
self,
tools: List[BaseTool],
retriever: BaseRetriever,
qa_prompt: PromptTemplate,
context_separator: str,
llm: OpenAI,
memory: BaseMemory,
prefix_messages: List[ChatMessage],
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
callback_manager: Optional[CallbackManager] = None,
) -> None:
super().__init__(
llm=llm,
memory=memory,
prefix_messages=prefix_messages,
verbose=verbose,
max_function_calls=max_function_calls,
callback_manager=callback_manager,
)
self._tools = tools
self._qa_prompt = qa_prompt
self._retriever = retriever
self._context_separator = context_separator
@classmethod
def from_tools_and_retriever(
cls,
tools: List[BaseTool],
retriever: BaseRetriever,
qa_prompt: Optional[PromptTemplate] = None,
context_separator: str = "\n",
llm: Optional[LLM] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
prefix_messages: Optional[List[ChatMessage]] = None,
) -> "ContextRetrieverOpenAIAgent":
"""Create a ContextRetrieverOpenAIAgent from a retriever.
Args:
retriever (BaseRetriever): A retriever.
qa_prompt (Optional[PromptTemplate]): A QA prompt.
context_separator (str): A context separator.
llm (Optional[OpenAI]): An OpenAI LLM.
chat_history (Optional[ChatMessageHistory]): A chat history.
verbose (bool): Whether to print debug statements.
max_function_calls (int): Maximum number of function calls.
callback_manager (Optional[CallbackManager]): A callback manager.
"""
qa_prompt = qa_prompt or DEFAULT_QA_PROMPT
chat_history = chat_history or []
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
if not isinstance(llm, OpenAI):
raise ValueError("llm must be a OpenAI instance")
if callback_manager is not None:
llm.callback_manager = callback_manager
memory = memory or memory_cls.from_defaults(chat_history=chat_history, llm=llm)
if not is_function_calling_model(llm.model):
raise ValueError(
f"Model name {llm.model} does not support function calling API."
)
if system_prompt is not None:
if prefix_messages is not None:
raise ValueError(
"Cannot specify both system_prompt and prefix_messages"
)
prefix_messages = [ChatMessage(content=system_prompt, role="system")]
prefix_messages = prefix_messages or []
return cls(
tools=tools,
retriever=retriever,
qa_prompt=qa_prompt,
context_separator=context_separator,
llm=llm,
memory=memory,
prefix_messages=prefix_messages,
verbose=verbose,
max_function_calls=max_function_calls,
callback_manager=callback_manager,
)
def _get_tools(self, message: str) -> List[BaseTool]:
"""Get tools."""
return self._tools
def _build_formatted_message(self, message: str) -> str:
# augment user message
retrieved_nodes_w_scores: List[NodeWithScore] = self._retriever.retrieve(
message
)
retrieved_nodes = [node.node for node in retrieved_nodes_w_scores]
retrieved_texts = [node.get_content() for node in retrieved_nodes]
# format message
context_str = self._context_separator.join(retrieved_texts)
return self._qa_prompt.format(context_str=context_str, query_str=message)
def chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> AgentChatResponse:
"""Chat."""
formatted_message = self._build_formatted_message(message)
if self._verbose:
print_text(formatted_message + "\n", color="yellow")
return super().chat(
formatted_message, chat_history=chat_history, tool_choice=tool_choice
)
async def achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> AgentChatResponse:
"""Chat."""
formatted_message = self._build_formatted_message(message)
if self._verbose:
print_text(formatted_message + "\n", color="yellow")
return await super().achat(
formatted_message, chat_history=chat_history, tool_choice=tool_choice
)
def get_tools(self, message: str) -> List[BaseTool]:
"""Get tools."""
return self._get_tools(message)

View File

@ -0,0 +1,610 @@
import asyncio
import json
import logging
from abc import abstractmethod
from threading import Thread
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast, get_args
from llama_index.agent.openai.utils import get_function_by_name
from llama_index.agent.types import BaseAgent
from llama_index.callbacks import (
CallbackManager,
CBEventType,
EventPayload,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
AgentChatResponse,
ChatResponseMode,
StreamingAgentChatResponse,
)
from llama_index.core.llms.types import ChatMessage, ChatResponse, MessageRole
from llama_index.llms.llm import LLM
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_utils import OpenAIToolCall
from llama_index.memory import BaseMemory, ChatMemoryBuffer
from llama_index.objects.base import ObjectRetriever
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
DEFAULT_MAX_FUNCTION_CALLS = 5
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
def call_tool_with_error_handling(
tool: BaseTool,
input_dict: Dict,
error_message: Optional[str] = None,
raise_error: bool = False,
) -> ToolOutput:
"""Call tool with error handling.
Input is a dictionary with args and kwargs
"""
try:
return tool(**input_dict)
except Exception as e:
if raise_error:
raise
error_message = error_message or f"Error: {e!s}"
return ToolOutput(
content=error_message,
tool_name=tool.metadata.name,
raw_input={"kwargs": input_dict},
raw_output=e,
)
def call_function(
tools: List[BaseTool],
tool_call: OpenAIToolCall,
verbose: bool = False,
) -> Tuple[ChatMessage, ToolOutput]:
"""Call a function and return the output as a string."""
# validations to get passed mypy
assert tool_call.id is not None
assert tool_call.function is not None
assert tool_call.function.name is not None
assert tool_call.function.arguments is not None
id_ = tool_call.id
function_call = tool_call.function
name = tool_call.function.name
arguments_str = tool_call.function.arguments
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
tool = get_function_by_name(tools, name)
argument_dict = json.loads(arguments_str)
# Call tool
# Use default error message
output = call_tool_with_error_handling(tool, argument_dict, error_message=None)
if verbose:
print(f"Got output: {output!s}")
print("========================\n")
return (
ChatMessage(
content=str(output),
role=MessageRole.TOOL,
additional_kwargs={
"name": name,
"tool_call_id": id_,
},
),
output,
)
async def acall_function(
tools: List[BaseTool], tool_call: OpenAIToolCall, verbose: bool = False
) -> Tuple[ChatMessage, ToolOutput]:
"""Call a function and return the output as a string."""
# validations to get passed mypy
assert tool_call.id is not None
assert tool_call.function is not None
assert tool_call.function.name is not None
assert tool_call.function.arguments is not None
id_ = tool_call.id
function_call = tool_call.function
name = tool_call.function.name
arguments_str = tool_call.function.arguments
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
tool = get_function_by_name(tools, name)
async_tool = adapt_to_async_tool(tool)
argument_dict = json.loads(arguments_str)
output = await async_tool.acall(**argument_dict)
if verbose:
print(f"Got output: {output!s}")
print("========================\n")
return (
ChatMessage(
content=str(output),
role=MessageRole.TOOL,
additional_kwargs={
"name": name,
"tool_call_id": id_,
},
),
output,
)
def resolve_tool_choice(tool_choice: Union[str, dict] = "auto") -> Union[str, dict]:
"""Resolve tool choice.
If tool_choice is a function name string, return the appropriate dict.
"""
if isinstance(tool_choice, str) and tool_choice not in ["none", "auto"]:
return {"type": "function", "function": {"name": tool_choice}}
return tool_choice
class BaseOpenAIAgent(BaseAgent):
def __init__(
self,
llm: OpenAI,
memory: BaseMemory,
prefix_messages: List[ChatMessage],
verbose: bool,
max_function_calls: int,
callback_manager: Optional[CallbackManager],
):
self._llm = llm
self._verbose = verbose
self._max_function_calls = max_function_calls
self.prefix_messages = prefix_messages
self.memory = memory
self.callback_manager = callback_manager or self._llm.callback_manager
self.sources: List[ToolOutput] = []
@property
def chat_history(self) -> List[ChatMessage]:
return self.memory.get_all()
@property
def all_messages(self) -> List[ChatMessage]:
return self.prefix_messages + self.memory.get()
@property
def latest_function_call(self) -> Optional[dict]:
return self.memory.get_all()[-1].additional_kwargs.get("function_call", None)
@property
def latest_tool_calls(self) -> Optional[List[OpenAIToolCall]]:
return self.memory.get_all()[-1].additional_kwargs.get("tool_calls", None)
def reset(self) -> None:
self.memory.reset()
@abstractmethod
def get_tools(self, message: str) -> List[BaseTool]:
"""Get tools."""
def _should_continue(
self, tool_calls: Optional[List[OpenAIToolCall]], n_function_calls: int
) -> bool:
if n_function_calls > self._max_function_calls:
return False
if not tool_calls:
return False
return True
def init_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> Tuple[List[BaseTool], List[dict]]:
if chat_history is not None:
self.memory.set(chat_history)
self.sources = []
self.memory.put(ChatMessage(content=message, role=MessageRole.USER))
tools = self.get_tools(message)
openai_tools = [tool.metadata.to_openai_tool() for tool in tools]
return tools, openai_tools
def _process_message(self, chat_response: ChatResponse) -> AgentChatResponse:
ai_message = chat_response.message
self.memory.put(ai_message)
return AgentChatResponse(response=str(ai_message.content), sources=self.sources)
def _get_stream_ai_response(
self, **llm_chat_kwargs: Any
) -> StreamingAgentChatResponse:
chat_stream_response = StreamingAgentChatResponse(
chat_stream=self._llm.stream_chat(**llm_chat_kwargs),
sources=self.sources,
)
# Get the response in a separate thread so we can yield the response
thread = Thread(
target=chat_stream_response.write_response_to_history,
args=(self.memory,),
)
thread.start()
# Wait for the event to be set
chat_stream_response._is_function_not_none_thread_event.wait()
# If it is executing an openAI function, wait for the thread to finish
if chat_stream_response._is_function:
thread.join()
# if it's false, return the answer (to stream)
return chat_stream_response
async def _get_async_stream_ai_response(
self, **llm_chat_kwargs: Any
) -> StreamingAgentChatResponse:
chat_stream_response = StreamingAgentChatResponse(
achat_stream=await self._llm.astream_chat(**llm_chat_kwargs),
sources=self.sources,
)
# create task to write chat response to history
asyncio.create_task(
chat_stream_response.awrite_response_to_history(self.memory)
)
# wait until openAI functions stop executing
await chat_stream_response._is_function_false_event.wait()
# return response stream
return chat_stream_response
def _call_function(self, tools: List[BaseTool], tool_call: OpenAIToolCall) -> None:
function_call = tool_call.function
# validations to get passed mypy
assert function_call is not None
assert function_call.name is not None
assert function_call.arguments is not None
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: function_call.arguments,
EventPayload.TOOL: get_function_by_name(
tools, function_call.name
).metadata,
},
) as event:
function_message, tool_output = call_function(
tools, tool_call, verbose=self._verbose
)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
self.sources.append(tool_output)
self.memory.put(function_message)
async def _acall_function(
self, tools: List[BaseTool], tool_call: OpenAIToolCall
) -> None:
function_call = tool_call.function
# validations to get passed mypy
assert function_call is not None
assert function_call.name is not None
assert function_call.arguments is not None
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: function_call.arguments,
EventPayload.TOOL: get_function_by_name(
tools, function_call.name
).metadata,
},
) as event:
function_message, tool_output = await acall_function(
tools, tool_call, verbose=self._verbose
)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
self.sources.append(tool_output)
self.memory.put(function_message)
def _get_llm_chat_kwargs(
self, openai_tools: List[dict], tool_choice: Union[str, dict] = "auto"
) -> Dict[str, Any]:
llm_chat_kwargs: dict = {"messages": self.all_messages}
if openai_tools:
llm_chat_kwargs.update(
tools=openai_tools, tool_choice=resolve_tool_choice(tool_choice)
)
return llm_chat_kwargs
def _get_agent_response(
self, mode: ChatResponseMode, **llm_chat_kwargs: Any
) -> AGENT_CHAT_RESPONSE_TYPE:
if mode == ChatResponseMode.WAIT:
chat_response: ChatResponse = self._llm.chat(**llm_chat_kwargs)
return self._process_message(chat_response)
elif mode == ChatResponseMode.STREAM:
return self._get_stream_ai_response(**llm_chat_kwargs)
else:
raise NotImplementedError
async def _get_async_agent_response(
self, mode: ChatResponseMode, **llm_chat_kwargs: Any
) -> AGENT_CHAT_RESPONSE_TYPE:
if mode == ChatResponseMode.WAIT:
chat_response: ChatResponse = await self._llm.achat(**llm_chat_kwargs)
return self._process_message(chat_response)
elif mode == ChatResponseMode.STREAM:
return await self._get_async_stream_ai_response(**llm_chat_kwargs)
else:
raise NotImplementedError
def _chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
tools, openai_tools = self.init_chat(message, chat_history)
n_function_calls = 0
# Loop until no more function calls or max_function_calls is reached
current_tool_choice = tool_choice
ix = 0
while True:
ix += 1
if self._verbose:
print(f"STARTING TURN {ix}\n---------------\n")
llm_chat_kwargs = self._get_llm_chat_kwargs(
openai_tools, current_tool_choice
)
agent_chat_response = self._get_agent_response(mode=mode, **llm_chat_kwargs)
if not self._should_continue(self.latest_tool_calls, n_function_calls):
logger.debug("Break: should continue False")
break
# iterate through all the tool calls
logger.debug(f"Continue to tool calls: {self.latest_tool_calls}")
if self.latest_tool_calls is not None:
for tool_call in self.latest_tool_calls:
# Some validation
if not isinstance(tool_call, get_args(OpenAIToolCall)):
raise ValueError("Invalid tool_call object")
if tool_call.type != "function":
raise ValueError("Invalid tool type. Unsupported by OpenAI")
# TODO: maybe execute this with multi-threading
self._call_function(tools, tool_call)
# change function call to the default value, if a custom function was given
# as an argument (none and auto are predefined by OpenAI)
if current_tool_choice not in ("auto", "none"):
current_tool_choice = "auto"
n_function_calls += 1
return agent_chat_response
async def _achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
tools, functions = self.init_chat(message, chat_history)
n_function_calls = 0
# Loop until no more function calls or max_function_calls is reached
current_tool_choice = tool_choice
ix = 0
while True:
ix += 1
if self._verbose:
print(f"STARTING TURN {ix}\n---------------\n")
llm_chat_kwargs = self._get_llm_chat_kwargs(functions, current_tool_choice)
agent_chat_response = await self._get_async_agent_response(
mode=mode, **llm_chat_kwargs
)
if not self._should_continue(self.latest_tool_calls, n_function_calls):
break
# iterate through all the tool calls
if self.latest_tool_calls is not None:
for tool_call in self.latest_tool_calls:
# Some validation
if not isinstance(tool_call, get_args(OpenAIToolCall)):
raise ValueError("Invalid tool_call object")
if tool_call.type != "function":
raise ValueError("Invalid tool type. Unsupported by OpenAI")
# TODO: maybe execute this with multi-threading
await self._acall_function(tools, tool_call)
# change function call to the default value, if a custom function was given
# as an argument (none and auto are predefined by OpenAI)
if current_tool_choice not in ("auto", "none"):
current_tool_choice = "auto"
n_function_calls += 1
return agent_chat_response
@trace_method("chat")
def chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> AgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = self._chat(
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
async def achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> AgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = await self._achat(
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
def stream_chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> StreamingAgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = self._chat(
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
)
assert isinstance(chat_response, StreamingAgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
async def astream_chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> StreamingAgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = await self._achat(
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
)
assert isinstance(chat_response, StreamingAgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
class OpenAIAgent(BaseOpenAIAgent):
"""OpenAI (function calling) Agent.
Uses the OpenAI function API to reason about whether to
use a tool, and returning the response to the user.
Supports both a flat list of tools as well as retrieval over the tools.
Args:
tools (List[BaseTool]): List of tools to use.
llm (OpenAI): OpenAI instance.
memory (BaseMemory): Memory to use.
prefix_messages (List[ChatMessage]): Prefix messages to use.
verbose (Optional[bool]): Whether to print verbose output. Defaults to False.
max_function_calls (Optional[int]): Maximum number of function calls.
Defaults to DEFAULT_MAX_FUNCTION_CALLS.
callback_manager (Optional[CallbackManager]): Callback manager to use.
Defaults to None.
tool_retriever (ObjectRetriever[BaseTool]): Object retriever to retrieve tools.
"""
def __init__(
self,
tools: List[BaseTool],
llm: OpenAI,
memory: BaseMemory,
prefix_messages: List[ChatMessage],
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
callback_manager: Optional[CallbackManager] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
) -> None:
super().__init__(
llm=llm,
memory=memory,
prefix_messages=prefix_messages,
verbose=verbose,
max_function_calls=max_function_calls,
callback_manager=callback_manager,
)
if len(tools) > 0 and tool_retriever is not None:
raise ValueError("Cannot specify both tools and tool_retriever")
elif len(tools) > 0:
self._get_tools = lambda _: tools
elif tool_retriever is not None:
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
else:
# no tools
self._get_tools = lambda _: []
@classmethod
def from_tools(
cls,
tools: Optional[List[BaseTool]] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
llm: Optional[LLM] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
prefix_messages: Optional[List[ChatMessage]] = None,
**kwargs: Any,
) -> "OpenAIAgent":
"""Create an OpenAIAgent from a list of tools.
Similar to `from_defaults` in other classes, this method will
infer defaults for a variety of parameters, including the LLM,
if they are not specified.
"""
tools = tools or []
chat_history = chat_history or []
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
if not isinstance(llm, OpenAI):
raise ValueError("llm must be a OpenAI instance")
if callback_manager is not None:
llm.callback_manager = callback_manager
memory = memory or memory_cls.from_defaults(chat_history, llm=llm)
if not llm.metadata.is_function_calling_model:
raise ValueError(
f"Model name {llm.model} does not support function calling API. "
)
if system_prompt is not None:
if prefix_messages is not None:
raise ValueError(
"Cannot specify both system_prompt and prefix_messages"
)
prefix_messages = [ChatMessage(content=system_prompt, role="system")]
prefix_messages = prefix_messages or []
return cls(
tools=tools,
tool_retriever=tool_retriever,
llm=llm,
memory=memory,
prefix_messages=prefix_messages,
verbose=verbose,
max_function_calls=max_function_calls,
callback_manager=callback_manager,
)
def get_tools(self, message: str) -> List[BaseTool]:
"""Get tools."""
return self._get_tools(message)

View File

@ -0,0 +1 @@
"""Init params."""

View File

@ -0,0 +1,526 @@
import asyncio
from itertools import chain
from threading import Thread
from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Sequence,
Tuple,
Type,
cast,
)
from llama_index.agent.react.formatter import ReActChatFormatter
from llama_index.agent.react.output_parser import ReActOutputParser
from llama_index.agent.react.types import (
ActionReasoningStep,
BaseReasoningStep,
ObservationReasoningStep,
ResponseReasoningStep,
)
from llama_index.agent.types import BaseAgent
from llama_index.callbacks import (
CallbackManager,
CBEventType,
EventPayload,
trace_method,
)
from llama_index.chat_engine.types import AgentChatResponse, StreamingAgentChatResponse
from llama_index.core.llms.types import MessageRole
from llama_index.llms.base import ChatMessage, ChatResponse
from llama_index.llms.llm import LLM
from llama_index.llms.openai import OpenAI
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
from llama_index.memory.types import BaseMemory
from llama_index.objects.base import ObjectRetriever
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
from llama_index.tools.types import AsyncBaseTool
from llama_index.utils import print_text, unit_generator
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
class ReActAgent(BaseAgent):
"""ReAct agent.
Uses a ReAct prompt that can be used in both chat and text
completion endpoints.
Can take in a set of tools that require structured inputs.
"""
def __init__(
self,
tools: Sequence[BaseTool],
llm: LLM,
memory: BaseMemory,
max_iterations: int = 10,
react_chat_formatter: Optional[ReActChatFormatter] = None,
output_parser: Optional[ReActOutputParser] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
) -> None:
super().__init__(callback_manager=callback_manager or llm.callback_manager)
self._llm = llm
self._memory = memory
self._max_iterations = max_iterations
self._react_chat_formatter = react_chat_formatter or ReActChatFormatter()
self._output_parser = output_parser or ReActOutputParser()
self._verbose = verbose
self.sources: List[ToolOutput] = []
if len(tools) > 0 and tool_retriever is not None:
raise ValueError("Cannot specify both tools and tool_retriever")
elif len(tools) > 0:
self._get_tools = lambda _: tools
elif tool_retriever is not None:
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
else:
self._get_tools = lambda _: []
@classmethod
def from_tools(
cls,
tools: Optional[List[BaseTool]] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
llm: Optional[LLM] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
max_iterations: int = 10,
react_chat_formatter: Optional[ReActChatFormatter] = None,
output_parser: Optional[ReActOutputParser] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
**kwargs: Any,
) -> "ReActAgent":
"""Convenience constructor method from set of of BaseTools (Optional).
NOTE: kwargs should have been exhausted by this point. In other words
the various upstream components such as BaseSynthesizer (response synthesizer)
or BaseRetriever should have picked up off their respective kwargs in their
constructions.
Returns:
ReActAgent
"""
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
if callback_manager is not None:
llm.callback_manager = callback_manager
memory = memory or memory_cls.from_defaults(
chat_history=chat_history or [], llm=llm
)
return cls(
tools=tools or [],
tool_retriever=tool_retriever,
llm=llm,
memory=memory,
max_iterations=max_iterations,
react_chat_formatter=react_chat_formatter,
output_parser=output_parser,
callback_manager=callback_manager,
verbose=verbose,
)
@property
def chat_history(self) -> List[ChatMessage]:
"""Chat history."""
return self._memory.get_all()
def reset(self) -> None:
self._memory.reset()
def _extract_reasoning_step(
self, output: ChatResponse, is_streaming: bool = False
) -> Tuple[str, List[BaseReasoningStep], bool]:
"""
Extracts the reasoning step from the given output.
This method parses the message content from the output,
extracts the reasoning step, and determines whether the processing is
complete. It also performs validation checks on the output and
handles possible errors.
"""
if output.message.content is None:
raise ValueError("Got empty message.")
message_content = output.message.content
current_reasoning = []
try:
reasoning_step = self._output_parser.parse(message_content, is_streaming)
except BaseException as exc:
raise ValueError(f"Could not parse output: {message_content}") from exc
if self._verbose:
print_text(f"{reasoning_step.get_content()}\n", color="pink")
current_reasoning.append(reasoning_step)
if reasoning_step.is_done:
return message_content, current_reasoning, True
reasoning_step = cast(ActionReasoningStep, reasoning_step)
if not isinstance(reasoning_step, ActionReasoningStep):
raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}")
return message_content, current_reasoning, False
def _process_actions(
self,
tools: Sequence[AsyncBaseTool],
output: ChatResponse,
is_streaming: bool = False,
) -> Tuple[List[BaseReasoningStep], bool]:
tools_dict: Dict[str, AsyncBaseTool] = {
tool.metadata.get_name(): tool for tool in tools
}
_, current_reasoning, is_done = self._extract_reasoning_step(
output, is_streaming
)
if is_done:
return current_reasoning, True
# call tool with input
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
tool = tools_dict[reasoning_step.action]
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
EventPayload.TOOL: tool.metadata,
},
) as event:
tool_output = tool.call(**reasoning_step.action_input)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
self.sources.append(tool_output)
observation_step = ObservationReasoningStep(observation=str(tool_output))
current_reasoning.append(observation_step)
if self._verbose:
print_text(f"{observation_step.get_content()}\n", color="blue")
return current_reasoning, False
async def _aprocess_actions(
self,
tools: Sequence[AsyncBaseTool],
output: ChatResponse,
is_streaming: bool = False,
) -> Tuple[List[BaseReasoningStep], bool]:
tools_dict = {tool.metadata.name: tool for tool in tools}
_, current_reasoning, is_done = self._extract_reasoning_step(
output, is_streaming
)
if is_done:
return current_reasoning, True
# call tool with input
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
tool = tools_dict[reasoning_step.action]
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
EventPayload.TOOL: tool.metadata,
},
) as event:
tool_output = await tool.acall(**reasoning_step.action_input)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
self.sources.append(tool_output)
observation_step = ObservationReasoningStep(observation=str(tool_output))
current_reasoning.append(observation_step)
if self._verbose:
print_text(f"{observation_step.get_content()}\n", color="blue")
return current_reasoning, False
def _get_response(
self,
current_reasoning: List[BaseReasoningStep],
) -> AgentChatResponse:
"""Get response from reasoning steps."""
if len(current_reasoning) == 0:
raise ValueError("No reasoning steps were taken.")
elif len(current_reasoning) == self._max_iterations:
raise ValueError("Reached max iterations.")
response_step = cast(ResponseReasoningStep, current_reasoning[-1])
# TODO: add sources from reasoning steps
return AgentChatResponse(response=response_step.response, sources=self.sources)
def _infer_stream_chunk_is_final(self, chunk: ChatResponse) -> bool:
"""Infers if a chunk from a live stream is the start of the final
reasoning step. (i.e., and should eventually become
ResponseReasoningStep not part of this function's logic tho.).
Args:
chunk (ChatResponse): the current chunk stream to check
Returns:
bool: Boolean on whether the chunk is the start of the final response
"""
latest_content = chunk.message.content
if latest_content:
if not latest_content.startswith(
"Thought"
): # doesn't follow thought-action format
return True
else:
if "Answer: " in latest_content:
return True
return False
def _add_back_chunk_to_stream(
self, chunk: ChatResponse, chat_stream: Generator[ChatResponse, None, None]
) -> Generator[ChatResponse, None, None]:
"""Helper method for adding back initial chunk stream of final response
back to the rest of the chat_stream.
Args:
chunk (ChatResponse): the chunk to add back to the beginning of the
chat_stream.
Return:
Generator[ChatResponse, None, None]: the updated chat_stream
"""
updated_stream = chain.from_iterable( # need to add back partial response chunk
[
unit_generator(chunk),
chat_stream,
]
)
# use cast to avoid mypy issue with chain and Generator
updated_stream_c: Generator[ChatResponse, None, None] = cast(
Generator[ChatResponse, None, None], updated_stream
)
return updated_stream_c
async def _async_add_back_chunk_to_stream(
self, chunk: ChatResponse, chat_stream: AsyncGenerator[ChatResponse, None]
) -> AsyncGenerator[ChatResponse, None]:
"""Helper method for adding back initial chunk stream of final response
back to the rest of the chat_stream.
NOTE: this itself is not an async function.
Args:
chunk (ChatResponse): the chunk to add back to the beginning of the
chat_stream.
Return:
AsyncGenerator[ChatResponse, None]: the updated async chat_stream
"""
yield chunk
async for item in chat_stream:
yield item
@trace_method("chat")
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
"""Chat."""
# get tools
# TODO: do get tools dynamically at every iteration of the agent loop
self.sources = []
tools = self.get_tools(message)
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
current_reasoning: List[BaseReasoningStep] = []
# start loop
for _ in range(self._max_iterations):
# prepare inputs
input_chat = self._react_chat_formatter.format(
tools,
chat_history=self._memory.get(),
current_reasoning=current_reasoning,
)
# send prompt
chat_response = self._llm.chat(input_chat)
# given react prompt outputs, call tools or return response
reasoning_steps, is_done = self._process_actions(
tools, output=chat_response
)
current_reasoning.extend(reasoning_steps)
if is_done:
break
response = self._get_response(current_reasoning)
self._memory.put(
ChatMessage(content=response.response, role=MessageRole.ASSISTANT)
)
return response
@trace_method("chat")
async def achat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
# get tools
# TODO: do get tools dynamically at every iteration of the agent loop
self.sources = []
tools = self.get_tools(message)
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
current_reasoning: List[BaseReasoningStep] = []
# start loop
for _ in range(self._max_iterations):
# prepare inputs
input_chat = self._react_chat_formatter.format(
tools,
chat_history=self._memory.get(),
current_reasoning=current_reasoning,
)
# send prompt
chat_response = await self._llm.achat(input_chat)
# given react prompt outputs, call tools or return response
reasoning_steps, is_done = await self._aprocess_actions(
tools, output=chat_response
)
current_reasoning.extend(reasoning_steps)
if is_done:
break
response = self._get_response(current_reasoning)
self._memory.put(
ChatMessage(content=response.response, role=MessageRole.ASSISTANT)
)
return response
@trace_method("chat")
def stream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
# get tools
# TODO: do get tools dynamically at every iteration of the agent loop
self.sources = []
tools = self.get_tools(message)
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
current_reasoning: List[BaseReasoningStep] = []
# start loop
is_done, ix = False, 0
while (not is_done) and (ix < self._max_iterations):
ix += 1
# prepare inputs
input_chat = self._react_chat_formatter.format(
tools,
chat_history=self._memory.get(),
current_reasoning=current_reasoning,
)
# send prompt
chat_stream = self._llm.stream_chat(input_chat)
# iterate over stream, break out if is final answer after the "Answer: "
full_response = ChatResponse(
message=ChatMessage(content=None, role="assistant")
)
for latest_chunk in chat_stream:
full_response = latest_chunk
is_done = self._infer_stream_chunk_is_final(latest_chunk)
if is_done:
break
# given react prompt outputs, call tools or return response
reasoning_steps, _ = self._process_actions(
tools=tools, output=full_response, is_streaming=True
)
current_reasoning.extend(reasoning_steps)
# Get the response in a separate thread so we can yield the response
response_stream = self._add_back_chunk_to_stream(
chunk=latest_chunk, chat_stream=chat_stream
)
chat_stream_response = StreamingAgentChatResponse(
chat_stream=response_stream,
sources=self.sources,
)
thread = Thread(
target=chat_stream_response.write_response_to_history,
args=(self._memory,),
)
thread.start()
return chat_stream_response
@trace_method("chat")
async def astream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
# get tools
# TODO: do get tools dynamically at every iteration of the agent loop
self.sources = []
tools = self.get_tools(message)
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
current_reasoning: List[BaseReasoningStep] = []
# start loop
is_done, ix = False, 0
while (not is_done) and (ix < self._max_iterations):
ix += 1
# prepare inputs
input_chat = self._react_chat_formatter.format(
tools,
chat_history=self._memory.get(),
current_reasoning=current_reasoning,
)
# send prompt
chat_stream = await self._llm.astream_chat(input_chat)
# iterate over stream, break out if is final answer
is_done = False
full_response = ChatResponse(
message=ChatMessage(content=None, role="assistant")
)
async for latest_chunk in chat_stream:
full_response = latest_chunk
is_done = self._infer_stream_chunk_is_final(latest_chunk)
if is_done:
break
# given react prompt outputs, call tools or return response
reasoning_steps, _ = self._process_actions(
tools=tools, output=full_response, is_streaming=True
)
current_reasoning.extend(reasoning_steps)
# Get the response in a separate thread so we can yield the response
response_stream = self._async_add_back_chunk_to_stream(
chunk=latest_chunk, chat_stream=chat_stream
)
chat_stream_response = StreamingAgentChatResponse(
achat_stream=response_stream, sources=self.sources
)
# create task to write chat response to history
asyncio.create_task(
chat_stream_response.awrite_response_to_history(self._memory)
)
# thread.start()
return chat_stream_response
def get_tools(self, message: str) -> List[AsyncBaseTool]:
"""Get tools."""
return [adapt_to_async_tool(t) for t in self._get_tools(message)]

View File

@ -0,0 +1,31 @@
"""Retriever OpenAI agent."""
from typing import Any, cast
from llama_index.agent.legacy.openai_agent import (
OpenAIAgent,
)
from llama_index.objects.base import ObjectRetriever
from llama_index.tools.types import BaseTool
class FnRetrieverOpenAIAgent(OpenAIAgent):
"""Function Retriever OpenAI Agent.
Uses our object retriever module to retrieve openai agent.
NOTE: This is deprecated, you can just use the base `OpenAIAgent` class by
specifying the following:
```
agent = OpenAIAgent.from_tools(tool_retriever=retriever, ...)
```
"""
@classmethod
def from_retriever(
cls, retriever: ObjectRetriever[BaseTool], **kwargs: Any
) -> "FnRetrieverOpenAIAgent":
return cast(
FnRetrieverOpenAIAgent, cls.from_tools(tool_retriever=retriever, **kwargs)
)

View File

View File

@ -0,0 +1,140 @@
"""OpenAI Agent.
Simple wrapper around AgentRunner + OpenAIAgentWorker.
For the legacy implementation see:
```python
from llama_index.agent.legacy.openai.base import OpenAIAgent
```
"""
from typing import (
Any,
List,
Optional,
Type,
)
from llama_index.agent.openai.step import OpenAIAgentWorker
from llama_index.agent.runner.base import AgentRunner
from llama_index.callbacks import (
CallbackManager,
)
from llama_index.llms.base import ChatMessage
from llama_index.llms.llm import LLM
from llama_index.llms.openai import OpenAI
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
from llama_index.memory.types import BaseMemory
from llama_index.objects.base import ObjectRetriever
from llama_index.tools import BaseTool
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
DEFAULT_MAX_FUNCTION_CALLS = 5
class OpenAIAgent(AgentRunner):
"""OpenAI agent.
Subclasses AgentRunner with a OpenAIAgentWorker.
For the legacy implementation see:
```python
from llama_index.agent.legacy.openai.base import OpenAIAgent
```
"""
def __init__(
self,
tools: List[BaseTool],
llm: OpenAI,
memory: BaseMemory,
prefix_messages: List[ChatMessage],
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
default_tool_choice: str = "auto",
callback_manager: Optional[CallbackManager] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
) -> None:
"""Init params."""
callback_manager = callback_manager or llm.callback_manager
step_engine = OpenAIAgentWorker.from_tools(
tools=tools,
tool_retriever=tool_retriever,
llm=llm,
verbose=verbose,
max_function_calls=max_function_calls,
callback_manager=callback_manager,
prefix_messages=prefix_messages,
)
super().__init__(
step_engine,
memory=memory,
llm=llm,
callback_manager=callback_manager,
default_tool_choice=default_tool_choice,
)
@classmethod
def from_tools(
cls,
tools: Optional[List[BaseTool]] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
llm: Optional[LLM] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
default_tool_choice: str = "auto",
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
prefix_messages: Optional[List[ChatMessage]] = None,
**kwargs: Any,
) -> "OpenAIAgent":
"""Create an OpenAIAgent from a list of tools.
Similar to `from_defaults` in other classes, this method will
infer defaults for a variety of parameters, including the LLM,
if they are not specified.
"""
tools = tools or []
chat_history = chat_history or []
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
if not isinstance(llm, OpenAI):
raise ValueError("llm must be a OpenAI instance")
if callback_manager is not None:
llm.callback_manager = callback_manager
memory = memory or memory_cls.from_defaults(chat_history, llm=llm)
if not llm.metadata.is_function_calling_model:
raise ValueError(
f"Model name {llm.model} does not support function calling API. "
)
if system_prompt is not None:
if prefix_messages is not None:
raise ValueError(
"Cannot specify both system_prompt and prefix_messages"
)
prefix_messages = [ChatMessage(content=system_prompt, role="system")]
prefix_messages = prefix_messages or []
return cls(
tools=tools,
tool_retriever=tool_retriever,
llm=llm,
memory=memory,
prefix_messages=prefix_messages,
verbose=verbose,
max_function_calls=max_function_calls,
callback_manager=callback_manager,
default_tool_choice=default_tool_choice,
)

View File

@ -0,0 +1,644 @@
"""OpenAI agent worker."""
import asyncio
import json
import logging
import uuid
from threading import Thread
from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args
from llama_index.agent.openai.utils import resolve_tool_choice
from llama_index.agent.types import (
BaseAgentWorker,
Task,
TaskStep,
TaskStepOutput,
)
from llama_index.agent.utils import add_user_step_to_memory
from llama_index.callbacks import (
CallbackManager,
CBEventType,
EventPayload,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
AgentChatResponse,
ChatResponseMode,
StreamingAgentChatResponse,
)
from llama_index.core.llms.types import MessageRole
from llama_index.llms.base import ChatMessage, ChatResponse
from llama_index.llms.llm import LLM
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_utils import OpenAIToolCall
from llama_index.memory import BaseMemory, ChatMemoryBuffer
from llama_index.memory.types import BaseMemory
from llama_index.objects.base import ObjectRetriever
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
DEFAULT_MAX_FUNCTION_CALLS = 5
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
def get_function_by_name(tools: List[BaseTool], name: str) -> BaseTool:
"""Get function by name."""
name_to_tool = {tool.metadata.name: tool for tool in tools}
if name not in name_to_tool:
raise ValueError(f"Tool with name {name} not found")
return name_to_tool[name]
def call_tool_with_error_handling(
tool: BaseTool,
input_dict: Dict,
error_message: Optional[str] = None,
raise_error: bool = False,
) -> ToolOutput:
"""Call tool with error handling.
Input is a dictionary with args and kwargs
"""
try:
return tool(**input_dict)
except Exception as e:
if raise_error:
raise
error_message = error_message or f"Error: {e!s}"
return ToolOutput(
content=error_message,
tool_name=tool.metadata.name,
raw_input={"kwargs": input_dict},
raw_output=e,
)
def call_function(
tools: List[BaseTool],
tool_call: OpenAIToolCall,
verbose: bool = False,
) -> Tuple[ChatMessage, ToolOutput]:
"""Call a function and return the output as a string."""
# validations to get passed mypy
assert tool_call.id is not None
assert tool_call.function is not None
assert tool_call.function.name is not None
assert tool_call.function.arguments is not None
id_ = tool_call.id
function_call = tool_call.function
name = tool_call.function.name
arguments_str = tool_call.function.arguments
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
tool = get_function_by_name(tools, name)
argument_dict = json.loads(arguments_str)
# Call tool
# Use default error message
output = call_tool_with_error_handling(tool, argument_dict, error_message=None)
if verbose:
print(f"Got output: {output!s}")
print("========================\n")
return (
ChatMessage(
content=str(output),
role=MessageRole.TOOL,
additional_kwargs={
"name": name,
"tool_call_id": id_,
},
),
output,
)
async def acall_function(
tools: List[BaseTool], tool_call: OpenAIToolCall, verbose: bool = False
) -> Tuple[ChatMessage, ToolOutput]:
"""Call a function and return the output as a string."""
# validations to get passed mypy
assert tool_call.id is not None
assert tool_call.function is not None
assert tool_call.function.name is not None
assert tool_call.function.arguments is not None
id_ = tool_call.id
function_call = tool_call.function
name = tool_call.function.name
arguments_str = tool_call.function.arguments
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
tool = get_function_by_name(tools, name)
async_tool = adapt_to_async_tool(tool)
argument_dict = json.loads(arguments_str)
output = await async_tool.acall(**argument_dict)
if verbose:
print(f"Got output: {output!s}")
print("========================\n")
return (
ChatMessage(
content=str(output),
role=MessageRole.TOOL,
additional_kwargs={
"name": name,
"tool_call_id": id_,
},
),
output,
)
class OpenAIAgentWorker(BaseAgentWorker):
"""OpenAI Agent agent worker."""
def __init__(
self,
tools: List[BaseTool],
llm: OpenAI,
prefix_messages: List[ChatMessage],
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
callback_manager: Optional[CallbackManager] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
):
self._llm = llm
self._verbose = verbose
self._max_function_calls = max_function_calls
self.prefix_messages = prefix_messages
self.callback_manager = callback_manager or self._llm.callback_manager
if len(tools) > 0 and tool_retriever is not None:
raise ValueError("Cannot specify both tools and tool_retriever")
elif len(tools) > 0:
self._get_tools = lambda _: tools
elif tool_retriever is not None:
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
else:
# no tools
self._get_tools = lambda _: []
@classmethod
def from_tools(
cls,
tools: Optional[List[BaseTool]] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
llm: Optional[LLM] = None,
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
prefix_messages: Optional[List[ChatMessage]] = None,
**kwargs: Any,
) -> "OpenAIAgentWorker":
"""Create an OpenAIAgent from a list of tools.
Similar to `from_defaults` in other classes, this method will
infer defaults for a variety of parameters, including the LLM,
if they are not specified.
"""
tools = tools or []
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
if not isinstance(llm, OpenAI):
raise ValueError("llm must be a OpenAI instance")
if callback_manager is not None:
llm.callback_manager = callback_manager
if not llm.metadata.is_function_calling_model:
raise ValueError(
f"Model name {llm.model} does not support function calling API. "
)
if system_prompt is not None:
if prefix_messages is not None:
raise ValueError(
"Cannot specify both system_prompt and prefix_messages"
)
prefix_messages = [ChatMessage(content=system_prompt, role="system")]
prefix_messages = prefix_messages or []
return cls(
tools=tools,
tool_retriever=tool_retriever,
llm=llm,
prefix_messages=prefix_messages,
verbose=verbose,
max_function_calls=max_function_calls,
callback_manager=callback_manager,
)
def get_all_messages(self, task: Task) -> List[ChatMessage]:
return (
self.prefix_messages
+ task.memory.get()
+ task.extra_state["new_memory"].get_all()
)
def get_latest_tool_calls(self, task: Task) -> Optional[List[OpenAIToolCall]]:
chat_history: List[ChatMessage] = task.extra_state["new_memory"].get_all()
return (
chat_history[-1].additional_kwargs.get("tool_calls", None)
if chat_history
else None
)
def _get_llm_chat_kwargs(
self,
task: Task,
openai_tools: List[dict],
tool_choice: Union[str, dict] = "auto",
) -> Dict[str, Any]:
llm_chat_kwargs: dict = {"messages": self.get_all_messages(task)}
if openai_tools:
llm_chat_kwargs.update(
tools=openai_tools, tool_choice=resolve_tool_choice(tool_choice)
)
return llm_chat_kwargs
def _process_message(
self, task: Task, chat_response: ChatResponse
) -> AgentChatResponse:
ai_message = chat_response.message
task.extra_state["new_memory"].put(ai_message)
return AgentChatResponse(
response=str(ai_message.content), sources=task.extra_state["sources"]
)
def _get_stream_ai_response(
self, task: Task, **llm_chat_kwargs: Any
) -> StreamingAgentChatResponse:
chat_stream_response = StreamingAgentChatResponse(
chat_stream=self._llm.stream_chat(**llm_chat_kwargs),
sources=task.extra_state["sources"],
)
# Get the response in a separate thread so we can yield the response
thread = Thread(
target=chat_stream_response.write_response_to_history,
args=(task.extra_state["new_memory"],),
)
thread.start()
# Wait for the event to be set
chat_stream_response._is_function_not_none_thread_event.wait()
# If it is executing an openAI function, wait for the thread to finish
if chat_stream_response._is_function:
thread.join()
# if it's false, return the answer (to stream)
return chat_stream_response
async def _get_async_stream_ai_response(
self, task: Task, **llm_chat_kwargs: Any
) -> StreamingAgentChatResponse:
chat_stream_response = StreamingAgentChatResponse(
achat_stream=await self._llm.astream_chat(**llm_chat_kwargs),
sources=task.extra_state["sources"],
)
# create task to write chat response to history
asyncio.create_task(
chat_stream_response.awrite_response_to_history(
task.extra_state["new_memory"]
)
)
# wait until openAI functions stop executing
await chat_stream_response._is_function_false_event.wait()
# return response stream
return chat_stream_response
def _get_agent_response(
self, task: Task, mode: ChatResponseMode, **llm_chat_kwargs: Any
) -> AGENT_CHAT_RESPONSE_TYPE:
if mode == ChatResponseMode.WAIT:
chat_response: ChatResponse = self._llm.chat(**llm_chat_kwargs)
return self._process_message(task, chat_response)
elif mode == ChatResponseMode.STREAM:
return self._get_stream_ai_response(task, **llm_chat_kwargs)
else:
raise NotImplementedError
async def _get_async_agent_response(
self, task: Task, mode: ChatResponseMode, **llm_chat_kwargs: Any
) -> AGENT_CHAT_RESPONSE_TYPE:
if mode == ChatResponseMode.WAIT:
chat_response: ChatResponse = await self._llm.achat(**llm_chat_kwargs)
return self._process_message(task, chat_response)
elif mode == ChatResponseMode.STREAM:
return await self._get_async_stream_ai_response(task, **llm_chat_kwargs)
else:
raise NotImplementedError
def _call_function(
self,
tools: List[BaseTool],
tool_call: OpenAIToolCall,
memory: BaseMemory,
sources: List[ToolOutput],
) -> None:
function_call = tool_call.function
# validations to get passed mypy
assert function_call is not None
assert function_call.name is not None
assert function_call.arguments is not None
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: function_call.arguments,
EventPayload.TOOL: get_function_by_name(
tools, function_call.name
).metadata,
},
) as event:
function_message, tool_output = call_function(
tools, tool_call, verbose=self._verbose
)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
sources.append(tool_output)
memory.put(function_message)
async def _acall_function(
self,
tools: List[BaseTool],
tool_call: OpenAIToolCall,
memory: BaseMemory,
sources: List[ToolOutput],
) -> None:
function_call = tool_call.function
# validations to get passed mypy
assert function_call is not None
assert function_call.name is not None
assert function_call.arguments is not None
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: function_call.arguments,
EventPayload.TOOL: get_function_by_name(
tools, function_call.name
).metadata,
},
) as event:
function_message, tool_output = await acall_function(
tools, tool_call, verbose=self._verbose
)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
sources.append(tool_output)
memory.put(function_message)
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
"""Initialize step from task."""
sources: List[ToolOutput] = []
# temporary memory for new messages
new_memory = ChatMemoryBuffer.from_defaults()
# initialize task state
task_state = {
"sources": sources,
"n_function_calls": 0,
"new_memory": new_memory,
}
task.extra_state.update(task_state)
return TaskStep(
task_id=task.task_id,
step_id=str(uuid.uuid4()),
input=task.input,
)
def _should_continue(
self, tool_calls: Optional[List[OpenAIToolCall]], n_function_calls: int
) -> bool:
if n_function_calls > self._max_function_calls:
return False
if not tool_calls:
return False
return True
def get_tools(self, input: str) -> List[BaseTool]:
"""Get tools."""
return self._get_tools(input)
def _run_step(
self,
step: TaskStep,
task: Task,
mode: ChatResponseMode = ChatResponseMode.WAIT,
tool_choice: Union[str, dict] = "auto",
) -> TaskStepOutput:
"""Run step."""
if step.input is not None:
add_user_step_to_memory(
step, task.extra_state["new_memory"], verbose=self._verbose
)
# TODO: see if we want to do step-based inputs
tools = self.get_tools(task.input)
openai_tools = [tool.metadata.to_openai_tool() for tool in tools]
llm_chat_kwargs = self._get_llm_chat_kwargs(task, openai_tools, tool_choice)
agent_chat_response = self._get_agent_response(
task, mode=mode, **llm_chat_kwargs
)
# TODO: implement _should_continue
latest_tool_calls = self.get_latest_tool_calls(task) or []
if not self._should_continue(
latest_tool_calls, task.extra_state["n_function_calls"]
):
is_done = True
new_steps = []
# TODO: return response
else:
is_done = False
for tool_call in latest_tool_calls:
# Some validation
if not isinstance(tool_call, get_args(OpenAIToolCall)):
raise ValueError("Invalid tool_call object")
if tool_call.type != "function":
raise ValueError("Invalid tool type. Unsupported by OpenAI")
# TODO: maybe execute this with multi-threading
self._call_function(
tools,
tool_call,
task.extra_state["new_memory"],
task.extra_state["sources"],
)
# change function call to the default value, if a custom function was given
# as an argument (none and auto are predefined by OpenAI)
if tool_choice not in ("auto", "none"):
tool_choice = "auto"
task.extra_state["n_function_calls"] += 1
new_steps = [
step.get_next_step(
step_id=str(uuid.uuid4()),
# NOTE: input is unused
input=None,
)
]
# attach next step to task
return TaskStepOutput(
output=agent_chat_response,
task_step=step,
is_last=is_done,
next_steps=new_steps,
)
async def _arun_step(
self,
step: TaskStep,
task: Task,
mode: ChatResponseMode = ChatResponseMode.WAIT,
tool_choice: Union[str, dict] = "auto",
) -> TaskStepOutput:
"""Run step."""
if step.input is not None:
add_user_step_to_memory(
step, task.extra_state["new_memory"], verbose=self._verbose
)
# TODO: see if we want to do step-based inputs
tools = self.get_tools(task.input)
openai_tools = [tool.metadata.to_openai_tool() for tool in tools]
llm_chat_kwargs = self._get_llm_chat_kwargs(task, openai_tools, tool_choice)
agent_chat_response = await self._get_async_agent_response(
task, mode=mode, **llm_chat_kwargs
)
# TODO: implement _should_continue
latest_tool_calls = self.get_latest_tool_calls(task) or []
if not self._should_continue(
latest_tool_calls, task.extra_state["n_function_calls"]
):
is_done = True
else:
is_done = False
for tool_call in latest_tool_calls:
# Some validation
if not isinstance(tool_call, get_args(OpenAIToolCall)):
raise ValueError("Invalid tool_call object")
if tool_call.type != "function":
raise ValueError("Invalid tool type. Unsupported by OpenAI")
# TODO: maybe execute this with multi-threading
await self._acall_function(
tools,
tool_call,
task.extra_state["new_memory"],
task.extra_state["sources"],
)
# change function call to the default value, if a custom function was given
# as an argument (none and auto are predefined by OpenAI)
if tool_choice not in ("auto", "none"):
tool_choice = "auto"
task.extra_state["n_function_calls"] += 1
# generate next step, append to task queue
new_steps = (
[
step.get_next_step(
step_id=str(uuid.uuid4()),
# NOTE: input is unused
input=None,
)
]
if not is_done
else []
)
return TaskStepOutput(
output=agent_chat_response,
task_step=step,
is_last=is_done,
next_steps=new_steps,
)
@trace_method("run_step")
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step."""
tool_choice = kwargs.get("tool_choice", "auto")
return self._run_step(
step, task, mode=ChatResponseMode.WAIT, tool_choice=tool_choice
)
@trace_method("run_step")
async def arun_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async)."""
tool_choice = kwargs.get("tool_choice", "auto")
return await self._arun_step(
step, task, mode=ChatResponseMode.WAIT, tool_choice=tool_choice
)
@trace_method("run_step")
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step (stream)."""
# TODO: figure out if we need a different type for TaskStepOutput
tool_choice = kwargs.get("tool_choice", "auto")
return self._run_step(
step, task, mode=ChatResponseMode.STREAM, tool_choice=tool_choice
)
@trace_method("run_step")
async def astream_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async stream)."""
tool_choice = kwargs.get("tool_choice", "auto")
return await self._arun_step(
step, task, mode=ChatResponseMode.STREAM, tool_choice=tool_choice
)
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
# add new messages to memory
task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
# reset new memory
task.extra_state["new_memory"].reset()
def undo_step(self, task: Task, **kwargs: Any) -> Optional[TaskStep]:
"""Undo step from task.
If this cannot be implemented, return None.
"""
raise NotImplementedError("Undo is not yet implemented")
# if len(task.completed_steps) == 0:
# return None
# # pop last step output
# last_step_output = task.completed_steps.pop()
# # add step to the front of the queue
# task.step_queue.appendleft(last_step_output.task_step)
# # undo any `step_state` variables that have changed
# last_step_output.step_state["n_function_calls"] -= 1
# # TODO: we don't have memory pop capabilities yet
# # # now pop the memory until we get to the state
# # last_step_response = cast(AgentChatResponse, last_step_output.output)
# # while last_step_response != task.memory.:
# # last_message = last_step_output.task_step.memory.pop()
# # if last_message == cast(AgentChatResponse, last_step_output.output).response:
# # break
# # while cast(AgentChatResponse, last_step_output.output).response !=
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
self.callback_manager = callback_manager

View File

@ -0,0 +1,24 @@
"""Utils for OpenAI agent."""
from typing import List, Union
from llama_index.tools import BaseTool
def get_function_by_name(tools: List[BaseTool], name: str) -> BaseTool:
"""Get function by name."""
name_to_tool = {tool.metadata.name: tool for tool in tools}
if name not in name_to_tool:
raise ValueError(f"Tool with name {name} not found")
return name_to_tool[name]
def resolve_tool_choice(tool_choice: Union[str, dict] = "auto") -> Union[str, dict]:
"""Resolve tool choice.
If tool_choice is a function name string, return the appropriate dict.
"""
if isinstance(tool_choice, str) and tool_choice not in ["none", "auto"]:
return {"type": "function", "function": {"name": tool_choice}}
return tool_choice

View File

@ -0,0 +1,554 @@
"""OpenAI Assistant Agent."""
import asyncio
import json
import logging
import time
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from llama_index.agent.openai.utils import get_function_by_name
from llama_index.agent.types import BaseAgent
from llama_index.callbacks import (
CallbackManager,
CBEventType,
EventPayload,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
AgentChatResponse,
ChatResponseMode,
StreamingAgentChatResponse,
)
from llama_index.core.llms.types import ChatMessage, MessageRole
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
def from_openai_thread_message(thread_message: Any) -> ChatMessage:
"""From OpenAI thread message."""
from openai.types.beta.threads import MessageContentText, ThreadMessage
thread_message = cast(ThreadMessage, thread_message)
# we don't have a way of showing images, just do text for now
text_contents = [
t for t in thread_message.content if isinstance(t, MessageContentText)
]
text_content_str = " ".join([t.text.value for t in text_contents])
return ChatMessage(
role=thread_message.role,
content=text_content_str,
additional_kwargs={
"thread_message": thread_message,
"thread_id": thread_message.thread_id,
"assistant_id": thread_message.assistant_id,
"id": thread_message.id,
"metadata": thread_message.metadata,
},
)
def from_openai_thread_messages(thread_messages: List[Any]) -> List[ChatMessage]:
"""From OpenAI thread messages."""
return [
from_openai_thread_message(thread_message) for thread_message in thread_messages
]
def call_function(
tools: List[BaseTool], fn_obj: Any, verbose: bool = False
) -> Tuple[ChatMessage, ToolOutput]:
"""Call a function and return the output as a string."""
from openai.types.beta.threads.required_action_function_tool_call import Function
fn_obj = cast(Function, fn_obj)
# TMP: consolidate with other abstractions
name = fn_obj.name
arguments_str = fn_obj.arguments
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
tool = get_function_by_name(tools, name)
argument_dict = json.loads(arguments_str)
output = tool(**argument_dict)
if verbose:
print(f"Got output: {output!s}")
print("========================")
return (
ChatMessage(
content=str(output),
role=MessageRole.FUNCTION,
additional_kwargs={
"name": fn_obj.name,
},
),
output,
)
async def acall_function(
tools: List[BaseTool], fn_obj: Any, verbose: bool = False
) -> Tuple[ChatMessage, ToolOutput]:
"""Call an async function and return the output as a string."""
from openai.types.beta.threads.required_action_function_tool_call import Function
fn_obj = cast(Function, fn_obj)
# TMP: consolidate with other abstractions
name = fn_obj.name
arguments_str = fn_obj.arguments
if verbose:
print("=== Calling Function ===")
print(f"Calling function: {name} with args: {arguments_str}")
tool = get_function_by_name(tools, name)
argument_dict = json.loads(arguments_str)
async_tool = adapt_to_async_tool(tool)
output = await async_tool.acall(**argument_dict)
if verbose:
print(f"Got output: {output!s}")
print("========================")
return (
ChatMessage(
content=str(output),
role=MessageRole.FUNCTION,
additional_kwargs={
"name": fn_obj.name,
},
),
output,
)
def _process_files(client: Any, files: List[str]) -> Dict[str, str]:
"""Process files."""
from openai import OpenAI
client = cast(OpenAI, client)
file_dict = {}
for file in files:
file_obj = client.files.create(file=open(file, "rb"), purpose="assistants")
file_dict[file_obj.id] = file
return file_dict
class OpenAIAssistantAgent(BaseAgent):
"""OpenAIAssistant agent.
Wrapper around OpenAI assistant API: https://platform.openai.com/docs/assistants/overview
"""
def __init__(
self,
client: Any,
assistant: Any,
tools: Optional[List[BaseTool]],
callback_manager: Optional[CallbackManager] = None,
thread_id: Optional[str] = None,
instructions_prefix: Optional[str] = None,
run_retrieve_sleep_time: float = 0.1,
file_dict: Dict[str, str] = {},
verbose: bool = False,
) -> None:
"""Init params."""
from openai import OpenAI
from openai.types.beta.assistant import Assistant
self._client = cast(OpenAI, client)
self._assistant = cast(Assistant, assistant)
self._tools = tools or []
if thread_id is None:
thread = self._client.beta.threads.create()
thread_id = thread.id
self._thread_id = thread_id
self._instructions_prefix = instructions_prefix
self._run_retrieve_sleep_time = run_retrieve_sleep_time
self._verbose = verbose
self.file_dict = file_dict
self.callback_manager = callback_manager or CallbackManager([])
@classmethod
def from_new(
cls,
name: str,
instructions: str,
tools: Optional[List[BaseTool]] = None,
openai_tools: Optional[List[Dict]] = None,
thread_id: Optional[str] = None,
model: str = "gpt-4-1106-preview",
instructions_prefix: Optional[str] = None,
run_retrieve_sleep_time: float = 0.1,
files: Optional[List[str]] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
file_ids: Optional[List[str]] = None,
api_key: Optional[str] = None,
) -> "OpenAIAssistantAgent":
"""From new assistant.
Args:
name: name of assistant
instructions: instructions for assistant
tools: list of tools
openai_tools: list of openai tools
thread_id: thread id
model: model
run_retrieve_sleep_time: run retrieve sleep time
files: files
instructions_prefix: instructions prefix
callback_manager: callback manager
verbose: verbose
file_ids: list of file ids
api_key: OpenAI API key
"""
from openai import OpenAI
# this is the set of openai tools
# not to be confused with the tools we pass in for function calling
openai_tools = openai_tools or []
tools = tools or []
tool_fns = [t.metadata.to_openai_tool() for t in tools]
all_openai_tools = openai_tools + tool_fns
# initialize client
client = OpenAI(api_key=api_key)
# process files
files = files or []
file_ids = file_ids or []
file_dict = _process_files(client, files)
all_file_ids = list(file_dict.keys()) + file_ids
# TODO: openai's typing is a bit sus
all_openai_tools = cast(List[Any], all_openai_tools)
assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
tools=cast(List[Any], all_openai_tools),
model=model,
file_ids=all_file_ids,
)
return cls(
client,
assistant,
tools,
callback_manager=callback_manager,
thread_id=thread_id,
instructions_prefix=instructions_prefix,
file_dict=file_dict,
run_retrieve_sleep_time=run_retrieve_sleep_time,
verbose=verbose,
)
@classmethod
def from_existing(
cls,
assistant_id: str,
tools: Optional[List[BaseTool]] = None,
thread_id: Optional[str] = None,
instructions_prefix: Optional[str] = None,
run_retrieve_sleep_time: float = 0.1,
callback_manager: Optional[CallbackManager] = None,
api_key: Optional[str] = None,
verbose: bool = False,
) -> "OpenAIAssistantAgent":
"""From existing assistant id.
Args:
assistant_id: id of assistant
tools: list of BaseTools Assistant can use
thread_id: thread id
run_retrieve_sleep_time: run retrieve sleep time
instructions_prefix: instructions prefix
callback_manager: callback manager
api_key: OpenAI API key
verbose: verbose
"""
from openai import OpenAI
# initialize client
client = OpenAI(api_key=api_key)
# get assistant
assistant = client.beta.assistants.retrieve(assistant_id)
# assistant.tools is incompatible with BaseTools so have to pass from params
return cls(
client,
assistant,
tools=tools,
callback_manager=callback_manager,
thread_id=thread_id,
instructions_prefix=instructions_prefix,
run_retrieve_sleep_time=run_retrieve_sleep_time,
verbose=verbose,
)
@property
def assistant(self) -> Any:
"""Get assistant."""
return self._assistant
@property
def client(self) -> Any:
"""Get client."""
return self._client
@property
def thread_id(self) -> str:
"""Get thread id."""
return self._thread_id
@property
def files_dict(self) -> Dict[str, str]:
"""Get files dict."""
return self.file_dict
@property
def chat_history(self) -> List[ChatMessage]:
raw_messages = self._client.beta.threads.messages.list(
thread_id=self._thread_id, order="asc"
)
return from_openai_thread_messages(list(raw_messages))
def reset(self) -> None:
"""Delete and create a new thread."""
self._client.beta.threads.delete(self._thread_id)
thread = self._client.beta.threads.create()
thread_id = thread.id
self._thread_id = thread_id
def get_tools(self, message: str) -> List[BaseTool]:
"""Get tools."""
return self._tools
def upload_files(self, files: List[str]) -> Dict[str, Any]:
"""Upload files."""
return _process_files(self._client, files)
def add_message(self, message: str, file_ids: Optional[List[str]] = None) -> Any:
"""Add message to assistant."""
file_ids = file_ids or []
return self._client.beta.threads.messages.create(
thread_id=self._thread_id,
role="user",
content=message,
file_ids=file_ids,
)
def _run_function_calling(self, run: Any) -> List[ToolOutput]:
"""Run function calling."""
tool_calls = run.required_action.submit_tool_outputs.tool_calls
tool_output_dicts = []
tool_output_objs: List[ToolOutput] = []
for tool_call in tool_calls:
fn_obj = tool_call.function
_, tool_output = call_function(self._tools, fn_obj, verbose=self._verbose)
tool_output_dicts.append(
{"tool_call_id": tool_call.id, "output": str(tool_output)}
)
tool_output_objs.append(tool_output)
# submit tool outputs
# TODO: openai's typing is a bit sus
self._client.beta.threads.runs.submit_tool_outputs(
thread_id=self._thread_id,
run_id=run.id,
tool_outputs=cast(List[Any], tool_output_dicts),
)
return tool_output_objs
async def _arun_function_calling(self, run: Any) -> List[ToolOutput]:
"""Run function calling."""
tool_calls = run.required_action.submit_tool_outputs.tool_calls
tool_output_dicts = []
tool_output_objs: List[ToolOutput] = []
for tool_call in tool_calls:
fn_obj = tool_call.function
_, tool_output = await acall_function(
self._tools, fn_obj, verbose=self._verbose
)
tool_output_dicts.append(
{"tool_call_id": tool_call.id, "output": str(tool_output)}
)
tool_output_objs.append(tool_output)
# submit tool outputs
self._client.beta.threads.runs.submit_tool_outputs(
thread_id=self._thread_id,
run_id=run.id,
tool_outputs=cast(List[Any], tool_output_dicts),
)
return tool_output_objs
def run_assistant(
self, instructions_prefix: Optional[str] = None
) -> Tuple[Any, Dict]:
"""Run assistant."""
instructions_prefix = instructions_prefix or self._instructions_prefix
run = self._client.beta.threads.runs.create(
thread_id=self._thread_id,
assistant_id=self._assistant.id,
instructions=instructions_prefix,
)
from openai.types.beta.threads import Run
run = cast(Run, run)
sources = []
while run.status in ["queued", "in_progress", "requires_action"]:
run = self._client.beta.threads.runs.retrieve(
thread_id=self._thread_id, run_id=run.id
)
if run.status == "requires_action":
cur_tool_outputs = self._run_function_calling(run)
sources.extend(cur_tool_outputs)
time.sleep(self._run_retrieve_sleep_time)
if run.status == "failed":
raise ValueError(
f"Run failed with status {run.status}.\n" f"Error: {run.last_error}"
)
return run, {"sources": sources}
async def arun_assistant(
self, instructions_prefix: Optional[str] = None
) -> Tuple[Any, Dict]:
"""Run assistant."""
instructions_prefix = instructions_prefix or self._instructions_prefix
run = self._client.beta.threads.runs.create(
thread_id=self._thread_id,
assistant_id=self._assistant.id,
instructions=instructions_prefix,
)
from openai.types.beta.threads import Run
run = cast(Run, run)
sources = []
while run.status in ["queued", "in_progress", "requires_action"]:
run = self._client.beta.threads.runs.retrieve(
thread_id=self._thread_id, run_id=run.id
)
if run.status == "requires_action":
cur_tool_outputs = await self._arun_function_calling(run)
sources.extend(cur_tool_outputs)
await asyncio.sleep(self._run_retrieve_sleep_time)
if run.status == "failed":
raise ValueError(
f"Run failed with status {run.status}.\n" f"Error: {run.last_error}"
)
return run, {"sources": sources}
@property
def latest_message(self) -> ChatMessage:
"""Get latest message."""
raw_messages = self._client.beta.threads.messages.list(
thread_id=self._thread_id, order="desc"
)
messages = from_openai_thread_messages(list(raw_messages))
return messages[0]
def _chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
function_call: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Main chat interface."""
# TODO: since chat interface doesn't expose additional kwargs
# we can't pass in file_ids per message
added_message_obj = self.add_message(message)
run, metadata = self.run_assistant(
instructions_prefix=self._instructions_prefix,
)
latest_message = self.latest_message
# get most recent message content
return AgentChatResponse(
response=str(latest_message.content),
sources=metadata["sources"],
)
async def _achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
function_call: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Asynchronous main chat interface."""
self.add_message(message)
run, metadata = await self.arun_assistant(
instructions_prefix=self._instructions_prefix,
)
latest_message = self.latest_message
# get most recent message content
return AgentChatResponse(
response=str(latest_message.content),
sources=metadata["sources"],
)
@trace_method("chat")
def chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
function_call: Union[str, dict] = "auto",
) -> AgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = self._chat(
message, chat_history, function_call, mode=ChatResponseMode.WAIT
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
async def achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
function_call: Union[str, dict] = "auto",
) -> AgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = await self._achat(
message, chat_history, function_call, mode=ChatResponseMode.WAIT
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
def stream_chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
function_call: Union[str, dict] = "auto",
) -> StreamingAgentChatResponse:
raise NotImplementedError("stream_chat not implemented")
@trace_method("chat")
async def astream_chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
function_call: Union[str, dict] = "auto",
) -> StreamingAgentChatResponse:
raise NotImplementedError("astream_chat not implemented")

View File

@ -0,0 +1,5 @@
from llama_index.agent.react.base import ReActAgent
from llama_index.agent.react.formatter import ReActChatFormatter
from llama_index.agent.react.step import ReActAgentWorker
__all__ = ["ReActChatFormatter", "ReActAgentWorker", "ReActAgent"]

View File

@ -0,0 +1,10 @@
"""ReAct agent.
Simple wrapper around AgentRunner + ReActAgentWorker.
For the legacy implementation see:
```python
from llama_index.agent.legacy.react.base import ReActAgent
```
"""

View File

@ -0,0 +1,135 @@
"""ReAct agent.
Simple wrapper around AgentRunner + ReActAgentWorker.
For the legacy implementation see:
```python
from llama_index.agent.legacy.react.base import ReActAgent
```
"""
from typing import (
Any,
List,
Optional,
Sequence,
Type,
)
from llama_index.agent.react.formatter import ReActChatFormatter
from llama_index.agent.react.output_parser import ReActOutputParser
from llama_index.agent.react.step import ReActAgentWorker
from llama_index.agent.runner.base import AgentRunner
from llama_index.callbacks import (
CallbackManager,
)
from llama_index.core.llms.types import ChatMessage
from llama_index.llms.llm import LLM
from llama_index.llms.openai import OpenAI
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
from llama_index.memory.types import BaseMemory
from llama_index.objects.base import ObjectRetriever
from llama_index.prompts.mixin import PromptMixinType
from llama_index.tools import BaseTool
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
class ReActAgent(AgentRunner):
"""ReAct agent.
Subclasses AgentRunner with a ReActAgentWorker.
For the legacy implementation see:
```python
from llama_index.agent.legacy.react.base import ReActAgent
```
"""
def __init__(
self,
tools: Sequence[BaseTool],
llm: LLM,
memory: BaseMemory,
max_iterations: int = 10,
react_chat_formatter: Optional[ReActChatFormatter] = None,
output_parser: Optional[ReActOutputParser] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
context: Optional[str] = None,
) -> None:
"""Init params."""
callback_manager = callback_manager or llm.callback_manager
if context and react_chat_formatter:
raise ValueError("Cannot provide both context and react_chat_formatter")
if context:
react_chat_formatter = ReActChatFormatter.from_context(context)
step_engine = ReActAgentWorker.from_tools(
tools=tools,
tool_retriever=tool_retriever,
llm=llm,
max_iterations=max_iterations,
react_chat_formatter=react_chat_formatter,
output_parser=output_parser,
callback_manager=callback_manager,
verbose=verbose,
)
super().__init__(
step_engine,
memory=memory,
llm=llm,
callback_manager=callback_manager,
)
@classmethod
def from_tools(
cls,
tools: Optional[List[BaseTool]] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
llm: Optional[LLM] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
max_iterations: int = 10,
react_chat_formatter: Optional[ReActChatFormatter] = None,
output_parser: Optional[ReActOutputParser] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
context: Optional[str] = None,
**kwargs: Any,
) -> "ReActAgent":
"""Convenience constructor method from set of of BaseTools (Optional).
NOTE: kwargs should have been exhausted by this point. In other words
the various upstream components such as BaseSynthesizer (response synthesizer)
or BaseRetriever should have picked up off their respective kwargs in their
constructions.
Returns:
ReActAgent
"""
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
if callback_manager is not None:
llm.callback_manager = callback_manager
memory = memory or memory_cls.from_defaults(
chat_history=chat_history or [], llm=llm
)
return cls(
tools=tools or [],
tool_retriever=tool_retriever,
llm=llm,
memory=memory,
max_iterations=max_iterations,
react_chat_formatter=react_chat_formatter,
output_parser=output_parser,
callback_manager=callback_manager,
verbose=verbose,
context=context,
)
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt modules."""
return {"agent_worker": self.agent_worker}

View File

@ -0,0 +1,127 @@
# ReAct agent formatter
import logging
from abc import abstractmethod
from typing import List, Optional, Sequence
from llama_index.agent.react.prompts import (
CONTEXT_REACT_CHAT_SYSTEM_HEADER,
REACT_CHAT_SYSTEM_HEADER,
)
from llama_index.agent.react.types import BaseReasoningStep, ObservationReasoningStep
from llama_index.bridge.pydantic import BaseModel
from llama_index.core.llms.types import ChatMessage, MessageRole
from llama_index.tools import BaseTool
logger = logging.getLogger(__name__)
def get_react_tool_descriptions(tools: Sequence[BaseTool]) -> List[str]:
"""Tool."""
tool_descs = []
for tool in tools:
tool_desc = (
f"> Tool Name: {tool.metadata.name}\n"
f"Tool Description: {tool.metadata.description}\n"
f"Tool Args: {tool.metadata.fn_schema_str}\n"
)
tool_descs.append(tool_desc)
return tool_descs
# TODO: come up with better name
class BaseAgentChatFormatter(BaseModel):
"""Base chat formatter."""
class Config:
arbitrary_types_allowed = True
@abstractmethod
def format(
self,
tools: Sequence[BaseTool],
chat_history: List[ChatMessage],
current_reasoning: Optional[List[BaseReasoningStep]] = None,
) -> List[ChatMessage]:
"""Format chat history into list of ChatMessage."""
class ReActChatFormatter(BaseAgentChatFormatter):
"""ReAct chat formatter."""
system_header: str = REACT_CHAT_SYSTEM_HEADER # default
context: str = "" # not needed w/ default
def format(
self,
tools: Sequence[BaseTool],
chat_history: List[ChatMessage],
current_reasoning: Optional[List[BaseReasoningStep]] = None,
) -> List[ChatMessage]:
"""Format chat history into list of ChatMessage."""
current_reasoning = current_reasoning or []
format_args = {
"tool_desc": "\n".join(get_react_tool_descriptions(tools)),
"tool_names": ", ".join([tool.metadata.get_name() for tool in tools]),
}
if self.context:
format_args["context"] = self.context
fmt_sys_header = self.system_header.format(**format_args)
# format reasoning history as alternating user and assistant messages
# where the assistant messages are thoughts and actions and the user
# messages are observations
reasoning_history = []
for reasoning_step in current_reasoning:
if isinstance(reasoning_step, ObservationReasoningStep):
message = ChatMessage(
role=MessageRole.USER,
content=reasoning_step.get_content(),
)
else:
message = ChatMessage(
role=MessageRole.ASSISTANT,
content=reasoning_step.get_content(),
)
reasoning_history.append(message)
return [
ChatMessage(role=MessageRole.SYSTEM, content=fmt_sys_header),
*chat_history,
*reasoning_history,
]
@classmethod
def from_defaults(
cls,
system_header: Optional[str] = None,
context: Optional[str] = None,
) -> "ReActChatFormatter":
"""Create ReActChatFormatter from defaults."""
if not system_header:
system_header = (
REACT_CHAT_SYSTEM_HEADER
if not context
else CONTEXT_REACT_CHAT_SYSTEM_HEADER
)
return ReActChatFormatter(
system_header=system_header,
context=context or "",
)
@classmethod
def from_context(cls, context: str) -> "ReActChatFormatter":
"""Create ReActChatFormatter from context.
NOTE: deprecated
"""
logger.warning(
"ReActChatFormatter.from_context is deprecated, please use `from_defaults` instead."
)
return ReActChatFormatter.from_defaults(
system_header=CONTEXT_REACT_CHAT_SYSTEM_HEADER, context=context
)

View File

@ -0,0 +1,113 @@
"""ReAct output parser."""
import re
from typing import Tuple
from llama_index.agent.react.types import (
ActionReasoningStep,
BaseReasoningStep,
ResponseReasoningStep,
)
from llama_index.output_parsers.utils import extract_json_str
from llama_index.types import BaseOutputParser
def extract_tool_use(input_text: str) -> Tuple[str, str, str]:
pattern = (
r"\s*Thought: (.*?)\nAction: ([a-zA-Z0-9_]+).*?\nAction Input: .*?(\{.*\})"
)
match = re.search(pattern, input_text, re.DOTALL)
if not match:
raise ValueError(f"Could not extract tool use from input text: {input_text}")
thought = match.group(1).strip()
action = match.group(2).strip()
action_input = match.group(3).strip()
return thought, action, action_input
def action_input_parser(json_str: str) -> dict:
processed_string = re.sub(r"(?<!\w)\'|\'(?!\w)", '"', json_str)
pattern = r'"(\w+)":\s*"([^"]*)"'
matches = re.findall(pattern, processed_string)
return dict(matches)
def extract_final_response(input_text: str) -> Tuple[str, str]:
pattern = r"\s*Thought:(.*?)Answer:(.*?)(?:$)"
match = re.search(pattern, input_text, re.DOTALL)
if not match:
raise ValueError(
f"Could not extract final answer from input text: {input_text}"
)
thought = match.group(1).strip()
answer = match.group(2).strip()
return thought, answer
def parse_action_reasoning_step(output: str) -> ActionReasoningStep:
"""
Parse an action reasoning step from the LLM output.
"""
# Weaker LLMs may generate ReActAgent steps whose Action Input are horrible JSON strings.
# `dirtyjson` is more lenient than `json` in parsing JSON strings.
import dirtyjson as json
thought, action, action_input = extract_tool_use(output)
json_str = extract_json_str(action_input)
# First we try json, if this fails we use ast
try:
action_input_dict = json.loads(json_str)
except Exception:
action_input_dict = action_input_parser(json_str)
return ActionReasoningStep(
thought=thought, action=action, action_input=action_input_dict
)
class ReActOutputParser(BaseOutputParser):
"""ReAct Output parser."""
def parse(self, output: str, is_streaming: bool = False) -> BaseReasoningStep:
"""Parse output from ReAct agent.
We expect the output to be in one of the following formats:
1. If the agent need to use a tool to answer the question:
```
Thought: <thought>
Action: <action>
Action Input: <action_input>
```
2. If the agent can answer the question without any tools:
```
Thought: <thought>
Answer: <answer>
```
"""
if "Thought:" not in output:
# NOTE: handle the case where the agent directly outputs the answer
# instead of following the thought-answer format
return ResponseReasoningStep(
thought="(Implicit) I can answer without any more tools!",
response=output,
is_streaming=is_streaming,
)
if "Answer:" in output:
thought, answer = extract_final_response(output)
return ResponseReasoningStep(
thought=thought, response=answer, is_streaming=is_streaming
)
if "Action:" in output:
return parse_action_reasoning_step(output)
raise ValueError(f"Could not parse output: {output}")
def format(self, output: str) -> str:
"""Format a query with structured output formatting instructions."""
raise NotImplementedError

View File

@ -0,0 +1,112 @@
"""Default prompt for ReAct agent."""
# ReAct chat prompt
# TODO: have formatting instructions be a part of react output parser
REACT_CHAT_SYSTEM_HEADER = """\
You are designed to help with a variety of tasks, from answering questions \
to providing summaries to other types of analyses.
## Tools
You have access to a wide variety of tools. You are responsible for using
the tools in any sequence you deem appropriate to complete the task at hand.
This may require breaking the task into subtasks and using different tools
to complete each subtask.
You have access to the following tools:
{tool_desc}
## Output Format
To answer the question, please use the following format.
```
Thought: I need to use a tool to help me answer the question.
Action: tool name (one of {tool_names}) if using a tool.
Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. {{"input": "hello world", "num_beams": 5}})
```
Please ALWAYS start with a Thought.
Please use a valid JSON format for the Action Input. Do NOT do this {{'input': 'hello world', 'num_beams': 5}}.
If this format is used, the user will respond in the following format:
```
Observation: tool response
```
You should keep repeating the above format until you have enough information
to answer the question without using any more tools. At that point, you MUST respond
in the one of the following two formats:
```
Thought: I can answer without using any more tools.
Answer: [your answer here]
```
```
Thought: I cannot answer the question with the provided tools.
Answer: Sorry, I cannot answer your query.
```
## Current Conversation
Below is the current conversation consisting of interleaving human and assistant messages.
"""
CONTEXT_REACT_CHAT_SYSTEM_HEADER = """\
You are designed to help with a variety of tasks, from answering questions \
to providing summaries to other types of analyses.
## Tools
You have access to a wide variety of tools. You are responsible for using
the tools in any sequence you deem appropriate to complete the task at hand.
This may require breaking the task into subtasks and using different tools
to complete each subtask.
Here is some context to help you answer the question and plan:
{context}
You have access to the following tools:
{tool_desc}
## Output Format
To answer the question, please use the following format.
```
Thought: I need to use a tool to help me answer the question.
Action: tool name (one of {tool_names}) if using a tool.
Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. {{"input": "hello world", "num_beams": 5}})
```
Please ALWAYS start with a Thought.
Please use a valid JSON format for the Action Input. Do NOT do this {{'input': 'hello world', 'num_beams': 5}}.
If this format is used, the user will respond in the following format:
```
Observation: tool response
```
You should keep repeating the above format until you have enough information
to answer the question without using any more tools. At that point, you MUST respond
in the one of the following two formats:
```
Thought: I can answer without using any more tools.
Answer: [your answer here]
```
```
Thought: I cannot answer the question with the provided tools.
Answer: Sorry, I cannot answer your query.
```
## Current Conversation
Below is the current conversation consisting of interleaving human and assistant messages.
"""

View File

@ -0,0 +1,640 @@
"""ReAct agent worker."""
import asyncio
import uuid
from itertools import chain
from threading import Thread
from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Sequence,
Tuple,
cast,
)
from llama_index.agent.react.formatter import ReActChatFormatter
from llama_index.agent.react.output_parser import ReActOutputParser
from llama_index.agent.react.types import (
ActionReasoningStep,
BaseReasoningStep,
ObservationReasoningStep,
ResponseReasoningStep,
)
from llama_index.agent.types import (
BaseAgentWorker,
Task,
TaskStep,
TaskStepOutput,
)
from llama_index.callbacks import (
CallbackManager,
CBEventType,
EventPayload,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
AgentChatResponse,
StreamingAgentChatResponse,
)
from llama_index.core.llms.types import MessageRole
from llama_index.llms.base import ChatMessage, ChatResponse
from llama_index.llms.llm import LLM
from llama_index.llms.openai import OpenAI
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
from llama_index.memory.types import BaseMemory
from llama_index.objects.base import ObjectRetriever
from llama_index.prompts.base import PromptTemplate
from llama_index.prompts.mixin import PromptDictType
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
from llama_index.tools.types import AsyncBaseTool
from llama_index.utils import print_text, unit_generator
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
def add_user_step_to_reasoning(
step: TaskStep,
memory: BaseMemory,
current_reasoning: List[BaseReasoningStep],
verbose: bool = False,
) -> None:
"""Add user step to memory."""
if "is_first" in step.step_state and step.step_state["is_first"]:
# add to new memory
memory.put(ChatMessage(content=step.input, role=MessageRole.USER))
step.step_state["is_first"] = False
else:
reasoning_step = ObservationReasoningStep(observation=step.input)
current_reasoning.append(reasoning_step)
if verbose:
print(f"Added user message to memory: {step.input}")
class ReActAgentWorker(BaseAgentWorker):
"""OpenAI Agent worker."""
def __init__(
self,
tools: Sequence[BaseTool],
llm: LLM,
max_iterations: int = 10,
react_chat_formatter: Optional[ReActChatFormatter] = None,
output_parser: Optional[ReActOutputParser] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
) -> None:
self._llm = llm
self.callback_manager = callback_manager or llm.callback_manager
self._max_iterations = max_iterations
self._react_chat_formatter = react_chat_formatter or ReActChatFormatter()
self._output_parser = output_parser or ReActOutputParser()
self._verbose = verbose
if len(tools) > 0 and tool_retriever is not None:
raise ValueError("Cannot specify both tools and tool_retriever")
elif len(tools) > 0:
self._get_tools = lambda _: tools
elif tool_retriever is not None:
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
else:
self._get_tools = lambda _: []
@classmethod
def from_tools(
cls,
tools: Optional[Sequence[BaseTool]] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
llm: Optional[LLM] = None,
max_iterations: int = 10,
react_chat_formatter: Optional[ReActChatFormatter] = None,
output_parser: Optional[ReActOutputParser] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
**kwargs: Any,
) -> "ReActAgentWorker":
"""Convenience constructor method from set of of BaseTools (Optional).
NOTE: kwargs should have been exhausted by this point. In other words
the various upstream components such as BaseSynthesizer (response synthesizer)
or BaseRetriever should have picked up off their respective kwargs in their
constructions.
Returns:
ReActAgent
"""
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
if callback_manager is not None:
llm.callback_manager = callback_manager
return cls(
tools=tools or [],
tool_retriever=tool_retriever,
llm=llm,
max_iterations=max_iterations,
react_chat_formatter=react_chat_formatter,
output_parser=output_parser,
callback_manager=callback_manager,
verbose=verbose,
)
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
# TODO: the ReAct formatter does not explicitly specify PromptTemplate
# objects, but wrap it in this to obey the interface
sys_header = self._react_chat_formatter.system_header
return {"system_prompt": PromptTemplate(sys_header)}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "system_prompt" in prompts:
sys_prompt = cast(PromptTemplate, prompts["system_prompt"])
self._react_chat_formatter.system_header = sys_prompt.template
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
"""Initialize step from task."""
sources: List[ToolOutput] = []
current_reasoning: List[BaseReasoningStep] = []
# temporary memory for new messages
new_memory = ChatMemoryBuffer.from_defaults()
# initialize task state
task_state = {
"sources": sources,
"current_reasoning": current_reasoning,
"new_memory": new_memory,
}
task.extra_state.update(task_state)
return TaskStep(
task_id=task.task_id,
step_id=str(uuid.uuid4()),
input=task.input,
step_state={"is_first": True},
)
def get_tools(self, input: str) -> List[AsyncBaseTool]:
"""Get tools."""
return [adapt_to_async_tool(t) for t in self._get_tools(input)]
def _extract_reasoning_step(
self, output: ChatResponse, is_streaming: bool = False
) -> Tuple[str, List[BaseReasoningStep], bool]:
"""
Extracts the reasoning step from the given output.
This method parses the message content from the output,
extracts the reasoning step, and determines whether the processing is
complete. It also performs validation checks on the output and
handles possible errors.
"""
if output.message.content is None:
raise ValueError("Got empty message.")
message_content = output.message.content
current_reasoning = []
try:
reasoning_step = self._output_parser.parse(message_content, is_streaming)
except BaseException as exc:
raise ValueError(f"Could not parse output: {message_content}") from exc
if self._verbose:
print_text(f"{reasoning_step.get_content()}\n", color="pink")
current_reasoning.append(reasoning_step)
if reasoning_step.is_done:
return message_content, current_reasoning, True
reasoning_step = cast(ActionReasoningStep, reasoning_step)
if not isinstance(reasoning_step, ActionReasoningStep):
raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}")
return message_content, current_reasoning, False
def _process_actions(
self,
task: Task,
tools: Sequence[AsyncBaseTool],
output: ChatResponse,
is_streaming: bool = False,
) -> Tuple[List[BaseReasoningStep], bool]:
tools_dict: Dict[str, AsyncBaseTool] = {
tool.metadata.get_name(): tool for tool in tools
}
_, current_reasoning, is_done = self._extract_reasoning_step(
output, is_streaming
)
if is_done:
return current_reasoning, True
# call tool with input
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
tool = tools_dict[reasoning_step.action]
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
EventPayload.TOOL: tool.metadata,
},
) as event:
tool_output = tool.call(**reasoning_step.action_input)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
task.extra_state["sources"].append(tool_output)
observation_step = ObservationReasoningStep(observation=str(tool_output))
current_reasoning.append(observation_step)
if self._verbose:
print_text(f"{observation_step.get_content()}\n", color="blue")
return current_reasoning, False
async def _aprocess_actions(
self,
task: Task,
tools: Sequence[AsyncBaseTool],
output: ChatResponse,
is_streaming: bool = False,
) -> Tuple[List[BaseReasoningStep], bool]:
tools_dict = {tool.metadata.name: tool for tool in tools}
_, current_reasoning, is_done = self._extract_reasoning_step(
output, is_streaming
)
if is_done:
return current_reasoning, True
# call tool with input
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
tool = tools_dict[reasoning_step.action]
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
EventPayload.TOOL: tool.metadata,
},
) as event:
tool_output = await tool.acall(**reasoning_step.action_input)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
task.extra_state["sources"].append(tool_output)
observation_step = ObservationReasoningStep(observation=str(tool_output))
current_reasoning.append(observation_step)
if self._verbose:
print_text(f"{observation_step.get_content()}\n", color="blue")
return current_reasoning, False
def _get_response(
self,
current_reasoning: List[BaseReasoningStep],
sources: List[ToolOutput],
) -> AgentChatResponse:
"""Get response from reasoning steps."""
if len(current_reasoning) == 0:
raise ValueError("No reasoning steps were taken.")
elif len(current_reasoning) == self._max_iterations:
raise ValueError("Reached max iterations.")
if isinstance(current_reasoning[-1], ResponseReasoningStep):
response_step = cast(ResponseReasoningStep, current_reasoning[-1])
response_str = response_step.response
else:
response_str = current_reasoning[-1].get_content()
# TODO: add sources from reasoning steps
return AgentChatResponse(response=response_str, sources=sources)
def _get_task_step_response(
self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
) -> TaskStepOutput:
"""Get task step response."""
if is_done:
new_steps = []
else:
new_steps = [
step.get_next_step(
step_id=str(uuid.uuid4()),
# NOTE: input is unused
input=None,
)
]
return TaskStepOutput(
output=agent_response,
task_step=step,
is_last=is_done,
next_steps=new_steps,
)
def _infer_stream_chunk_is_final(self, chunk: ChatResponse) -> bool:
"""Infers if a chunk from a live stream is the start of the final
reasoning step. (i.e., and should eventually become
ResponseReasoningStep not part of this function's logic tho.).
Args:
chunk (ChatResponse): the current chunk stream to check
Returns:
bool: Boolean on whether the chunk is the start of the final response
"""
latest_content = chunk.message.content
if latest_content:
if not latest_content.startswith(
"Thought"
): # doesn't follow thought-action format
return True
else:
if "Answer: " in latest_content:
return True
return False
def _add_back_chunk_to_stream(
self, chunk: ChatResponse, chat_stream: Generator[ChatResponse, None, None]
) -> Generator[ChatResponse, None, None]:
"""Helper method for adding back initial chunk stream of final response
back to the rest of the chat_stream.
Args:
chunk (ChatResponse): the chunk to add back to the beginning of the
chat_stream.
Return:
Generator[ChatResponse, None, None]: the updated chat_stream
"""
updated_stream = chain.from_iterable( # need to add back partial response chunk
[
unit_generator(chunk),
chat_stream,
]
)
# use cast to avoid mypy issue with chain and Generator
updated_stream_c: Generator[ChatResponse, None, None] = cast(
Generator[ChatResponse, None, None], updated_stream
)
return updated_stream_c
async def _async_add_back_chunk_to_stream(
self, chunk: ChatResponse, chat_stream: AsyncGenerator[ChatResponse, None]
) -> AsyncGenerator[ChatResponse, None]:
"""Helper method for adding back initial chunk stream of final response
back to the rest of the chat_stream.
NOTE: this itself is not an async function.
Args:
chunk (ChatResponse): the chunk to add back to the beginning of the
chat_stream.
Return:
AsyncGenerator[ChatResponse, None]: the updated async chat_stream
"""
yield chunk
async for item in chat_stream:
yield item
def _run_step(
self,
step: TaskStep,
task: Task,
) -> TaskStepOutput:
"""Run step."""
if step.input is not None:
add_user_step_to_reasoning(
step,
task.extra_state["new_memory"],
task.extra_state["current_reasoning"],
verbose=self._verbose,
)
# TODO: see if we want to do step-based inputs
tools = self.get_tools(task.input)
input_chat = self._react_chat_formatter.format(
tools,
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
current_reasoning=task.extra_state["current_reasoning"],
)
# send prompt
chat_response = self._llm.chat(input_chat)
# given react prompt outputs, call tools or return response
reasoning_steps, is_done = self._process_actions(
task, tools, output=chat_response
)
task.extra_state["current_reasoning"].extend(reasoning_steps)
agent_response = self._get_response(
task.extra_state["current_reasoning"], task.extra_state["sources"]
)
if is_done:
task.extra_state["new_memory"].put(
ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
)
return self._get_task_step_response(agent_response, step, is_done)
async def _arun_step(
self,
step: TaskStep,
task: Task,
) -> TaskStepOutput:
"""Run step."""
if step.input is not None:
add_user_step_to_reasoning(
step,
task.extra_state["new_memory"],
task.extra_state["current_reasoning"],
verbose=self._verbose,
)
# TODO: see if we want to do step-based inputs
tools = self.get_tools(task.input)
input_chat = self._react_chat_formatter.format(
tools,
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
current_reasoning=task.extra_state["current_reasoning"],
)
# send prompt
chat_response = await self._llm.achat(input_chat)
# given react prompt outputs, call tools or return response
reasoning_steps, is_done = await self._aprocess_actions(
task, tools, output=chat_response
)
task.extra_state["current_reasoning"].extend(reasoning_steps)
agent_response = self._get_response(
task.extra_state["current_reasoning"], task.extra_state["sources"]
)
if is_done:
task.extra_state["new_memory"].put(
ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
)
return self._get_task_step_response(agent_response, step, is_done)
def _run_step_stream(
self,
step: TaskStep,
task: Task,
) -> TaskStepOutput:
"""Run step."""
if step.input is not None:
add_user_step_to_reasoning(
step,
task.extra_state["new_memory"],
task.extra_state["current_reasoning"],
verbose=self._verbose,
)
# TODO: see if we want to do step-based inputs
tools = self.get_tools(task.input)
input_chat = self._react_chat_formatter.format(
tools,
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
current_reasoning=task.extra_state["current_reasoning"],
)
chat_stream = self._llm.stream_chat(input_chat)
# iterate over stream, break out if is final answer after the "Answer: "
full_response = ChatResponse(
message=ChatMessage(content=None, role="assistant")
)
is_done = False
for latest_chunk in chat_stream:
full_response = latest_chunk
is_done = self._infer_stream_chunk_is_final(latest_chunk)
if is_done:
break
if not is_done:
# given react prompt outputs, call tools or return response
reasoning_steps, _ = self._process_actions(
task, tools=tools, output=full_response, is_streaming=True
)
task.extra_state["current_reasoning"].extend(reasoning_steps)
# use _get_response to return intermediate response
agent_response: AGENT_CHAT_RESPONSE_TYPE = self._get_response(
task.extra_state["current_reasoning"], task.extra_state["sources"]
)
else:
# Get the response in a separate thread so we can yield the response
response_stream = self._add_back_chunk_to_stream(
chunk=latest_chunk, chat_stream=chat_stream
)
agent_response = StreamingAgentChatResponse(
chat_stream=response_stream,
sources=task.extra_state["sources"],
)
thread = Thread(
target=agent_response.write_response_to_history,
args=(task.extra_state["new_memory"],),
)
thread.start()
return self._get_task_step_response(agent_response, step, is_done)
async def _arun_step_stream(
self,
step: TaskStep,
task: Task,
) -> TaskStepOutput:
"""Run step."""
if step.input is not None:
add_user_step_to_reasoning(
step,
task.extra_state["new_memory"],
task.extra_state["current_reasoning"],
verbose=self._verbose,
)
# TODO: see if we want to do step-based inputs
tools = self.get_tools(task.input)
input_chat = self._react_chat_formatter.format(
tools,
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
current_reasoning=task.extra_state["current_reasoning"],
)
chat_stream = await self._llm.astream_chat(input_chat)
# iterate over stream, break out if is final answer after the "Answer: "
full_response = ChatResponse(
message=ChatMessage(content=None, role="assistant")
)
is_done = False
async for latest_chunk in chat_stream:
full_response = latest_chunk
is_done = self._infer_stream_chunk_is_final(latest_chunk)
if is_done:
break
if not is_done:
# given react prompt outputs, call tools or return response
reasoning_steps, _ = self._process_actions(
task, tools=tools, output=full_response, is_streaming=True
)
task.extra_state["current_reasoning"].extend(reasoning_steps)
# use _get_response to return intermediate response
agent_response: AGENT_CHAT_RESPONSE_TYPE = self._get_response(
task.extra_state["current_reasoning"], task.extra_state["sources"]
)
else:
# Get the response in a separate thread so we can yield the response
response_stream = self._async_add_back_chunk_to_stream(
chunk=latest_chunk, chat_stream=chat_stream
)
agent_response = StreamingAgentChatResponse(
achat_stream=response_stream,
sources=task.extra_state["sources"],
)
# create task to write chat response to history
asyncio.create_task(
agent_response.awrite_response_to_history(
task.extra_state["new_memory"]
)
)
# wait until response writing is done
await agent_response._is_function_false_event.wait()
return self._get_task_step_response(agent_response, step, is_done)
@trace_method("run_step")
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step."""
return self._run_step(step, task)
@trace_method("run_step")
async def arun_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async)."""
return await self._arun_step(step, task)
@trace_method("run_step")
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step (stream)."""
# TODO: figure out if we need a different type for TaskStepOutput
return self._run_step_stream(step, task)
@trace_method("run_step")
async def astream_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async stream)."""
return await self._arun_step_stream(step, task)
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
# add new messages to memory
task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
# reset new memory
task.extra_state["new_memory"].reset()
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
self.callback_manager = callback_manager

View File

@ -0,0 +1,77 @@
"""Base types for ReAct agent."""
from abc import abstractmethod
from typing import Dict
from llama_index.bridge.pydantic import BaseModel
class BaseReasoningStep(BaseModel):
"""Reasoning step."""
@abstractmethod
def get_content(self) -> str:
"""Get content."""
@property
@abstractmethod
def is_done(self) -> bool:
"""Is the reasoning step the last one."""
class ActionReasoningStep(BaseReasoningStep):
"""Action Reasoning step."""
thought: str
action: str
action_input: Dict
def get_content(self) -> str:
"""Get content."""
return (
f"Thought: {self.thought}\nAction: {self.action}\n"
f"Action Input: {self.action_input}"
)
@property
def is_done(self) -> bool:
"""Is the reasoning step the last one."""
return False
class ObservationReasoningStep(BaseReasoningStep):
"""Observation reasoning step."""
observation: str
def get_content(self) -> str:
"""Get content."""
return f"Observation: {self.observation}"
@property
def is_done(self) -> bool:
"""Is the reasoning step the last one."""
return False
class ResponseReasoningStep(BaseReasoningStep):
"""Response reasoning step."""
thought: str
response: str
is_streaming: bool = False
def get_content(self) -> str:
"""Get content."""
if self.is_streaming:
return (
f"Thought: {self.thought}\n"
f"Answer (Starts With): {self.response} ..."
)
else:
return f"Thought: {self.thought}\n" f"Answer: {self.response}"
@property
def is_done(self) -> bool:
"""Is the reasoning step the last one."""
return True

View File

@ -0,0 +1,87 @@
"""Default prompt for ReAct agent."""
# ReAct multimodal chat prompt
# TODO: have formatting instructions be a part of react output parser
REACT_MM_CHAT_SYSTEM_HEADER = """\
You are designed to help with a variety of tasks, from answering questions \
to providing summaries to other types of analyses. You can take in both text \
and images.
## Tools
You have access to a wide variety of tools. You are responsible for using
the tools in any sequence you deem appropriate to complete the task at hand.
This may require breaking the task into subtasks and using different tools
to complete each subtask.
NOTE: you do NOT need to use a tool to understand the provided images. You can
use both the input text and images as context to decide which tool to use.
You have access to the following tools:
{tool_desc}
## Input
The user will specify a task (in text) and a set of images. Treat
the images as additional context for the task.
## Output Format
To answer the question, please use the following format.
```
Thought: I need to use a tool to help me answer the question.
Action: tool name (one of {tool_names}) if using a tool.
Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. {{"input": "hello world", "num_beams": 5}})
```
Please ALWAYS start with a Thought.
Please use a valid JSON format for the Action Input. Do NOT do this {{'input': 'hello world', 'num_beams': 5}}.
If this format is used, the user will respond in the following format:
```
Observation: tool response
```
Here's a concrete example. Again, you can take in both text and images as input. This can generate a thought which can be used to decide which tool to use.
The input to the tool should not assume knowledge of the image. Therefore it is your responsibility \
to translate the input text/images into a format that the tool can understand.
For example:
```
Thought: This image is a picture of a brown dog. The text asked me to identify its name, so I need to use a tool to lookup its name.
Action: churchill_bio_tool
Action Input: {{"input": "brown dog name"}}
```
Example user response:
```
Observation: The name of the brown dog is Rufus.
```
You should keep repeating the above format until you have enough information
to answer the question without using any more tools. At that point, you MUST respond
in the one of the following two formats:
```
Thought: I can answer without using any more tools.
Answer: [your answer here]
```
```
Thought: I cannot answer the question with the provided tools.
Answer: Sorry, I cannot answer your query.
```
The answer MUST be grounded in the input text and images. Do not give an answer that is irrelevant to the image
provided.
## Current Conversation
Below is the current conversation consisting of interleaving human and assistant messages.
"""

View File

@ -0,0 +1,479 @@
"""ReAct multimodal agent."""
import uuid
from typing import (
Any,
Dict,
List,
Optional,
Sequence,
Tuple,
cast,
)
from llama_index.agent.react.formatter import ReActChatFormatter
from llama_index.agent.react.output_parser import ReActOutputParser
from llama_index.agent.react.types import (
ActionReasoningStep,
BaseReasoningStep,
ObservationReasoningStep,
ResponseReasoningStep,
)
from llama_index.agent.react_multimodal.prompts import REACT_MM_CHAT_SYSTEM_HEADER
from llama_index.agent.types import (
BaseAgentWorker,
Task,
TaskStep,
TaskStepOutput,
)
from llama_index.callbacks import (
CallbackManager,
CBEventType,
EventPayload,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
AgentChatResponse,
)
from llama_index.core.llms.types import MessageRole
from llama_index.llms.base import ChatMessage, ChatResponse
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
from llama_index.memory.types import BaseMemory
from llama_index.multi_modal_llms.base import MultiModalLLM
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.multi_modal_llms.openai_utils import (
generate_openai_multi_modal_chat_message,
)
from llama_index.objects.base import ObjectRetriever
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
from llama_index.tools.types import AsyncBaseTool
from llama_index.utils import print_text
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
def add_user_step_to_reasoning(
step: TaskStep,
memory: BaseMemory,
current_reasoning: List[BaseReasoningStep],
verbose: bool = False,
) -> None:
"""Add user step to reasoning.
Adds both text input and image input to reasoning.
"""
# raise error if step.input is None
if step.input is None:
raise ValueError("Step input is None.")
# TODO: support gemini as well. Currently just supports OpenAI
# TODO: currently assume that you can't generate images in the loop,
# so step_state contains the original image_docs from the task
# (it doesn't change)
image_docs = step.step_state["image_docs"]
image_kwargs = step.step_state.get("image_kwargs", {})
if "is_first" in step.step_state and step.step_state["is_first"]:
mm_message = generate_openai_multi_modal_chat_message(
prompt=step.input,
role=MessageRole.USER,
image_documents=image_docs,
**image_kwargs,
)
# add to new memory
memory.put(mm_message)
step.step_state["is_first"] = False
else:
# NOTE: this is where the user specifies an intermediate step in the middle
# TODO: don't support specifying image_docs here for now
reasoning_step = ObservationReasoningStep(observation=step.input)
current_reasoning.append(reasoning_step)
if verbose:
print(f"Added user message to memory: {step.input}")
class MultimodalReActAgentWorker(BaseAgentWorker):
"""Multimodal ReAct Agent worker.
**NOTE**: This is a BETA feature.
"""
def __init__(
self,
tools: Sequence[BaseTool],
multi_modal_llm: MultiModalLLM,
max_iterations: int = 10,
react_chat_formatter: Optional[ReActChatFormatter] = None,
output_parser: Optional[ReActOutputParser] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
) -> None:
self._multi_modal_llm = multi_modal_llm
self.callback_manager = callback_manager or CallbackManager([])
self._max_iterations = max_iterations
self._react_chat_formatter = react_chat_formatter or ReActChatFormatter(
system_header=REACT_MM_CHAT_SYSTEM_HEADER
)
self._output_parser = output_parser or ReActOutputParser()
self._verbose = verbose
if len(tools) > 0 and tool_retriever is not None:
raise ValueError("Cannot specify both tools and tool_retriever")
elif len(tools) > 0:
self._get_tools = lambda _: tools
elif tool_retriever is not None:
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
else:
self._get_tools = lambda _: []
@classmethod
def from_tools(
cls,
tools: Optional[Sequence[BaseTool]] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
multi_modal_llm: Optional[MultiModalLLM] = None,
max_iterations: int = 10,
react_chat_formatter: Optional[ReActChatFormatter] = None,
output_parser: Optional[ReActOutputParser] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
**kwargs: Any,
) -> "MultimodalReActAgentWorker":
"""Convenience constructor method from set of of BaseTools (Optional).
NOTE: kwargs should have been exhausted by this point. In other words
the various upstream components such as BaseSynthesizer (response synthesizer)
or BaseRetriever should have picked up off their respective kwargs in their
constructions.
Returns:
ReActAgent
"""
multi_modal_llm = multi_modal_llm or OpenAIMultiModal(
model="gpt-4-vision-preview", max_new_tokens=1000
)
return cls(
tools=tools or [],
tool_retriever=tool_retriever,
multi_modal_llm=multi_modal_llm,
max_iterations=max_iterations,
react_chat_formatter=react_chat_formatter,
output_parser=output_parser,
callback_manager=callback_manager,
verbose=verbose,
)
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
"""Initialize step from task."""
sources: List[ToolOutput] = []
current_reasoning: List[BaseReasoningStep] = []
# temporary memory for new messages
new_memory = ChatMemoryBuffer.from_defaults()
# validation
if "image_docs" not in task.extra_state:
raise ValueError("Image docs not found in task extra state.")
# initialize task state
task_state = {
"sources": sources,
"current_reasoning": current_reasoning,
"new_memory": new_memory,
}
task.extra_state.update(task_state)
return TaskStep(
task_id=task.task_id,
step_id=str(uuid.uuid4()),
input=task.input,
step_state={"is_first": True, "image_docs": task.extra_state["image_docs"]},
)
def get_tools(self, input: str) -> List[AsyncBaseTool]:
"""Get tools."""
return [adapt_to_async_tool(t) for t in self._get_tools(input)]
def _extract_reasoning_step(
self, output: ChatResponse, is_streaming: bool = False
) -> Tuple[str, List[BaseReasoningStep], bool]:
"""
Extracts the reasoning step from the given output.
This method parses the message content from the output,
extracts the reasoning step, and determines whether the processing is
complete. It also performs validation checks on the output and
handles possible errors.
"""
if output.message.content is None:
raise ValueError("Got empty message.")
message_content = output.message.content
current_reasoning = []
try:
reasoning_step = self._output_parser.parse(message_content, is_streaming)
except BaseException as exc:
raise ValueError(f"Could not parse output: {message_content}") from exc
if self._verbose:
print_text(f"{reasoning_step.get_content()}\n", color="pink")
current_reasoning.append(reasoning_step)
if reasoning_step.is_done:
return message_content, current_reasoning, True
reasoning_step = cast(ActionReasoningStep, reasoning_step)
if not isinstance(reasoning_step, ActionReasoningStep):
raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}")
return message_content, current_reasoning, False
def _process_actions(
self,
task: Task,
tools: Sequence[AsyncBaseTool],
output: ChatResponse,
is_streaming: bool = False,
) -> Tuple[List[BaseReasoningStep], bool]:
tools_dict: Dict[str, AsyncBaseTool] = {
tool.metadata.get_name(): tool for tool in tools
}
_, current_reasoning, is_done = self._extract_reasoning_step(
output, is_streaming
)
if is_done:
return current_reasoning, True
# call tool with input
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
tool = tools_dict[reasoning_step.action]
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
EventPayload.TOOL: tool.metadata,
},
) as event:
tool_output = tool.call(**reasoning_step.action_input)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
task.extra_state["sources"].append(tool_output)
observation_step = ObservationReasoningStep(observation=str(tool_output))
current_reasoning.append(observation_step)
if self._verbose:
print_text(f"{observation_step.get_content()}\n", color="blue")
return current_reasoning, False
async def _aprocess_actions(
self,
task: Task,
tools: Sequence[AsyncBaseTool],
output: ChatResponse,
is_streaming: bool = False,
) -> Tuple[List[BaseReasoningStep], bool]:
tools_dict = {tool.metadata.name: tool for tool in tools}
_, current_reasoning, is_done = self._extract_reasoning_step(
output, is_streaming
)
if is_done:
return current_reasoning, True
# call tool with input
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
tool = tools_dict[reasoning_step.action]
with self.callback_manager.event(
CBEventType.FUNCTION_CALL,
payload={
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
EventPayload.TOOL: tool.metadata,
},
) as event:
tool_output = await tool.acall(**reasoning_step.action_input)
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
task.extra_state["sources"].append(tool_output)
observation_step = ObservationReasoningStep(observation=str(tool_output))
current_reasoning.append(observation_step)
if self._verbose:
print_text(f"{observation_step.get_content()}\n", color="blue")
return current_reasoning, False
def _get_response(
self,
current_reasoning: List[BaseReasoningStep],
sources: List[ToolOutput],
) -> AgentChatResponse:
"""Get response from reasoning steps."""
if len(current_reasoning) == 0:
raise ValueError("No reasoning steps were taken.")
elif len(current_reasoning) == self._max_iterations:
raise ValueError("Reached max iterations.")
if isinstance(current_reasoning[-1], ResponseReasoningStep):
response_step = cast(ResponseReasoningStep, current_reasoning[-1])
response_str = response_step.response
else:
response_str = current_reasoning[-1].get_content()
# TODO: add sources from reasoning steps
return AgentChatResponse(response=response_str, sources=sources)
def _get_task_step_response(
self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
) -> TaskStepOutput:
"""Get task step response."""
if is_done:
new_steps = []
else:
new_steps = [
step.get_next_step(
step_id=str(uuid.uuid4()),
# NOTE: input is unused
input=None,
)
]
return TaskStepOutput(
output=agent_response,
task_step=step,
is_last=is_done,
next_steps=new_steps,
)
def _run_step(
self,
step: TaskStep,
task: Task,
) -> TaskStepOutput:
"""Run step."""
# This is either not None on the first step or if the user specifies
# an intermediate step in the middle
if step.input is not None:
add_user_step_to_reasoning(
step,
task.extra_state["new_memory"],
task.extra_state["current_reasoning"],
verbose=self._verbose,
)
# TODO: see if we want to do step-based inputs
tools = self.get_tools(task.input)
input_chat = self._react_chat_formatter.format(
tools,
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
current_reasoning=task.extra_state["current_reasoning"],
)
# send prompt
chat_response = self._multi_modal_llm.chat(input_chat)
# given react prompt outputs, call tools or return response
reasoning_steps, is_done = self._process_actions(
task, tools, output=chat_response
)
task.extra_state["current_reasoning"].extend(reasoning_steps)
agent_response = self._get_response(
task.extra_state["current_reasoning"], task.extra_state["sources"]
)
if is_done:
task.extra_state["new_memory"].put(
ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
)
return self._get_task_step_response(agent_response, step, is_done)
async def _arun_step(
self,
step: TaskStep,
task: Task,
) -> TaskStepOutput:
"""Run step."""
if step.input is not None:
add_user_step_to_reasoning(
step,
task.extra_state["new_memory"],
task.extra_state["current_reasoning"],
verbose=self._verbose,
)
# TODO: see if we want to do step-based inputs
tools = self.get_tools(task.input)
input_chat = self._react_chat_formatter.format(
tools,
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
current_reasoning=task.extra_state["current_reasoning"],
)
# send prompt
chat_response = await self._multi_modal_llm.achat(input_chat)
# given react prompt outputs, call tools or return response
reasoning_steps, is_done = await self._aprocess_actions(
task, tools, output=chat_response
)
task.extra_state["current_reasoning"].extend(reasoning_steps)
agent_response = self._get_response(
task.extra_state["current_reasoning"], task.extra_state["sources"]
)
if is_done:
task.extra_state["new_memory"].put(
ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
)
return self._get_task_step_response(agent_response, step, is_done)
def _run_step_stream(
self,
step: TaskStep,
task: Task,
) -> TaskStepOutput:
"""Run step."""
raise NotImplementedError("Stream step not implemented yet.")
async def _arun_step_stream(
self,
step: TaskStep,
task: Task,
) -> TaskStepOutput:
"""Run step."""
raise NotImplementedError("Stream step not implemented yet.")
@trace_method("run_step")
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step."""
return self._run_step(step, task)
@trace_method("run_step")
async def arun_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async)."""
return await self._arun_step(step, task)
@trace_method("run_step")
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step (stream)."""
# TODO: figure out if we need a different type for TaskStepOutput
return self._run_step_stream(step, task)
@trace_method("run_step")
async def astream_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async stream)."""
return await self._arun_step_stream(step, task)
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
# add new messages to memory
task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
# reset new memory
task.extra_state["new_memory"].reset()
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
self.callback_manager = callback_manager

View File

@ -0,0 +1 @@
"""Init params."""

View File

@ -0,0 +1,631 @@
from abc import abstractmethod
from collections import deque
from typing import Any, Deque, Dict, List, Optional, Union, cast
from llama_index.agent.types import (
BaseAgent,
BaseAgentWorker,
Task,
TaskStep,
TaskStepOutput,
)
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.callbacks import (
CallbackManager,
CBEventType,
EventPayload,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
AgentChatResponse,
ChatResponseMode,
StreamingAgentChatResponse,
)
from llama_index.llms.base import ChatMessage
from llama_index.llms.llm import LLM
from llama_index.memory import BaseMemory, ChatMemoryBuffer
from llama_index.memory.types import BaseMemory
from llama_index.tools.types import BaseTool
class BaseAgentRunner(BaseAgent):
"""Base agent runner."""
@abstractmethod
def create_task(self, input: str, **kwargs: Any) -> Task:
"""Create task."""
@abstractmethod
def delete_task(
self,
task_id: str,
) -> None:
"""Delete task.
NOTE: this will not delete any previous executions from memory.
"""
@abstractmethod
def list_tasks(self, **kwargs: Any) -> List[Task]:
"""List tasks."""
@abstractmethod
def get_task(self, task_id: str, **kwargs: Any) -> Task:
"""Get task."""
@abstractmethod
def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
"""Get upcoming steps."""
@abstractmethod
def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
"""Get completed steps."""
def get_completed_step(
self, task_id: str, step_id: str, **kwargs: Any
) -> TaskStepOutput:
"""Get completed step."""
# call get_completed_steps, and then find the right task
completed_steps = self.get_completed_steps(task_id, **kwargs)
for step_output in completed_steps:
if step_output.task_step.step_id == step_id:
return step_output
raise ValueError(f"Could not find step_id: {step_id}")
@abstractmethod
def run_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step."""
@abstractmethod
async def arun_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (async)."""
@abstractmethod
def stream_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (stream)."""
@abstractmethod
async def astream_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (async stream)."""
@abstractmethod
def finalize_response(
self,
task_id: str,
step_output: Optional[TaskStepOutput] = None,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Finalize response."""
@abstractmethod
def undo_step(self, task_id: str) -> None:
"""Undo previous step."""
raise NotImplementedError("undo_step not implemented")
def validate_step_from_args(
task_id: str, input: Optional[str] = None, step: Optional[Any] = None, **kwargs: Any
) -> Optional[TaskStep]:
"""Validate step from args."""
if step is not None:
if input is not None:
raise ValueError("Cannot specify both `step` and `input`")
if not isinstance(step, TaskStep):
raise ValueError(f"step must be TaskStep: {step}")
return step
else:
return None
class TaskState(BaseModel):
"""Task state."""
task: Task = Field(..., description="Task.")
step_queue: Deque[TaskStep] = Field(
default_factory=deque, description="Task step queue."
)
completed_steps: List[TaskStepOutput] = Field(
default_factory=list, description="Completed step outputs."
)
class AgentState(BaseModel):
"""Agent state."""
task_dict: Dict[str, TaskState] = Field(
default_factory=dict, description="Task dictionary."
)
def get_task(self, task_id: str) -> Task:
"""Get task state."""
return self.task_dict[task_id].task
def get_completed_steps(self, task_id: str) -> List[TaskStepOutput]:
"""Get completed steps."""
return self.task_dict[task_id].completed_steps
def get_step_queue(self, task_id: str) -> Deque[TaskStep]:
"""Get step queue."""
return self.task_dict[task_id].step_queue
def reset(self) -> None:
"""Reset."""
self.task_dict = {}
class AgentRunner(BaseAgentRunner):
"""Agent runner.
Top-level agent orchestrator that can create tasks, run each step in a task,
or run a task e2e. Stores state and keeps track of tasks.
Args:
agent_worker (BaseAgentWorker): step executor
chat_history (Optional[List[ChatMessage]], optional): chat history. Defaults to None.
state (Optional[AgentState], optional): agent state. Defaults to None.
memory (Optional[BaseMemory], optional): memory. Defaults to None.
llm (Optional[LLM], optional): LLM. Defaults to None.
callback_manager (Optional[CallbackManager], optional): callback manager. Defaults to None.
init_task_state_kwargs (Optional[dict], optional): init task state kwargs. Defaults to None.
"""
# # TODO: implement this in Pydantic
def __init__(
self,
agent_worker: BaseAgentWorker,
chat_history: Optional[List[ChatMessage]] = None,
state: Optional[AgentState] = None,
memory: Optional[BaseMemory] = None,
llm: Optional[LLM] = None,
callback_manager: Optional[CallbackManager] = None,
init_task_state_kwargs: Optional[dict] = None,
delete_task_on_finish: bool = False,
default_tool_choice: str = "auto",
verbose: bool = False,
) -> None:
"""Initialize."""
self.agent_worker = agent_worker
self.state = state or AgentState()
self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm)
# get and set callback manager
if callback_manager is not None:
self.agent_worker.set_callback_manager(callback_manager)
self.callback_manager = callback_manager
else:
# TODO: This is *temporary*
# Stopgap before having a callback on the BaseAgentWorker interface.
# Doing that requires a bit more refactoring to make sure existing code
# doesn't break.
if hasattr(self.agent_worker, "callback_manager"):
self.callback_manager = (
self.agent_worker.callback_manager or CallbackManager()
)
else:
self.callback_manager = CallbackManager()
self.init_task_state_kwargs = init_task_state_kwargs or {}
self.delete_task_on_finish = delete_task_on_finish
self.default_tool_choice = default_tool_choice
self.verbose = verbose
@staticmethod
def from_llm(
tools: Optional[List[BaseTool]] = None,
llm: Optional[LLM] = None,
**kwargs: Any,
) -> "AgentRunner":
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_utils import is_function_calling_model
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
from llama_index.agent import OpenAIAgent
return OpenAIAgent.from_tools(
tools=tools,
llm=llm,
**kwargs,
)
else:
from llama_index.agent import ReActAgent
return ReActAgent.from_tools(
tools=tools,
llm=llm,
**kwargs,
)
@property
def chat_history(self) -> List[ChatMessage]:
return self.memory.get_all()
def reset(self) -> None:
self.memory.reset()
self.state.reset()
def create_task(self, input: str, **kwargs: Any) -> Task:
"""Create task."""
if not self.init_task_state_kwargs:
extra_state = kwargs.pop("extra_state", {})
else:
if "extra_state" in kwargs:
raise ValueError(
"Cannot specify both `extra_state` and `init_task_state_kwargs`"
)
else:
extra_state = self.init_task_state_kwargs
callback_manager = kwargs.pop("callback_manager", self.callback_manager)
task = Task(
input=input,
memory=self.memory,
extra_state=extra_state,
callback_manager=callback_manager,
**kwargs,
)
# # put input into memory
# self.memory.put(ChatMessage(content=input, role=MessageRole.USER))
# get initial step from task, and put it in the step queue
initial_step = self.agent_worker.initialize_step(task)
task_state = TaskState(
task=task,
step_queue=deque([initial_step]),
)
# add it to state
self.state.task_dict[task.task_id] = task_state
return task
def delete_task(
self,
task_id: str,
) -> None:
"""Delete task.
NOTE: this will not delete any previous executions from memory.
"""
self.state.task_dict.pop(task_id)
def list_tasks(self, **kwargs: Any) -> List[Task]:
"""List tasks."""
return list(self.state.task_dict.values())
def get_task(self, task_id: str, **kwargs: Any) -> Task:
"""Get task."""
return self.state.get_task(task_id)
def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
"""Get upcoming steps."""
return list(self.state.get_step_queue(task_id))
def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
"""Get completed steps."""
return self.state.get_completed_steps(task_id)
def _run_step(
self,
task_id: str,
step: Optional[TaskStep] = None,
input: Optional[str] = None,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any,
) -> TaskStepOutput:
"""Execute step."""
task = self.state.get_task(task_id)
step_queue = self.state.get_step_queue(task_id)
step = step or step_queue.popleft()
if input is not None:
step.input = input
if self.verbose:
print(f"> Running step {step.step_id}. Step input: {step.input}")
# TODO: figure out if you can dynamically swap in different step executors
# not clear when you would do that by theoretically possible
if mode == ChatResponseMode.WAIT:
cur_step_output = self.agent_worker.run_step(step, task, **kwargs)
elif mode == ChatResponseMode.STREAM:
cur_step_output = self.agent_worker.stream_step(step, task, **kwargs)
else:
raise ValueError(f"Invalid mode: {mode}")
# append cur_step_output next steps to queue
next_steps = cur_step_output.next_steps
step_queue.extend(next_steps)
# add cur_step_output to completed steps
completed_steps = self.state.get_completed_steps(task_id)
completed_steps.append(cur_step_output)
return cur_step_output
async def _arun_step(
self,
task_id: str,
step: Optional[TaskStep] = None,
input: Optional[str] = None,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any,
) -> TaskStepOutput:
"""Execute step."""
task = self.state.get_task(task_id)
step_queue = self.state.get_step_queue(task_id)
step = step or step_queue.popleft()
if input is not None:
step.input = input
if self.verbose:
print(f"> Running step {step.step_id}. Step input: {step.input}")
# TODO: figure out if you can dynamically swap in different step executors
# not clear when you would do that by theoretically possible
if mode == ChatResponseMode.WAIT:
cur_step_output = await self.agent_worker.arun_step(step, task, **kwargs)
elif mode == ChatResponseMode.STREAM:
cur_step_output = await self.agent_worker.astream_step(step, task, **kwargs)
else:
raise ValueError(f"Invalid mode: {mode}")
# append cur_step_output next steps to queue
next_steps = cur_step_output.next_steps
step_queue.extend(next_steps)
# add cur_step_output to completed steps
completed_steps = self.state.get_completed_steps(task_id)
completed_steps.append(cur_step_output)
return cur_step_output
def run_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step."""
step = validate_step_from_args(task_id, input, step, **kwargs)
return self._run_step(
task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs
)
async def arun_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (async)."""
step = validate_step_from_args(task_id, input, step, **kwargs)
return await self._arun_step(
task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs
)
def stream_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (stream)."""
step = validate_step_from_args(task_id, input, step, **kwargs)
return self._run_step(
task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs
)
async def astream_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (async stream)."""
step = validate_step_from_args(task_id, input, step, **kwargs)
return await self._arun_step(
task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs
)
def finalize_response(
self,
task_id: str,
step_output: Optional[TaskStepOutput] = None,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Finalize response."""
if step_output is None:
step_output = self.state.get_completed_steps(task_id)[-1]
if not step_output.is_last:
raise ValueError(
"finalize_response can only be called on the last step output"
)
if not isinstance(
step_output.output,
(AgentChatResponse, StreamingAgentChatResponse),
):
raise ValueError(
"When `is_last` is True, cur_step_output.output must be "
f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}"
)
# finalize task
self.agent_worker.finalize_task(self.state.get_task(task_id))
if self.delete_task_on_finish:
self.delete_task(task_id)
return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output)
def _chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)
result_output = None
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_output = self._run_step(
task.task_id, mode=mode, tool_choice=tool_choice
)
if cur_step_output.is_last:
result_output = cur_step_output
break
# ensure tool_choice does not cause endless loops
tool_choice = "auto"
return self.finalize_response(task.task_id, result_output)
async def _achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)
result_output = None
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_output = await self._arun_step(
task.task_id, mode=mode, tool_choice=tool_choice
)
if cur_step_output.is_last:
result_output = cur_step_output
break
# ensure tool_choice does not cause endless loops
tool_choice = "auto"
return self.finalize_response(task.task_id, result_output)
@trace_method("chat")
def chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> AgentChatResponse:
# override tool choice is provided as input.
if tool_choice is None:
tool_choice = self.default_tool_choice
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = self._chat(
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
async def achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> AgentChatResponse:
# override tool choice is provided as input.
if tool_choice is None:
tool_choice = self.default_tool_choice
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = await self._achat(
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
def stream_chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> StreamingAgentChatResponse:
# override tool choice is provided as input.
if tool_choice is None:
tool_choice = self.default_tool_choice
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = self._chat(
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
)
assert isinstance(chat_response, StreamingAgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
async def astream_chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> StreamingAgentChatResponse:
# override tool choice is provided as input.
if tool_choice is None:
tool_choice = self.default_tool_choice
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = await self._achat(
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
)
assert isinstance(chat_response, StreamingAgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
def undo_step(self, task_id: str) -> None:
"""Undo previous step."""
raise NotImplementedError("undo_step not implemented")

View File

@ -0,0 +1,472 @@
"""Agent executor."""
import asyncio
from collections import deque
from typing import Any, Deque, Dict, List, Optional, Union, cast
from llama_index.agent.runner.base import BaseAgentRunner
from llama_index.agent.types import (
BaseAgentWorker,
Task,
TaskStep,
TaskStepOutput,
)
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.callbacks import (
CallbackManager,
CBEventType,
EventPayload,
trace_method,
)
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
AgentChatResponse,
ChatResponseMode,
StreamingAgentChatResponse,
)
from llama_index.llms.base import ChatMessage
from llama_index.llms.llm import LLM
from llama_index.memory import BaseMemory, ChatMemoryBuffer
from llama_index.memory.types import BaseMemory
class DAGTaskState(BaseModel):
"""DAG Task state."""
task: Task = Field(..., description="Task.")
root_step: TaskStep = Field(..., description="Root step.")
step_queue: Deque[TaskStep] = Field(
default_factory=deque, description="Task step queue."
)
completed_steps: List[TaskStepOutput] = Field(
default_factory=list, description="Completed step outputs."
)
@property
def task_id(self) -> str:
"""Task id."""
return self.task.task_id
class DAGAgentState(BaseModel):
"""Agent state."""
task_dict: Dict[str, DAGTaskState] = Field(
default_factory=dict, description="Task dictionary."
)
def get_task(self, task_id: str) -> Task:
"""Get task state."""
return self.task_dict[task_id].task
def get_completed_steps(self, task_id: str) -> List[TaskStepOutput]:
"""Get completed steps."""
return self.task_dict[task_id].completed_steps
def get_step_queue(self, task_id: str) -> Deque[TaskStep]:
"""Get step queue."""
return self.task_dict[task_id].step_queue
class ParallelAgentRunner(BaseAgentRunner):
"""Parallel agent runner.
Executes steps in queue in parallel. Requires async support.
"""
def __init__(
self,
agent_worker: BaseAgentWorker,
chat_history: Optional[List[ChatMessage]] = None,
state: Optional[DAGAgentState] = None,
memory: Optional[BaseMemory] = None,
llm: Optional[LLM] = None,
callback_manager: Optional[CallbackManager] = None,
init_task_state_kwargs: Optional[dict] = None,
delete_task_on_finish: bool = False,
) -> None:
"""Initialize."""
self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm)
self.state = state or DAGAgentState()
self.callback_manager = callback_manager or CallbackManager([])
self.init_task_state_kwargs = init_task_state_kwargs or {}
self.agent_worker = agent_worker
self.delete_task_on_finish = delete_task_on_finish
@property
def chat_history(self) -> List[ChatMessage]:
return self.memory.get_all()
def reset(self) -> None:
self.memory.reset()
def create_task(self, input: str, **kwargs: Any) -> Task:
"""Create task."""
task = Task(
input=input,
memory=self.memory,
extra_state=self.init_task_state_kwargs,
**kwargs,
)
# # put input into memory
# self.memory.put(ChatMessage(content=input, role=MessageRole.USER))
# add it to state
# get initial step from task, and put it in the step queue
initial_step = self.agent_worker.initialize_step(task)
task_state = DAGTaskState(
task=task,
root_step=initial_step,
step_queue=deque([initial_step]),
)
self.state.task_dict[task.task_id] = task_state
return task
def delete_task(
self,
task_id: str,
) -> None:
"""Delete task.
NOTE: this will not delete any previous executions from memory.
"""
self.state.task_dict.pop(task_id)
def list_tasks(self, **kwargs: Any) -> List[Task]:
"""List tasks."""
task_states = list(self.state.task_dict.values())
return [task_state.task for task_state in task_states]
def get_task(self, task_id: str, **kwargs: Any) -> Task:
"""Get task."""
return self.state.get_task(task_id)
def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
"""Get upcoming steps."""
return list(self.state.get_step_queue(task_id))
def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
"""Get completed steps."""
return self.state.get_completed_steps(task_id)
def run_steps_in_queue(
self,
task_id: str,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any,
) -> List[TaskStepOutput]:
"""Execute steps in queue.
Run all steps in queue, clearing it out.
Assume that all steps can be run in parallel.
"""
return asyncio.run(self.arun_steps_in_queue(task_id, mode=mode, **kwargs))
async def arun_steps_in_queue(
self,
task_id: str,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any,
) -> List[TaskStepOutput]:
"""Execute all steps in queue.
All steps in queue are assumed to be ready.
"""
# first pop all steps from step_queue
steps: List[TaskStep] = []
while len(self.state.get_step_queue(task_id)) > 0:
steps.append(self.state.get_step_queue(task_id).popleft())
# take every item in the queue, and run it
tasks = []
for step in steps:
tasks.append(self._arun_step(task_id, step=step, mode=mode, **kwargs))
return await asyncio.gather(*tasks)
def _run_step(
self,
task_id: str,
step: Optional[TaskStep] = None,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any,
) -> TaskStepOutput:
"""Execute step."""
task = self.state.get_task(task_id)
task_queue = self.state.get_step_queue(task_id)
step = step or task_queue.popleft()
if not step.is_ready:
raise ValueError(f"Step {step.step_id} is not ready")
if mode == ChatResponseMode.WAIT:
cur_step_output: TaskStepOutput = self.agent_worker.run_step(
step, task, **kwargs
)
elif mode == ChatResponseMode.STREAM:
cur_step_output = self.agent_worker.stream_step(step, task, **kwargs)
else:
raise ValueError(f"Invalid mode: {mode}")
for next_step in cur_step_output.next_steps:
if next_step.is_ready:
task_queue.append(next_step)
# add cur_step_output to completed steps
completed_steps = self.state.get_completed_steps(task_id)
completed_steps.append(cur_step_output)
return cur_step_output
async def _arun_step(
self,
task_id: str,
step: Optional[TaskStep] = None,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any,
) -> TaskStepOutput:
"""Execute step."""
task = self.state.get_task(task_id)
task_queue = self.state.get_step_queue(task_id)
step = step or task_queue.popleft()
if not step.is_ready:
raise ValueError(f"Step {step.step_id} is not ready")
if mode == ChatResponseMode.WAIT:
cur_step_output = await self.agent_worker.arun_step(step, task, **kwargs)
elif mode == ChatResponseMode.STREAM:
cur_step_output = await self.agent_worker.astream_step(step, task, **kwargs)
else:
raise ValueError(f"Invalid mode: {mode}")
for next_step in cur_step_output.next_steps:
if next_step.is_ready:
task_queue.append(next_step)
# add cur_step_output to completed steps
completed_steps = self.state.get_completed_steps(task_id)
completed_steps.append(cur_step_output)
return cur_step_output
def run_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step."""
return self._run_step(task_id, step, mode=ChatResponseMode.WAIT, **kwargs)
async def arun_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (async)."""
return await self._arun_step(
task_id, step, mode=ChatResponseMode.WAIT, **kwargs
)
def stream_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (stream)."""
return self._run_step(task_id, step, mode=ChatResponseMode.STREAM, **kwargs)
async def astream_step(
self,
task_id: str,
input: Optional[str] = None,
step: Optional[TaskStep] = None,
**kwargs: Any,
) -> TaskStepOutput:
"""Run step (async stream)."""
return await self._arun_step(
task_id, step, mode=ChatResponseMode.STREAM, **kwargs
)
def finalize_response(
self,
task_id: str,
step_output: Optional[TaskStepOutput] = None,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Finalize response."""
if step_output is None:
step_output = self.state.get_completed_steps(task_id)[-1]
if not step_output.is_last:
raise ValueError(
"finalize_response can only be called on the last step output"
)
if not isinstance(
step_output.output,
(AgentChatResponse, StreamingAgentChatResponse),
):
raise ValueError(
"When `is_last` is True, cur_step_output.output must be "
f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}"
)
# finalize task
self.agent_worker.finalize_task(self.state.get_task(task_id))
if self.delete_task_on_finish:
self.delete_task(task_id)
return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output)
def _chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)
result_output = None
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_outputs = self.run_steps_in_queue(task.task_id, mode=mode)
# check if a step output is_last
is_last = any(
cur_step_output.is_last for cur_step_output in cur_step_outputs
)
if is_last:
if len(cur_step_outputs) > 1:
raise ValueError(
"More than one step output returned in final step."
)
cur_step_output = cur_step_outputs[0]
result_output = cur_step_output
break
return self.finalize_response(task.task_id, result_output)
async def _achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)
result_output = None
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_outputs = await self.arun_steps_in_queue(task.task_id, mode=mode)
# check if a step output is_last
is_last = any(
cur_step_output.is_last for cur_step_output in cur_step_outputs
)
if is_last:
if len(cur_step_outputs) > 1:
raise ValueError(
"More than one step output returned in final step."
)
cur_step_output = cur_step_outputs[0]
result_output = cur_step_output
break
return self.finalize_response(task.task_id, result_output)
@trace_method("chat")
def chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> AgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = self._chat(
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
async def achat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> AgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = await self._achat(
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
def stream_chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> StreamingAgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = self._chat(
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
)
assert isinstance(chat_response, StreamingAgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
@trace_method("chat")
async def astream_chat(
self,
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
) -> StreamingAgentChatResponse:
with self.callback_manager.event(
CBEventType.AGENT_STEP,
payload={EventPayload.MESSAGES: [message]},
) as e:
chat_response = await self._achat(
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
)
assert isinstance(chat_response, StreamingAgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
return chat_response
def undo_step(self, task_id: str) -> None:
"""Undo previous step."""
raise NotImplementedError("undo_step not implemented")

235
llama_index/agent/types.py Normal file
View File

@ -0,0 +1,235 @@
"""Base agent type."""
import uuid
from abc import abstractmethod
from typing import Any, Dict, List, Optional
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.callbacks import CallbackManager, trace_method
from llama_index.chat_engine.types import BaseChatEngine, StreamingAgentChatResponse
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.llms.types import ChatMessage
from llama_index.core.response.schema import RESPONSE_TYPE, Response
from llama_index.memory.types import BaseMemory
from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType
from llama_index.schema import QueryBundle
class BaseAgent(BaseChatEngine, BaseQueryEngine):
"""Base Agent."""
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
# TODO: the ReAct agent does not explicitly specify prompts, would need a
# refactor to expose those prompts
return {}
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt modules."""
return {}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
# ===== Query Engine Interface =====
@trace_method("query")
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
agent_response = self.chat(
query_bundle.query_str,
chat_history=[],
)
return Response(
response=str(agent_response), source_nodes=agent_response.source_nodes
)
@trace_method("query")
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
agent_response = await self.achat(
query_bundle.query_str,
chat_history=[],
)
return Response(
response=str(agent_response), source_nodes=agent_response.source_nodes
)
def stream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
raise NotImplementedError("stream_chat not implemented")
async def astream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
raise NotImplementedError("astream_chat not implemented")
class TaskStep(BaseModel):
"""Agent task step.
Represents a single input step within the execution run ("Task") of an agent
given a user input.
The output is returned as a `TaskStepOutput`.
"""
task_id: str = Field(..., diescription="Task ID")
step_id: str = Field(..., description="Step ID")
input: Optional[str] = Field(default=None, description="User input")
# memory: BaseMemory = Field(
# ..., type=BaseMemory, description="Conversational Memory"
# )
step_state: Dict[str, Any] = Field(
default_factory=dict, description="Additional state for a given step."
)
# NOTE: the state below may change throughout the course of execution
# this tracks the relationships to other steps
next_steps: Dict[str, "TaskStep"] = Field(
default_factory=dict, description="Next steps to be executed."
)
prev_steps: Dict[str, "TaskStep"] = Field(
default_factory=dict,
description="Previous steps that were dependencies for this step.",
)
is_ready: bool = Field(
default=True, description="Is this step ready to be executed?"
)
def get_next_step(
self,
step_id: str,
input: Optional[str] = None,
step_state: Optional[Dict[str, Any]] = None,
) -> "TaskStep":
"""Convenience function to get next step.
Preserve task_id, memory, step_state.
"""
return TaskStep(
task_id=self.task_id,
step_id=step_id,
input=input,
# memory=self.memory,
step_state=step_state or self.step_state,
)
def link_step(
self,
next_step: "TaskStep",
) -> None:
"""Link to next step.
Add link from this step to next, and from next step to current.
"""
self.next_steps[next_step.step_id] = next_step
next_step.prev_steps[self.step_id] = self
class TaskStepOutput(BaseModel):
"""Agent task step output."""
output: Any = Field(..., description="Task step output")
task_step: TaskStep = Field(..., description="Task step input")
next_steps: List[TaskStep] = Field(..., description="Next steps to be executed.")
is_last: bool = Field(default=False, description="Is this the last step?")
def __str__(self) -> str:
"""String representation."""
return str(self.output)
class Task(BaseModel):
"""Agent Task.
Represents a "run" of an agent given a user input.
"""
class Config:
arbitrary_types_allowed = True
task_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), type=str, description="Task ID"
)
input: str = Field(..., type=str, description="User input")
# NOTE: this is state that may be modified throughout the course of execution of the task
memory: BaseMemory = Field(
...,
type=BaseMemory,
description=(
"Conversational Memory. Maintains state before execution of this task."
),
)
callback_manager: CallbackManager = Field(
default_factory=CallbackManager,
exclude=True,
description="Callback manager for the task.",
)
extra_state: Dict[str, Any] = Field(
default_factory=dict,
description=(
"Additional user-specified state for a given task. "
"Can be modified throughout the execution of a task."
),
)
class BaseAgentWorker(PromptMixin):
"""Base agent worker."""
class Config:
arbitrary_types_allowed = True
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
# TODO: the ReAct agent does not explicitly specify prompts, would need a
# refactor to expose those prompts
return {}
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt modules."""
return {}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
@abstractmethod
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
"""Initialize step from task."""
@abstractmethod
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step."""
@abstractmethod
async def arun_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async)."""
raise NotImplementedError
@abstractmethod
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step (stream)."""
# TODO: figure out if we need a different type for TaskStepOutput
raise NotImplementedError
@abstractmethod
async def astream_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async stream)."""
raise NotImplementedError
@abstractmethod
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)

View File

@ -0,0 +1,17 @@
"""Agent utils."""
from llama_index.agent.types import TaskStep
from llama_index.core.llms.types import MessageRole
from llama_index.llms.base import ChatMessage
from llama_index.memory import BaseMemory
def add_user_step_to_memory(
step: TaskStep, memory: BaseMemory, verbose: bool = False
) -> None:
"""Add user step to memory."""
user_message = ChatMessage(content=step.input, role=MessageRole.USER)
memory.put(user_message)
if verbose:
print(f"Added user message to memory: {step.input}")

110
llama_index/async_utils.py Normal file
View File

@ -0,0 +1,110 @@
"""Async utils."""
import asyncio
from itertools import zip_longest
from typing import Any, Coroutine, Iterable, List
def asyncio_module(show_progress: bool = False) -> Any:
if show_progress:
from tqdm.asyncio import tqdm_asyncio
module = tqdm_asyncio
else:
module = asyncio
return module
def run_async_tasks(
tasks: List[Coroutine],
show_progress: bool = False,
progress_bar_desc: str = "Running async tasks",
) -> List[Any]:
"""Run a list of async tasks."""
tasks_to_execute: List[Any] = tasks
if show_progress:
try:
import nest_asyncio
from tqdm.asyncio import tqdm
# jupyter notebooks already have an event loop running
# we need to reuse it instead of creating a new one
nest_asyncio.apply()
loop = asyncio.get_event_loop()
async def _tqdm_gather() -> List[Any]:
return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc)
tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather())
return tqdm_outputs
# run the operation w/o tqdm on hitting a fatal
# may occur in some environments where tqdm.asyncio
# is not supported
except Exception:
pass
async def _gather() -> List[Any]:
return await asyncio.gather(*tasks_to_execute)
outputs: List[Any] = asyncio.run(_gather())
return outputs
def chunks(iterable: Iterable, size: int) -> Iterable:
args = [iter(iterable)] * size
return zip_longest(*args, fillvalue=None)
async def batch_gather(
tasks: List[Coroutine], batch_size: int = 10, verbose: bool = False
) -> List[Any]:
output: List[Any] = []
for task_chunk in chunks(tasks, batch_size):
output_chunk = await asyncio.gather(*task_chunk)
output.extend(output_chunk)
if verbose:
print(f"Completed {len(output)} out of {len(tasks)} tasks")
return output
def get_asyncio_module(show_progress: bool = False) -> Any:
if show_progress:
from tqdm.asyncio import tqdm_asyncio
module = tqdm_asyncio
else:
module = asyncio
return module
DEFAULT_NUM_WORKERS = 4
async def run_jobs(
jobs: List[Coroutine],
show_progress: bool = False,
workers: int = DEFAULT_NUM_WORKERS,
) -> List[Any]:
"""Run jobs.
Args:
jobs (List[Coroutine]):
List of jobs to run.
show_progress (bool):
Whether to show progress bar.
Returns:
List[Any]:
List of results.
"""
asyncio_mod = get_asyncio_module(show_progress=show_progress)
semaphore = asyncio.Semaphore(workers)
async def worker(job: Coroutine) -> Any:
async with semaphore:
return await job
pool_jobs = [worker(job) for job in jobs]
return await asyncio_mod.gather(*pool_jobs)

View File

View File

@ -0,0 +1,108 @@
import langchain
from langchain.agents import AgentExecutor, AgentType, initialize_agent
# agents and tools
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.base_language import BaseLanguageModel
# callback
from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model
from langchain.chat_models.base import BaseChatModel
from langchain.docstore.document import Document
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
# chat and memory
from langchain.memory.chat_memory import BaseChatMemory
from langchain.output_parsers import ResponseSchema
# prompts
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import (
AIMessagePromptTemplate,
BaseMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
# schema
from langchain.schema import (
AIMessage,
BaseMemory,
BaseMessage,
BaseOutputParser,
ChatGeneration,
ChatMessage,
FunctionMessage,
HumanMessage,
LLMResult,
SystemMessage,
)
# embeddings
from langchain.schema.embeddings import Embeddings
from langchain.schema.prompt_template import BasePromptTemplate
# input & output
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain.tools import BaseTool, StructuredTool, Tool
from langchain_community.chat_models import ChatAnyscale, ChatOpenAI
from langchain_community.embeddings import (
HuggingFaceBgeEmbeddings,
HuggingFaceEmbeddings,
)
# LLMs
from langchain_community.llms import AI21, BaseLLM, Cohere, FakeListLLM, OpenAI
__all__ = [
"langchain",
"BaseLLM",
"FakeListLLM",
"OpenAI",
"AI21",
"Cohere",
"BaseChatModel",
"ChatAnyscale",
"ChatOpenAI",
"BaseLanguageModel",
"Embeddings",
"HuggingFaceEmbeddings",
"HuggingFaceBgeEmbeddings",
"PromptTemplate",
"BasePromptTemplate",
"ConditionalPromptSelector",
"is_chat_model",
"AIMessagePromptTemplate",
"ChatPromptTemplate",
"HumanMessagePromptTemplate",
"BaseMessagePromptTemplate",
"SystemMessagePromptTemplate",
"BaseChatMemory",
"ConversationBufferMemory",
"ChatMessageHistory",
"BaseToolkit",
"AgentType",
"AgentExecutor",
"initialize_agent",
"StructuredTool",
"Tool",
"BaseTool",
"ResponseSchema",
"BaseCallbackHandler",
"BaseCallbackManager",
"AIMessage",
"FunctionMessage",
"BaseMessage",
"ChatMessage",
"HumanMessage",
"SystemMessage",
"BaseMemory",
"BaseOutputParser",
"LLMResult",
"ChatGeneration",
"Document",
"RecursiveCharacterTextSplitter",
"TextSplitter",
]

View File

@ -0,0 +1,51 @@
try:
import pydantic.v1 as pydantic
from pydantic.v1 import (
BaseConfig,
BaseModel,
Field,
PrivateAttr,
StrictFloat,
StrictInt,
StrictStr,
create_model,
root_validator,
validator,
)
from pydantic.v1.error_wrappers import ValidationError
from pydantic.v1.fields import FieldInfo
from pydantic.v1.generics import GenericModel
except ImportError:
import pydantic # type: ignore
from pydantic import (
BaseConfig,
BaseModel,
Field,
PrivateAttr,
StrictFloat,
StrictInt,
StrictStr,
create_model,
root_validator,
validator,
)
from pydantic.error_wrappers import ValidationError
from pydantic.fields import FieldInfo
from pydantic.generics import GenericModel
__all__ = [
"pydantic",
"BaseModel",
"Field",
"PrivateAttr",
"root_validator",
"validator",
"create_model",
"StrictFloat",
"StrictInt",
"StrictStr",
"FieldInfo",
"ValidationError",
"GenericModel",
"BaseConfig",
]

View File

@ -0,0 +1,24 @@
from .aim import AimCallback
from .base import CallbackManager
from .finetuning_handler import GradientAIFineTuningHandler, OpenAIFineTuningHandler
from .llama_debug import LlamaDebugHandler
from .open_inference_callback import OpenInferenceCallbackHandler
from .schema import CBEvent, CBEventType, EventPayload
from .token_counting import TokenCountingHandler
from .utils import trace_method
from .wandb_callback import WandbCallbackHandler
__all__ = [
"OpenInferenceCallbackHandler",
"CallbackManager",
"CBEvent",
"CBEventType",
"EventPayload",
"LlamaDebugHandler",
"AimCallback",
"WandbCallbackHandler",
"TokenCountingHandler",
"OpenAIFineTuningHandler",
"GradientAIFineTuningHandler",
"trace_method",
]

View File

@ -0,0 +1,191 @@
import logging
from typing import Any, Dict, List, Optional
try:
from aim import Run, Text
except ModuleNotFoundError:
Run, Text = None, None
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
class AimCallback(BaseCallbackHandler):
"""
AimCallback callback class.
Args:
repo (:obj:`str`, optional):
Aim repository path or Repo object to which Run object is bound.
If skipped, default Repo is used.
experiment_name (:obj:`str`, optional):
Sets Run's `experiment` property. 'default' if not specified.
Can be used later to query runs/sequences.
system_tracking_interval (:obj:`int`, optional):
Sets the tracking interval in seconds for system usage
metrics (CPU, Memory, etc.). Set to `None` to disable
system metrics tracking.
log_system_params (:obj:`bool`, optional):
Enable/Disable logging of system params such as installed packages,
git info, environment variables, etc.
capture_terminal_logs (:obj:`bool`, optional):
Enable/Disable terminal stdout logging.
event_starts_to_ignore (Optional[List[CBEventType]]):
list of event types to ignore when tracking event starts.
event_ends_to_ignore (Optional[List[CBEventType]]):
list of event types to ignore when tracking event ends.
"""
def __init__(
self,
repo: Optional[str] = None,
experiment_name: Optional[str] = None,
system_tracking_interval: Optional[int] = 1,
log_system_params: Optional[bool] = True,
capture_terminal_logs: Optional[bool] = True,
event_starts_to_ignore: Optional[List[CBEventType]] = None,
event_ends_to_ignore: Optional[List[CBEventType]] = None,
run_params: Optional[Dict[str, Any]] = None,
) -> None:
if Run is None:
raise ModuleNotFoundError(
"Please install aim to use the AimCallback: 'pip install aim'"
)
event_starts_to_ignore = (
event_starts_to_ignore if event_starts_to_ignore else []
)
event_ends_to_ignore = event_ends_to_ignore if event_ends_to_ignore else []
super().__init__(
event_starts_to_ignore=event_starts_to_ignore,
event_ends_to_ignore=event_ends_to_ignore,
)
self.repo = repo
self.experiment_name = experiment_name
self.system_tracking_interval = system_tracking_interval
self.log_system_params = log_system_params
self.capture_terminal_logs = capture_terminal_logs
self._run: Optional[Any] = None
self._run_hash = None
self._llm_response_step = 0
self.setup(run_params)
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""
Args:
event_type (CBEventType): event type to store.
payload (Optional[Dict[str, Any]]): payload to store.
event_id (str): event id to store.
parent_id (str): parent event id.
"""
return ""
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""
Args:
event_type (CBEventType): event type to store.
payload (Optional[Dict[str, Any]]): payload to store.
event_id (str): event id to store.
"""
if not self._run:
raise ValueError("AimCallback failed to init properly.")
if event_type is CBEventType.LLM and payload:
if EventPayload.PROMPT in payload:
llm_input = str(payload[EventPayload.PROMPT])
llm_output = str(payload[EventPayload.COMPLETION])
else:
message = payload.get(EventPayload.MESSAGES, [])
llm_input = "\n".join([str(x) for x in message])
llm_output = str(payload[EventPayload.RESPONSE])
self._run.track(
Text(llm_input),
name="prompt",
step=self._llm_response_step,
context={"event_id": event_id},
)
self._run.track(
Text(llm_output),
name="response",
step=self._llm_response_step,
context={"event_id": event_id},
)
self._llm_response_step += 1
elif event_type is CBEventType.CHUNKING and payload:
for chunk_id, chunk in enumerate(payload[EventPayload.CHUNKS]):
self._run.track(
Text(chunk),
name="chunk",
step=self._llm_response_step,
context={"chunk_id": chunk_id, "event_id": event_id},
)
@property
def experiment(self) -> Run:
if not self._run:
self.setup()
return self._run
def setup(self, args: Optional[Dict[str, Any]] = None) -> None:
if not self._run:
if self._run_hash:
self._run = Run(
self._run_hash,
repo=self.repo,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
capture_terminal_logs=self.capture_terminal_logs,
)
else:
self._run = Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
capture_terminal_logs=self.capture_terminal_logs,
)
self._run_hash = self._run.hash
# Log config parameters
if args:
try:
for key in args:
self._run.set(key, args[key], strict=False)
except Exception as e:
logger.warning(f"Aim could not log config parameters -> {e}")
def __del__(self) -> None:
if self._run and self._run.active:
self._run.close()
def start_trace(self, trace_id: Optional[str] = None) -> None:
pass
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
pass

View File

@ -0,0 +1,12 @@
from typing import Any
from llama_index.callbacks.base_handler import BaseCallbackHandler
def argilla_callback_handler(**kwargs: Any) -> BaseCallbackHandler:
try:
# lazy import
from argilla_llama_index import ArgillaCallbackHandler
except ImportError:
raise ImportError("Please install Argilla with `pip install argilla`")
return ArgillaCallbackHandler(**kwargs)

View File

@ -0,0 +1,13 @@
from typing import Any
from llama_index.callbacks.base_handler import BaseCallbackHandler
def arize_phoenix_callback_handler(**kwargs: Any) -> BaseCallbackHandler:
try:
from phoenix.trace.llama_index import OpenInferenceTraceCallbackHandler
except ImportError:
raise ImportError(
"Please install Arize Phoenix with `pip install -q arize-phoenix`"
)
return OpenInferenceTraceCallbackHandler(**kwargs)

View File

@ -0,0 +1,274 @@
import logging
import uuid
from abc import ABC
from collections import defaultdict
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, Generator, List, Optional
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import (
BASE_TRACE_EVENT,
LEAF_EVENTS,
CBEventType,
EventPayload,
)
logger = logging.getLogger(__name__)
global_stack_trace = ContextVar("trace", default=[BASE_TRACE_EVENT])
empty_trace_ids: List[str] = []
global_stack_trace_ids = ContextVar("trace_ids", default=empty_trace_ids)
class CallbackManager(BaseCallbackHandler, ABC):
"""
Callback manager that handles callbacks for events within LlamaIndex.
The callback manager provides a way to call handlers on event starts/ends.
Additionally, the callback manager traces the current stack of events.
It does this by using a few key attributes.
- trace_stack - The current stack of events that have not ended yet.
When an event ends, it's removed from the stack.
Since this is a contextvar, it is unique to each
thread/task.
- trace_map - A mapping of event ids to their children events.
On the start of events, the bottom of the trace stack
is used as the current parent event for the trace map.
- trace_id - A simple name for the current trace, usually denoting the
entrypoint (query, index_construction, insert, etc.)
Args:
handlers (List[BaseCallbackHandler]): list of handlers to use.
Usage:
with callback_manager.event(CBEventType.QUERY) as event:
event.on_start(payload={key, val})
...
event.on_end(payload={key, val})
"""
def __init__(self, handlers: Optional[List[BaseCallbackHandler]] = None):
"""Initialize the manager with a list of handlers."""
from llama_index import global_handler
handlers = handlers or []
# add eval handlers based on global defaults
if global_handler is not None:
new_handler = global_handler
# go through existing handlers, check if any are same type as new handler
# if so, error
for existing_handler in handlers:
if isinstance(existing_handler, type(new_handler)):
raise ValueError(
"Cannot add two handlers of the same type "
f"{type(new_handler)} to the callback manager."
)
handlers.append(new_handler)
self.handlers = handlers
self._trace_map: Dict[str, List[str]] = defaultdict(list)
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: Optional[str] = None,
parent_id: Optional[str] = None,
**kwargs: Any,
) -> str:
"""Run handlers when an event starts and return id of event."""
event_id = event_id or str(uuid.uuid4())
# if no trace is running, start a default trace
try:
parent_id = parent_id or global_stack_trace.get()[-1]
except IndexError:
self.start_trace("llama-index")
parent_id = global_stack_trace.get()[-1]
self._trace_map[parent_id].append(event_id)
for handler in self.handlers:
if event_type not in handler.event_starts_to_ignore:
handler.on_event_start(
event_type,
payload,
event_id=event_id,
parent_id=parent_id,
**kwargs,
)
if event_type not in LEAF_EVENTS:
# copy the stack trace to prevent conflicts with threads/coroutines
current_trace_stack = global_stack_trace.get().copy()
current_trace_stack.append(event_id)
global_stack_trace.set(current_trace_stack)
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run handlers when an event ends."""
event_id = event_id or str(uuid.uuid4())
for handler in self.handlers:
if event_type not in handler.event_ends_to_ignore:
handler.on_event_end(event_type, payload, event_id=event_id, **kwargs)
if event_type not in LEAF_EVENTS:
# copy the stack trace to prevent conflicts with threads/coroutines
current_trace_stack = global_stack_trace.get().copy()
current_trace_stack.pop()
global_stack_trace.set(current_trace_stack)
def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = handlers
@contextmanager
def event(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: Optional[str] = None,
) -> Generator["EventContext", None, None]:
"""Context manager for lanching and shutdown of events.
Handles sending on_evnt_start and on_event_end to handlers for specified event.
Usage:
with callback_manager.event(CBEventType.QUERY, payload={key, val}) as event:
...
event.on_end(payload={key, val}) # optional
"""
# create event context wrapper
event = EventContext(self, event_type, event_id=event_id)
event.on_start(payload=payload)
payload = None
try:
yield event
except Exception as e:
# data already logged to trace?
if not hasattr(e, "event_added"):
payload = {EventPayload.EXCEPTION: e}
e.event_added = True # type: ignore
if not event.finished:
event.on_end(payload=payload)
raise
finally:
# ensure event is ended
if not event.finished:
event.on_end(payload=payload)
@contextmanager
def as_trace(self, trace_id: str) -> Generator[None, None, None]:
"""Context manager tracer for lanching and shutdown of traces."""
self.start_trace(trace_id=trace_id)
try:
yield
except Exception as e:
# event already added to trace?
if not hasattr(e, "event_added"):
self.on_event_start(
CBEventType.EXCEPTION, payload={EventPayload.EXCEPTION: e}
)
e.event_added = True # type: ignore
raise
finally:
# ensure trace is ended
self.end_trace(trace_id=trace_id)
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
current_trace_stack_ids = global_stack_trace_ids.get().copy()
if trace_id is not None:
if len(current_trace_stack_ids) == 0:
self._reset_trace_events()
for handler in self.handlers:
handler.start_trace(trace_id=trace_id)
current_trace_stack_ids = [trace_id]
else:
current_trace_stack_ids.append(trace_id)
global_stack_trace_ids.set(current_trace_stack_ids)
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""
current_trace_stack_ids = global_stack_trace_ids.get().copy()
if trace_id is not None and len(current_trace_stack_ids) > 0:
current_trace_stack_ids.pop()
if len(current_trace_stack_ids) == 0:
for handler in self.handlers:
handler.end_trace(trace_id=trace_id, trace_map=self._trace_map)
current_trace_stack_ids = []
global_stack_trace_ids.set(current_trace_stack_ids)
def _reset_trace_events(self) -> None:
"""Helper function to reset the current trace."""
self._trace_map = defaultdict(list)
global_stack_trace.set([BASE_TRACE_EVENT])
@property
def trace_map(self) -> Dict[str, List[str]]:
return self._trace_map
class EventContext:
"""
Simple wrapper to call callbacks on event starts and ends
with an event type and id.
"""
def __init__(
self,
callback_manager: CallbackManager,
event_type: CBEventType,
event_id: Optional[str] = None,
):
self._callback_manager = callback_manager
self._event_type = event_type
self._event_id = event_id or str(uuid.uuid4())
self.started = False
self.finished = False
def on_start(self, payload: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
if not self.started:
self.started = True
self._callback_manager.on_event_start(
self._event_type, payload=payload, event_id=self._event_id, **kwargs
)
else:
logger.warning(
f"Event {self._event_type!s}: {self._event_id} already started!"
)
def on_end(self, payload: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
if not self.finished:
self.finished = True
self._callback_manager.on_event_end(
self._event_type, payload=payload, event_id=self._event_id, **kwargs
)

View File

@ -0,0 +1,55 @@
import logging
from abc import ABC, abstractmethod
from contextvars import ContextVar
from typing import Any, Dict, List, Optional
from llama_index.callbacks.schema import BASE_TRACE_EVENT, CBEventType
logger = logging.getLogger(__name__)
global_stack_trace = ContextVar("trace", default=[BASE_TRACE_EVENT])
class BaseCallbackHandler(ABC):
"""Base callback handler that can be used to track event starts and ends."""
def __init__(
self,
event_starts_to_ignore: List[CBEventType],
event_ends_to_ignore: List[CBEventType],
) -> None:
"""Initialize the base callback handler."""
self.event_starts_to_ignore = tuple(event_starts_to_ignore)
self.event_ends_to_ignore = tuple(event_ends_to_ignore)
@abstractmethod
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Run when an event starts and return id of event."""
@abstractmethod
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Run when an event ends."""
@abstractmethod
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
@abstractmethod
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""

View File

@ -0,0 +1,11 @@
from typing import Any
from llama_index.callbacks.base_handler import BaseCallbackHandler
def deepeval_callback_handler(**kwargs: Any) -> BaseCallbackHandler:
try:
from deepeval.tracing.integrations.llama_index import LlamaIndexCallbackHandler
except ImportError:
raise ImportError("Please install DeepEval with `pip install -U deepeval`")
return LlamaIndexCallbackHandler(**kwargs)

View File

@ -0,0 +1,215 @@
import json
from abc import abstractmethod
from typing import Any, Dict, List, Optional
from llama_index.callbacks.base import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
class BaseFinetuningHandler(BaseCallbackHandler):
"""
Callback handler for finetuning.
This handler will collect all messages
sent to the LLM, along with their responses.
It also defines a `get_finetuning_events` endpoint as well as a
`save_finetuning_events` endpoint.
"""
def __init__(self) -> None:
"""Initialize the base callback handler."""
super().__init__(
event_starts_to_ignore=[],
event_ends_to_ignore=[],
)
self._finetuning_events: Dict[str, List[Any]] = {}
self._function_calls: Dict[str, List[Any]] = {}
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Run when an event starts and return id of event."""
from llama_index.core.llms.types import ChatMessage, MessageRole
if event_type == CBEventType.LLM:
cur_messages = []
if payload and EventPayload.PROMPT in payload:
message = ChatMessage(
role=MessageRole.USER, text=str(payload[EventPayload.PROMPT])
)
cur_messages = [message]
elif payload and EventPayload.MESSAGES in payload:
cur_messages = payload[EventPayload.MESSAGES]
if len(cur_messages) > 0:
if event_id in self._finetuning_events:
self._finetuning_events[event_id].extend(cur_messages)
else:
self._finetuning_events[event_id] = cur_messages
# if functions exists, add that
if payload and EventPayload.ADDITIONAL_KWARGS in payload:
kwargs_dict = payload[EventPayload.ADDITIONAL_KWARGS]
if "functions" in kwargs_dict:
self._function_calls[event_id] = kwargs_dict["functions"]
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Run when an event ends."""
from llama_index.core.llms.types import ChatMessage, MessageRole
if (
event_type == CBEventType.LLM
and event_id in self._finetuning_events
and payload is not None
):
if isinstance(payload[EventPayload.RESPONSE], str):
response = ChatMessage(
role=MessageRole.ASSISTANT, text=str(payload[EventPayload.RESPONSE])
)
else:
response = payload[EventPayload.RESPONSE].message
self._finetuning_events[event_id].append(response)
@abstractmethod
def get_finetuning_events(self) -> Dict[str, Dict[str, Any]]:
"""Get finetuning events."""
@abstractmethod
def save_finetuning_events(self, path: str) -> None:
"""Save the finetuning events to a file."""
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""
class OpenAIFineTuningHandler(BaseFinetuningHandler):
"""
Callback handler for OpenAI fine-tuning.
This handler will collect all messages
sent to the LLM, along with their responses. It will then save these messages
in a `.jsonl` format that can be used for fine-tuning with OpenAI's API.
"""
def get_finetuning_events(self) -> Dict[str, Dict[str, Any]]:
events_dict = {}
for event_id, event in self._finetuning_events.items():
events_dict[event_id] = {"messages": event[:-1], "response": event[-1]}
return events_dict
def save_finetuning_events(self, path: str) -> None:
"""
Save the finetuning events to a file.
This saved format can be used for fine-tuning with OpenAI's API.
The structure for each json line is as follows:
{
messages: [
{ rol: "system", content: "Text"},
{ role: "user", content: "Text" },
]
},
...
"""
from llama_index.llms.openai_utils import to_openai_message_dicts
events_dict = self.get_finetuning_events()
json_strs = []
for event_id, event in events_dict.items():
all_messages = event["messages"] + [event["response"]]
message_dicts = to_openai_message_dicts(all_messages, drop_none=True)
event_dict = {"messages": message_dicts}
if event_id in self._function_calls:
event_dict["functions"] = self._function_calls[event_id]
json_strs.append(json.dumps(event_dict))
with open(path, "w") as f:
f.write("\n".join(json_strs))
print(f"Wrote {len(json_strs)} examples to {path}")
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""
class GradientAIFineTuningHandler(BaseFinetuningHandler):
"""
Callback handler for Gradient AI fine-tuning.
This handler will collect all messages
sent to the LLM, along with their responses. It will then save these messages
in a `.jsonl` format that can be used for fine-tuning with Gradient AI's API.
"""
def get_finetuning_events(self) -> Dict[str, Dict[str, Any]]:
events_dict = {}
for event_id, event in self._finetuning_events.items():
events_dict[event_id] = {"messages": event[:-1], "response": event[-1]}
return events_dict
def save_finetuning_events(self, path: str) -> None:
"""
Save the finetuning events to a file.
This saved format can be used for fine-tuning with OpenAI's API.
The structure for each json line is as follows:
{
"inputs": "<full_prompt_str>"
},
...
"""
from llama_index.llms.generic_utils import messages_to_history_str
events_dict = self.get_finetuning_events()
json_strs = []
for event in events_dict.values():
all_messages = event["messages"] + [event["response"]]
# TODO: come up with model-specific message->prompt serialization format
prompt_str = messages_to_history_str(all_messages)
input_dict = {"inputs": prompt_str}
json_strs.append(json.dumps(input_dict))
with open(path, "w") as f:
f.write("\n".join(json_strs))
print(f"Wrote {len(json_strs)} examples to {path}")
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""

View File

@ -0,0 +1,44 @@
"""Global eval handlers."""
from typing import Any
from llama_index.callbacks.argilla_callback import argilla_callback_handler
from llama_index.callbacks.arize_phoenix_callback import arize_phoenix_callback_handler
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.deepeval_callback import deepeval_callback_handler
from llama_index.callbacks.honeyhive_callback import honeyhive_callback_handler
from llama_index.callbacks.open_inference_callback import OpenInferenceCallbackHandler
from llama_index.callbacks.promptlayer_handler import PromptLayerHandler
from llama_index.callbacks.simple_llm_handler import SimpleLLMHandler
from llama_index.callbacks.wandb_callback import WandbCallbackHandler
def set_global_handler(eval_mode: str, **eval_params: Any) -> None:
"""Set global eval handlers."""
import llama_index
llama_index.global_handler = create_global_handler(eval_mode, **eval_params)
def create_global_handler(eval_mode: str, **eval_params: Any) -> BaseCallbackHandler:
"""Get global eval handler."""
if eval_mode == "wandb":
handler: BaseCallbackHandler = WandbCallbackHandler(**eval_params)
elif eval_mode == "openinference":
handler = OpenInferenceCallbackHandler(**eval_params)
elif eval_mode == "arize_phoenix":
handler = arize_phoenix_callback_handler(**eval_params)
elif eval_mode == "honeyhive":
handler = honeyhive_callback_handler(**eval_params)
elif eval_mode == "promptlayer":
handler = PromptLayerHandler(**eval_params)
elif eval_mode == "deepeval":
handler = deepeval_callback_handler(**eval_params)
elif eval_mode == "simple":
handler = SimpleLLMHandler(**eval_params)
elif eval_mode == "argilla":
handler = argilla_callback_handler(**eval_params)
else:
raise ValueError(f"Eval mode {eval_mode} not supported.")
return handler

View File

@ -0,0 +1,11 @@
from typing import Any
from llama_index.callbacks.base_handler import BaseCallbackHandler
def honeyhive_callback_handler(**kwargs: Any) -> BaseCallbackHandler:
try:
from honeyhive.utils.llamaindex_tracer import HoneyHiveLlamaIndexTracer
except ImportError:
raise ImportError("Please install HoneyHive with `pip install honeyhive`")
return HoneyHiveLlamaIndexTracer(**kwargs)

View File

@ -0,0 +1,205 @@
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, List, Optional
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import (
BASE_TRACE_EVENT,
TIMESTAMP_FORMAT,
CBEvent,
CBEventType,
EventStats,
)
class LlamaDebugHandler(BaseCallbackHandler):
"""Callback handler that keeps track of debug info.
NOTE: this is a beta feature. The usage within our codebase, and the interface
may change.
This handler simply keeps track of event starts/ends, separated by event types.
You can use this callback handler to keep track of and debug events.
Args:
event_starts_to_ignore (Optional[List[CBEventType]]): list of event types to
ignore when tracking event starts.
event_ends_to_ignore (Optional[List[CBEventType]]): list of event types to
ignore when tracking event ends.
"""
def __init__(
self,
event_starts_to_ignore: Optional[List[CBEventType]] = None,
event_ends_to_ignore: Optional[List[CBEventType]] = None,
print_trace_on_end: bool = True,
) -> None:
"""Initialize the llama debug handler."""
self._event_pairs_by_type: Dict[CBEventType, List[CBEvent]] = defaultdict(list)
self._event_pairs_by_id: Dict[str, List[CBEvent]] = defaultdict(list)
self._sequential_events: List[CBEvent] = []
self._cur_trace_id: Optional[str] = None
self._trace_map: Dict[str, List[str]] = defaultdict(list)
self.print_trace_on_end = print_trace_on_end
event_starts_to_ignore = (
event_starts_to_ignore if event_starts_to_ignore else []
)
event_ends_to_ignore = event_ends_to_ignore if event_ends_to_ignore else []
super().__init__(
event_starts_to_ignore=event_starts_to_ignore,
event_ends_to_ignore=event_ends_to_ignore,
)
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Store event start data by event type.
Args:
event_type (CBEventType): event type to store.
payload (Optional[Dict[str, Any]]): payload to store.
event_id (str): event id to store.
parent_id (str): parent event id.
"""
event = CBEvent(event_type, payload=payload, id_=event_id)
self._event_pairs_by_type[event.event_type].append(event)
self._event_pairs_by_id[event.id_].append(event)
self._sequential_events.append(event)
return event.id_
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Store event end data by event type.
Args:
event_type (CBEventType): event type to store.
payload (Optional[Dict[str, Any]]): payload to store.
event_id (str): event id to store.
"""
event = CBEvent(event_type, payload=payload, id_=event_id)
self._event_pairs_by_type[event.event_type].append(event)
self._event_pairs_by_id[event.id_].append(event)
self._sequential_events.append(event)
self._trace_map = defaultdict(list)
def get_events(self, event_type: Optional[CBEventType] = None) -> List[CBEvent]:
"""Get all events for a specific event type."""
if event_type is not None:
return self._event_pairs_by_type[event_type]
return self._sequential_events
def _get_event_pairs(self, events: List[CBEvent]) -> List[List[CBEvent]]:
"""Helper function to pair events according to their ID."""
event_pairs: Dict[str, List[CBEvent]] = defaultdict(list)
for event in events:
event_pairs[event.id_].append(event)
return sorted(
event_pairs.values(),
key=lambda x: datetime.strptime(x[0].time, TIMESTAMP_FORMAT),
)
def _get_time_stats_from_event_pairs(
self, event_pairs: List[List[CBEvent]]
) -> EventStats:
"""Calculate time-based stats for a set of event pairs."""
total_secs = 0.0
for event_pair in event_pairs:
start_time = datetime.strptime(event_pair[0].time, TIMESTAMP_FORMAT)
end_time = datetime.strptime(event_pair[-1].time, TIMESTAMP_FORMAT)
total_secs += (end_time - start_time).total_seconds()
return EventStats(
total_secs=total_secs,
average_secs=total_secs / len(event_pairs),
total_count=len(event_pairs),
)
def get_event_pairs(
self, event_type: Optional[CBEventType] = None
) -> List[List[CBEvent]]:
"""Pair events by ID, either all events or a specific type."""
if event_type is not None:
return self._get_event_pairs(self._event_pairs_by_type[event_type])
return self._get_event_pairs(self._sequential_events)
def get_llm_inputs_outputs(self) -> List[List[CBEvent]]:
"""Get the exact LLM inputs and outputs."""
return self._get_event_pairs(self._event_pairs_by_type[CBEventType.LLM])
def get_event_time_info(
self, event_type: Optional[CBEventType] = None
) -> EventStats:
event_pairs = self.get_event_pairs(event_type)
return self._get_time_stats_from_event_pairs(event_pairs)
def flush_event_logs(self) -> None:
"""Clear all events from memory."""
self._event_pairs_by_type = defaultdict(list)
self._event_pairs_by_id = defaultdict(list)
self._sequential_events = []
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Launch a trace."""
self._trace_map = defaultdict(list)
self._cur_trace_id = trace_id
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Shutdown the current trace."""
self._trace_map = trace_map or defaultdict(list)
if self.print_trace_on_end:
self.print_trace_map()
def _print_trace_map(self, cur_event_id: str, level: int = 0) -> None:
"""Recursively print trace map to terminal for debugging."""
event_pair = self._event_pairs_by_id[cur_event_id]
if event_pair:
time_stats = self._get_time_stats_from_event_pairs([event_pair])
indent = " " * level * 2
print(
f"{indent}|_{event_pair[0].event_type} -> ",
f"{time_stats.total_secs} seconds",
flush=True,
)
child_event_ids = self._trace_map[cur_event_id]
for child_event_id in child_event_ids:
self._print_trace_map(child_event_id, level=level + 1)
def print_trace_map(self) -> None:
"""Print simple trace map to terminal for debugging of the most recent trace."""
print("*" * 10, flush=True)
print(f"Trace: {self._cur_trace_id}", flush=True)
self._print_trace_map(BASE_TRACE_EVENT, level=1)
print("*" * 10, flush=True)
@property
def event_pairs_by_type(self) -> Dict[CBEventType, List[CBEvent]]:
return self._event_pairs_by_type
@property
def events_pairs_by_id(self) -> Dict[str, List[CBEvent]]:
return self._event_pairs_by_id
@property
def sequential_events(self) -> List[CBEvent]:
return self._sequential_events

View File

@ -0,0 +1,247 @@
"""
Callback handler for storing generation data in OpenInference format.
OpenInference is an open standard for capturing and storing AI model inferences.
It enables production LLMapp servers to seamlessly integrate with LLM
observability solutions such as Arize and Phoenix.
For more information on the specification, see
https://github.com/Arize-ai/open-inference-spec
"""
import importlib
import uuid
from dataclasses import dataclass, field, fields
from datetime import datetime
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, TypeVar
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
if TYPE_CHECKING:
from pandas import DataFrame
OPENINFERENCE_COLUMN_NAME = "openinference_column_name"
Embedding = List[float]
def _generate_random_id() -> str:
"""Generates a random ID.
Returns:
str: A random ID.
"""
return str(uuid.uuid4())
@dataclass
class QueryData:
"""
Query data with column names following the OpenInference specification.
"""
id: str = field(
default_factory=_generate_random_id,
metadata={OPENINFERENCE_COLUMN_NAME: ":id.id:"},
)
timestamp: Optional[str] = field(
default=None, metadata={OPENINFERENCE_COLUMN_NAME: ":timestamp.iso_8601:"}
)
query_text: Optional[str] = field(
default=None,
metadata={OPENINFERENCE_COLUMN_NAME: ":feature.text:prompt"},
)
query_embedding: Optional[Embedding] = field(
default=None,
metadata={OPENINFERENCE_COLUMN_NAME: ":feature.[float].embedding:prompt"},
)
response_text: Optional[str] = field(
default=None, metadata={OPENINFERENCE_COLUMN_NAME: ":prediction.text:response"}
)
node_ids: List[str] = field(
default_factory=list,
metadata={
OPENINFERENCE_COLUMN_NAME: ":feature.[str].retrieved_document_ids:prompt"
},
)
scores: List[float] = field(
default_factory=list,
metadata={
OPENINFERENCE_COLUMN_NAME: (
":feature.[float].retrieved_document_scores:prompt"
)
},
)
@dataclass
class NodeData:
"""Node data."""
id: str
node_text: Optional[str] = None
node_embedding: Optional[Embedding] = None
BaseDataType = TypeVar("BaseDataType", QueryData, NodeData)
def as_dataframe(data: Iterable[BaseDataType]) -> "DataFrame":
"""Converts a list of BaseDataType to a pandas dataframe.
Args:
data (Iterable[BaseDataType]): A list of BaseDataType.
Returns:
DataFrame: The converted pandas dataframe.
"""
pandas = _import_package("pandas")
as_dict_list = []
for datum in data:
as_dict = {
field.metadata.get(OPENINFERENCE_COLUMN_NAME, field.name): getattr(
datum, field.name
)
for field in fields(datum)
}
as_dict_list.append(as_dict)
return pandas.DataFrame(as_dict_list)
@dataclass
class TraceData:
"""Trace data."""
query_data: QueryData = field(default_factory=QueryData)
node_datas: List[NodeData] = field(default_factory=list)
def _import_package(package_name: str) -> ModuleType:
"""Dynamically imports a package.
Args:
package_name (str): Name of the package to import.
Raises:
ImportError: If the package is not installed.
Returns:
ModuleType: The imported package.
"""
try:
package = importlib.import_module(package_name)
except ImportError:
raise ImportError(f"The {package_name} package must be installed.")
return package
class OpenInferenceCallbackHandler(BaseCallbackHandler):
"""Callback handler for storing generation data in OpenInference format.
OpenInference is an open standard for capturing and storing AI model
inferences. It enables production LLMapp servers to seamlessly integrate
with LLM observability solutions such as Arize and Phoenix.
For more information on the specification, see
https://github.com/Arize-ai/open-inference-spec
"""
def __init__(
self,
callback: Optional[Callable[[List[QueryData], List[NodeData]], None]] = None,
) -> None:
"""Initializes the OpenInferenceCallbackHandler.
Args:
callback (Optional[Callable[[List[QueryData], List[NodeData]], None]], optional): A
callback function that will be called when a query trace is
completed, often used for logging or persisting query data.
"""
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
self._callback = callback
self._trace_data = TraceData()
self._query_data_buffer: List[QueryData] = []
self._node_data_buffer: List[NodeData] = []
def start_trace(self, trace_id: Optional[str] = None) -> None:
if trace_id == "query":
self._trace_data = TraceData()
self._trace_data.query_data.timestamp = datetime.now().isoformat()
self._trace_data.query_data.id = _generate_random_id()
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
if trace_id == "query":
self._query_data_buffer.append(self._trace_data.query_data)
self._node_data_buffer.extend(self._trace_data.node_datas)
self._trace_data = TraceData()
if self._callback is not None:
self._callback(self._query_data_buffer, self._node_data_buffer)
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
if payload is not None:
if event_type is CBEventType.QUERY:
query_text = payload[EventPayload.QUERY_STR]
self._trace_data.query_data.query_text = query_text
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
if payload is None:
return
if event_type is CBEventType.RETRIEVE:
for node_with_score in payload[EventPayload.NODES]:
node = node_with_score.node
score = node_with_score.score
self._trace_data.query_data.node_ids.append(node.hash)
self._trace_data.query_data.scores.append(score)
self._trace_data.node_datas.append(
NodeData(
id=node.hash,
node_text=node.text,
)
)
elif event_type is CBEventType.LLM:
self._trace_data.query_data.response_text = str(
payload.get(EventPayload.RESPONSE, "")
) or str(payload.get(EventPayload.COMPLETION, ""))
elif event_type is CBEventType.EMBEDDING:
self._trace_data.query_data.query_embedding = payload[
EventPayload.EMBEDDINGS
][0]
def flush_query_data_buffer(self) -> List[QueryData]:
"""Clears the query data buffer and returns the data.
Returns:
List[QueryData]: The query data.
"""
query_data_buffer = self._query_data_buffer
self._query_data_buffer = []
return query_data_buffer
def flush_node_data_buffer(self) -> List[NodeData]:
"""Clears the node data buffer and returns the data.
Returns:
List[NodeData]: The node data.
"""
node_data_buffer = self._node_data_buffer
self._node_data_buffer = []
return node_data_buffer

View File

@ -0,0 +1,136 @@
import datetime
from typing import Any, Dict, List, Optional, Union, cast
from llama_index.bridge.pydantic import BaseModel
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.llms import ChatMessage
PROMPT_LAYER_CHAT_FUNCTION_NAME = "llamaindex.chat.openai"
PROMPT_LAYER_COMPLETION_FUNCTION_NAME = "llamaindex.completion.openai"
class PromptLayerHandler(BaseCallbackHandler):
"""Callback handler for sending to promptlayer.com."""
pl_tags: Optional[List[str]]
return_pl_id: bool = False
def __init__(self, pl_tags: List[str] = [], return_pl_id: bool = False) -> None:
try:
from promptlayer.utils import get_api_key, promptlayer_api_request
self._promptlayer_api_request = promptlayer_api_request
self._promptlayer_api_key = get_api_key()
except ImportError:
raise ImportError(
"Please install PromptLAyer with `pip install promptlayer`"
)
self.pl_tags = pl_tags
self.return_pl_id = return_pl_id
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
def start_trace(self, trace_id: Optional[str] = None) -> None:
return
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
return
event_map: Dict[str, Dict[str, Any]] = {}
def add_event(self, event_id: str, **kwargs: Any) -> None:
self.event_map[event_id] = {
"kwargs": kwargs,
"request_start_time": datetime.datetime.now().timestamp(),
}
def get_event(
self,
event_id: str,
) -> Dict[str, Any]:
return self.event_map[event_id] or {}
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
if event_type == CBEventType.LLM and payload is not None:
self.add_event(
event_id=event_id, **payload.get(EventPayload.SERIALIZED, {})
)
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
if event_type != CBEventType.LLM or payload is None:
return
request_end_time = datetime.datetime.now().timestamp()
prompt = str(payload.get(EventPayload.PROMPT))
completion = payload.get(EventPayload.COMPLETION)
response = payload.get(EventPayload.RESPONSE)
function_name = PROMPT_LAYER_CHAT_FUNCTION_NAME
event_data = self.get_event(event_id=event_id)
resp: Union[str, Dict]
extra_args = {}
if response:
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
resp = response.message.dict()
assert isinstance(resp, dict)
usage_dict: Dict[str, int] = {}
try:
usage = response.raw.get("usage", None) # type: ignore
if isinstance(usage, dict):
usage_dict = {
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
}
elif isinstance(usage, BaseModel):
usage_dict = usage.dict()
except Exception:
pass
extra_args = {
"messages": [message.dict() for message in messages],
"usage": usage_dict,
}
## promptlayer needs tool_calls toplevel.
if "tool_calls" in response.message.additional_kwargs:
resp["tool_calls"] = [
tool_call.dict()
for tool_call in resp["additional_kwargs"]["tool_calls"]
]
del resp["additional_kwargs"]["tool_calls"]
if completion:
function_name = PROMPT_LAYER_COMPLETION_FUNCTION_NAME
resp = str(completion)
pl_request_id = self._promptlayer_api_request(
function_name,
"openai",
[prompt],
{
**extra_args,
**event_data["kwargs"],
},
self.pl_tags,
[resp],
event_data["request_start_time"],
request_end_time,
self._promptlayer_api_key,
return_pl_id=self.return_pl_id,
)

View File

@ -0,0 +1,98 @@
"""Base schema for callback managers."""
import uuid
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional
# timestamp for callback events
TIMESTAMP_FORMAT = "%m/%d/%Y, %H:%M:%S.%f"
# base trace_id for the tracemap in callback_manager
BASE_TRACE_EVENT = "root"
class CBEventType(str, Enum):
"""Callback manager event types.
Attributes:
CHUNKING: Logs for the before and after of text splitting.
NODE_PARSING: Logs for the documents and the nodes that they are parsed into.
EMBEDDING: Logs for the number of texts embedded.
LLM: Logs for the template and response of LLM calls.
QUERY: Keeps track of the start and end of each query.
RETRIEVE: Logs for the nodes retrieved for a query.
SYNTHESIZE: Logs for the result for synthesize calls.
TREE: Logs for the summary and level of summaries generated.
SUB_QUESTION: Logs for a generated sub question and answer.
"""
CHUNKING = "chunking"
NODE_PARSING = "node_parsing"
EMBEDDING = "embedding"
LLM = "llm"
QUERY = "query"
RETRIEVE = "retrieve"
SYNTHESIZE = "synthesize"
TREE = "tree"
SUB_QUESTION = "sub_question"
TEMPLATING = "templating"
FUNCTION_CALL = "function_call"
RERANKING = "reranking"
EXCEPTION = "exception"
AGENT_STEP = "agent_step"
class EventPayload(str, Enum):
DOCUMENTS = "documents" # list of documents before parsing
CHUNKS = "chunks" # list of text chunks
NODES = "nodes" # list of nodes
PROMPT = "formatted_prompt" # formatted prompt sent to LLM
MESSAGES = "messages" # list of messages sent to LLM
COMPLETION = "completion" # completion from LLM
RESPONSE = "response" # message response from LLM
QUERY_STR = "query_str" # query used for query engine
SUB_QUESTION = "sub_question" # a sub question & answer + sources
EMBEDDINGS = "embeddings" # list of embeddings
TOP_K = "top_k" # top k nodes retrieved
ADDITIONAL_KWARGS = "additional_kwargs" # additional kwargs for event call
SERIALIZED = "serialized" # serialized object for event caller
FUNCTION_CALL = "function_call" # function call for the LLM
FUNCTION_OUTPUT = "function_call_response" # function call output
TOOL = "tool" # tool used in LLM call
MODEL_NAME = "model_name" # model name used in an event
TEMPLATE = "template" # template used in LLM call
TEMPLATE_VARS = "template_vars" # template variables used in LLM call
SYSTEM_PROMPT = "system_prompt" # system prompt used in LLM call
QUERY_WRAPPER_PROMPT = "query_wrapper_prompt" # query wrapper prompt used in LLM
EXCEPTION = "exception" # exception raised in an event
# events that will never have children events
LEAF_EVENTS = (CBEventType.CHUNKING, CBEventType.LLM, CBEventType.EMBEDDING)
@dataclass
class CBEvent:
"""Generic class to store event information."""
event_type: CBEventType
payload: Optional[Dict[str, Any]] = None
time: str = ""
id_: str = ""
def __post_init__(self) -> None:
"""Init time and id if needed."""
if not self.time:
self.time = datetime.now().strftime(TIMESTAMP_FORMAT)
if not self.id_:
self.id = str(uuid.uuid4())
@dataclass
class EventStats:
"""Time-based Statistics for events."""
total_secs: float
average_secs: float
total_count: int

View File

@ -0,0 +1,65 @@
from typing import Any, Dict, List, Optional, cast
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
class SimpleLLMHandler(BaseCallbackHandler):
"""Callback handler for printing llms inputs/outputs."""
def __init__(self) -> None:
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
def start_trace(self, trace_id: Optional[str] = None) -> None:
return
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
return
def _print_llm_event(self, payload: dict) -> None:
from llama_index.llms import ChatMessage
if EventPayload.PROMPT in payload:
prompt = str(payload.get(EventPayload.PROMPT))
completion = str(payload.get(EventPayload.COMPLETION))
print(f"** Prompt: **\n{prompt}")
print("*" * 50)
print(f"** Completion: **\n{completion}")
print("*" * 50)
print("\n")
elif EventPayload.MESSAGES in payload:
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
messages_str = "\n".join([str(x) for x in messages])
response = str(payload.get(EventPayload.RESPONSE))
print(f"** Messages: **\n{messages_str}")
print("*" * 50)
print(f"** Response: **\n{response}")
print("*" * 50)
print("\n")
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Count the LLM or Embedding tokens as needed."""
if event_type == CBEventType.LLM and payload is not None:
self._print_llm_event(payload)

View File

@ -0,0 +1,216 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, cast
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.utilities.token_counting import TokenCounter
from llama_index.utils import get_tokenizer
@dataclass
class TokenCountingEvent:
prompt: str
completion: str
completion_token_count: int
prompt_token_count: int
total_token_count: int = 0
event_id: str = ""
def __post_init__(self) -> None:
self.total_token_count = self.prompt_token_count + self.completion_token_count
def get_llm_token_counts(
token_counter: TokenCounter, payload: Dict[str, Any], event_id: str = ""
) -> TokenCountingEvent:
from llama_index.llms import ChatMessage
if EventPayload.PROMPT in payload:
prompt = str(payload.get(EventPayload.PROMPT))
completion = str(payload.get(EventPayload.COMPLETION))
return TokenCountingEvent(
event_id=event_id,
prompt=prompt,
prompt_token_count=token_counter.get_string_tokens(prompt),
completion=completion,
completion_token_count=token_counter.get_string_tokens(completion),
)
elif EventPayload.MESSAGES in payload:
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
messages_str = "\n".join([str(x) for x in messages])
response = payload.get(EventPayload.RESPONSE)
response_str = str(response)
# try getting attached token counts first
try:
messages_tokens = 0
response_tokens = 0
if response is not None and response.raw is not None:
usage = response.raw.get("usage", None)
if usage is not None:
if not isinstance(usage, dict):
usage = dict(usage)
messages_tokens = usage.get("prompt_tokens", 0)
response_tokens = usage.get("completion_tokens", 0)
if messages_tokens == 0 or response_tokens == 0:
raise ValueError("Invalid token counts!")
return TokenCountingEvent(
event_id=event_id,
prompt=messages_str,
prompt_token_count=messages_tokens,
completion=response_str,
completion_token_count=response_tokens,
)
except (ValueError, KeyError):
# Invalid token counts, or no token counts attached
pass
# Should count tokens ourselves
messages_tokens = token_counter.estimate_tokens_in_messages(messages)
response_tokens = token_counter.get_string_tokens(response_str)
return TokenCountingEvent(
event_id=event_id,
prompt=messages_str,
prompt_token_count=messages_tokens,
completion=response_str,
completion_token_count=response_tokens,
)
else:
raise ValueError(
"Invalid payload! Need prompt and completion or messages and response."
)
class TokenCountingHandler(BaseCallbackHandler):
"""Callback handler for counting tokens in LLM and Embedding events.
Args:
tokenizer:
Tokenizer to use. Defaults to the global tokenizer
(see llama_index.utils.globals_helper).
event_starts_to_ignore: List of event types to ignore at the start of a trace.
event_ends_to_ignore: List of event types to ignore at the end of a trace.
"""
def __init__(
self,
tokenizer: Optional[Callable[[str], List]] = None,
event_starts_to_ignore: Optional[List[CBEventType]] = None,
event_ends_to_ignore: Optional[List[CBEventType]] = None,
verbose: bool = False,
) -> None:
self.llm_token_counts: List[TokenCountingEvent] = []
self.embedding_token_counts: List[TokenCountingEvent] = []
self.tokenizer = tokenizer or get_tokenizer()
self._token_counter = TokenCounter(tokenizer=self.tokenizer)
self._verbose = verbose
super().__init__(
event_starts_to_ignore=event_starts_to_ignore or [],
event_ends_to_ignore=event_ends_to_ignore or [],
)
def start_trace(self, trace_id: Optional[str] = None) -> None:
return
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
return
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Count the LLM or Embedding tokens as needed."""
if (
event_type == CBEventType.LLM
and event_type not in self.event_ends_to_ignore
and payload is not None
):
self.llm_token_counts.append(
get_llm_token_counts(
token_counter=self._token_counter,
payload=payload,
event_id=event_id,
)
)
if self._verbose:
print(
"LLM Prompt Token Usage: "
f"{self.llm_token_counts[-1].prompt_token_count}\n"
"LLM Completion Token Usage: "
f"{self.llm_token_counts[-1].completion_token_count}",
flush=True,
)
elif (
event_type == CBEventType.EMBEDDING
and event_type not in self.event_ends_to_ignore
and payload is not None
):
total_chunk_tokens = 0
for chunk in payload.get(EventPayload.CHUNKS, []):
self.embedding_token_counts.append(
TokenCountingEvent(
event_id=event_id,
prompt=chunk,
prompt_token_count=self._token_counter.get_string_tokens(chunk),
completion="",
completion_token_count=0,
)
)
total_chunk_tokens += self.embedding_token_counts[-1].total_token_count
if self._verbose:
print(f"Embedding Token Usage: {total_chunk_tokens}", flush=True)
@property
def total_llm_token_count(self) -> int:
"""Get the current total LLM token count."""
return sum([x.total_token_count for x in self.llm_token_counts])
@property
def prompt_llm_token_count(self) -> int:
"""Get the current total LLM prompt token count."""
return sum([x.prompt_token_count for x in self.llm_token_counts])
@property
def completion_llm_token_count(self) -> int:
"""Get the current total LLM completion token count."""
return sum([x.completion_token_count for x in self.llm_token_counts])
@property
def total_embedding_token_count(self) -> int:
"""Get the current total Embedding token count."""
return sum([x.total_token_count for x in self.embedding_token_counts])
def reset_counts(self) -> None:
"""Reset the token counts."""
self.llm_token_counts = []
self.embedding_token_counts = []

View File

@ -0,0 +1,60 @@
import asyncio
import functools
import logging
from typing import Any, Callable, cast
from llama_index.callbacks.base import CallbackManager
logger = logging.getLogger(__name__)
def trace_method(
trace_id: str, callback_manager_attr: str = "callback_manager"
) -> Callable[[Callable], Callable]:
"""
Decorator to trace a method.
Example:
@trace_method("my_trace_id")
def my_method(self):
pass
Assumes that the self instance has a CallbackManager instance in an attribute
named `callback_manager`.
This can be overridden by passing in a `callback_manager_attr` keyword argument.
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func) # preserve signature, name, etc. of func
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
try:
callback_manager = getattr(self, callback_manager_attr)
except AttributeError:
logger.warning(
"Could not find attribute %s on %s.",
callback_manager_attr,
type(self),
)
return func(self, *args, **kwargs)
callback_manager = cast(CallbackManager, callback_manager)
with callback_manager.as_trace(trace_id):
return func(self, *args, **kwargs)
@functools.wraps(func) # preserve signature, name, etc. of func
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
try:
callback_manager = getattr(self, callback_manager_attr)
except AttributeError:
logger.warning(
"Could not find attribute %s on %s.",
callback_manager_attr,
type(self),
)
return await func(self, *args, **kwargs)
callback_manager = cast(CallbackManager, callback_manager)
with callback_manager.as_trace(trace_id):
return await func(self, *args, **kwargs)
return async_wrapper if asyncio.iscoroutinefunction(func) else wrapper
return decorator

View File

@ -0,0 +1,570 @@
import os
import shutil
from collections import defaultdict
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
TypedDict,
Union,
)
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import (
TIMESTAMP_FORMAT,
CBEvent,
CBEventType,
EventPayload,
)
from llama_index.callbacks.token_counting import get_llm_token_counts
from llama_index.utilities.token_counting import TokenCounter
from llama_index.utils import get_tokenizer
if TYPE_CHECKING:
from wandb import Settings as WBSettings
from wandb.sdk.data_types import trace_tree
from llama_index.indices import (
ComposableGraph,
GPTEmptyIndex,
GPTKeywordTableIndex,
GPTRAKEKeywordTableIndex,
GPTSimpleKeywordTableIndex,
GPTSQLStructStoreIndex,
GPTTreeIndex,
GPTVectorStoreIndex,
SummaryIndex,
)
from llama_index.storage.storage_context import StorageContext
IndexType = Union[
ComposableGraph,
GPTKeywordTableIndex,
GPTSimpleKeywordTableIndex,
GPTRAKEKeywordTableIndex,
SummaryIndex,
GPTEmptyIndex,
GPTTreeIndex,
GPTVectorStoreIndex,
GPTSQLStructStoreIndex,
]
# remove this class
class WandbRunArgs(TypedDict):
job_type: Optional[str]
dir: Optional[str]
config: Union[Dict, str, None]
project: Optional[str]
entity: Optional[str]
reinit: Optional[bool]
tags: Optional[Sequence]
group: Optional[str]
name: Optional[str]
notes: Optional[str]
magic: Optional[Union[dict, str, bool]]
config_exclude_keys: Optional[List[str]]
config_include_keys: Optional[List[str]]
anonymous: Optional[str]
mode: Optional[str]
allow_val_change: Optional[bool]
resume: Optional[Union[bool, str]]
force: Optional[bool]
tensorboard: Optional[bool]
sync_tensorboard: Optional[bool]
monitor_gym: Optional[bool]
save_code: Optional[bool]
id: Optional[str]
settings: Union["WBSettings", Dict[str, Any], None]
class WandbCallbackHandler(BaseCallbackHandler):
"""Callback handler that logs events to wandb.
NOTE: this is a beta feature. The usage within our codebase, and the interface
may change.
Use the `WandbCallbackHandler` to log trace events to wandb. This handler is
useful for debugging and visualizing the trace events. It captures the payload of
the events and logs them to wandb. The handler also tracks the start and end of
events. This is particularly useful for debugging your LLM calls.
The `WandbCallbackHandler` can also be used to log the indices and graphs to wandb
using the `persist_index` method. This will save the indexes as artifacts in wandb.
The `load_storage_context` method can be used to load the indexes from wandb
artifacts. This method will return a `StorageContext` object that can be used to
build the index, using `load_index_from_storage`, `load_indices_from_storage` or
`load_graph_from_storage` functions.
Args:
event_starts_to_ignore (Optional[List[CBEventType]]): list of event types to
ignore when tracking event starts.
event_ends_to_ignore (Optional[List[CBEventType]]): list of event types to
ignore when tracking event ends.
"""
def __init__(
self,
run_args: Optional[WandbRunArgs] = None,
tokenizer: Optional[Callable[[str], List]] = None,
event_starts_to_ignore: Optional[List[CBEventType]] = None,
event_ends_to_ignore: Optional[List[CBEventType]] = None,
) -> None:
try:
import wandb
from wandb.sdk.data_types import trace_tree
self._wandb = wandb
self._trace_tree = trace_tree
except ImportError:
raise ImportError(
"WandbCallbackHandler requires wandb. "
"Please install it with `pip install wandb`."
)
from llama_index.indices import (
ComposableGraph,
GPTEmptyIndex,
GPTKeywordTableIndex,
GPTRAKEKeywordTableIndex,
GPTSimpleKeywordTableIndex,
GPTSQLStructStoreIndex,
GPTTreeIndex,
GPTVectorStoreIndex,
SummaryIndex,
)
self._IndexType = (
ComposableGraph,
GPTKeywordTableIndex,
GPTSimpleKeywordTableIndex,
GPTRAKEKeywordTableIndex,
SummaryIndex,
GPTEmptyIndex,
GPTTreeIndex,
GPTVectorStoreIndex,
GPTSQLStructStoreIndex,
)
self._run_args = run_args
# Check if a W&B run is already initialized; if not, initialize one
self._ensure_run(should_print_url=(self._wandb.run is None))
self._event_pairs_by_id: Dict[str, List[CBEvent]] = defaultdict(list)
self._cur_trace_id: Optional[str] = None
self._trace_map: Dict[str, List[str]] = defaultdict(list)
self.tokenizer = tokenizer or get_tokenizer()
self._token_counter = TokenCounter(tokenizer=self.tokenizer)
event_starts_to_ignore = (
event_starts_to_ignore if event_starts_to_ignore else []
)
event_ends_to_ignore = event_ends_to_ignore if event_ends_to_ignore else []
super().__init__(
event_starts_to_ignore=event_starts_to_ignore,
event_ends_to_ignore=event_ends_to_ignore,
)
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Store event start data by event type.
Args:
event_type (CBEventType): event type to store.
payload (Optional[Dict[str, Any]]): payload to store.
event_id (str): event id to store.
parent_id (str): parent event id.
"""
event = CBEvent(event_type, payload=payload, id_=event_id)
self._event_pairs_by_id[event.id_].append(event)
return event.id_
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Store event end data by event type.
Args:
event_type (CBEventType): event type to store.
payload (Optional[Dict[str, Any]]): payload to store.
event_id (str): event id to store.
"""
event = CBEvent(event_type, payload=payload, id_=event_id)
self._event_pairs_by_id[event.id_].append(event)
self._trace_map = defaultdict(list)
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Launch a trace."""
self._trace_map = defaultdict(list)
self._cur_trace_id = trace_id
self._start_time = datetime.now()
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
# Ensure W&B run is initialized
self._ensure_run()
self._trace_map = trace_map or defaultdict(list)
self._end_time = datetime.now()
# Log the trace map to wandb
# We can control what trace ids we want to log here.
self.log_trace_tree()
# TODO (ayulockin): Log the LLM token counts to wandb when weave is ready
def log_trace_tree(self) -> None:
"""Log the trace tree to wandb."""
try:
child_nodes = self._trace_map["root"]
root_span = self._convert_event_pair_to_wb_span(
self._event_pairs_by_id[child_nodes[0]],
trace_id=self._cur_trace_id if len(child_nodes) > 1 else None,
)
if len(child_nodes) == 1:
child_nodes = self._trace_map[child_nodes[0]]
root_span = self._build_trace_tree(child_nodes, root_span)
else:
root_span = self._build_trace_tree(child_nodes, root_span)
if root_span:
root_trace = self._trace_tree.WBTraceTree(root_span)
if self._wandb.run:
self._wandb.run.log({"trace": root_trace})
self._wandb.termlog("Logged trace tree to W&B.")
except Exception as e:
print(f"Failed to log trace tree to W&B: {e}")
# ignore errors to not break user code
def persist_index(
self, index: "IndexType", index_name: str, persist_dir: Union[str, None] = None
) -> None:
"""Upload an index to wandb as an artifact. You can learn more about W&B
artifacts here: https://docs.wandb.ai/guides/artifacts.
For the `ComposableGraph` index, the root id is stored as artifact metadata.
Args:
index (IndexType): index to upload.
index_name (str): name of the index. This will be used as the artifact name.
persist_dir (Union[str, None]): directory to persist the index. If None, a
temporary directory will be created and used.
"""
if persist_dir is None:
persist_dir = f"{self._wandb.run.dir}/storage" # type: ignore
_default_persist_dir = True
if not os.path.exists(persist_dir):
os.makedirs(persist_dir)
if isinstance(index, self._IndexType):
try:
index.storage_context.persist(persist_dir) # type: ignore
metadata = None
# For the `ComposableGraph` index, store the root id as metadata
if isinstance(index, self._IndexType[0]):
root_id = index.root_id
metadata = {"root_id": root_id}
self._upload_index_as_wb_artifact(persist_dir, index_name, metadata)
except Exception as e:
# Silently ignore errors to not break user code
self._print_upload_index_fail_message(e)
# clear the default storage dir
if _default_persist_dir:
shutil.rmtree(persist_dir, ignore_errors=True)
def load_storage_context(
self, artifact_url: str, index_download_dir: Union[str, None] = None
) -> "StorageContext":
"""Download an index from wandb and return a storage context.
Use this storage context to load the index into memory using
`load_index_from_storage`, `load_indices_from_storage` or
`load_graph_from_storage` functions.
Args:
artifact_url (str): url of the artifact to download. The artifact url will
be of the form: `entity/project/index_name:version` and can be found in
the W&B UI.
index_download_dir (Union[str, None]): directory to download the index to.
"""
from llama_index.storage.storage_context import StorageContext
artifact = self._wandb.use_artifact(artifact_url, type="storage_context")
artifact_dir = artifact.download(root=index_download_dir)
return StorageContext.from_defaults(persist_dir=artifact_dir)
def _upload_index_as_wb_artifact(
self, dir_path: str, artifact_name: str, metadata: Optional[Dict]
) -> None:
"""Utility function to upload a dir to W&B as an artifact."""
artifact = self._wandb.Artifact(artifact_name, type="storage_context")
if metadata:
artifact.metadata = metadata
artifact.add_dir(dir_path)
self._wandb.run.log_artifact(artifact) # type: ignore
def _build_trace_tree(
self, events: List[str], span: "trace_tree.Span"
) -> "trace_tree.Span":
"""Build the trace tree from the trace map."""
for child_event in events:
child_span = self._convert_event_pair_to_wb_span(
self._event_pairs_by_id[child_event]
)
child_span = self._build_trace_tree(
self._trace_map[child_event], child_span
)
span.add_child_span(child_span)
return span
def _convert_event_pair_to_wb_span(
self,
event_pair: List[CBEvent],
trace_id: Optional[str] = None,
) -> "trace_tree.Span":
"""Convert a pair of events to a wandb trace tree span."""
start_time_ms, end_time_ms = self._get_time_in_ms(event_pair)
if trace_id is None:
event_type = event_pair[0].event_type
span_kind = self._map_event_type_to_span_kind(event_type)
else:
event_type = trace_id # type: ignore
span_kind = None
wb_span = self._trace_tree.Span(
name=f"{event_type}",
span_kind=span_kind,
start_time_ms=start_time_ms,
end_time_ms=end_time_ms,
)
inputs, outputs, wb_span = self._add_payload_to_span(wb_span, event_pair)
wb_span.add_named_result(inputs=inputs, outputs=outputs) # type: ignore
return wb_span
def _map_event_type_to_span_kind(
self, event_type: CBEventType
) -> Union[None, "trace_tree.SpanKind"]:
"""Map a CBEventType to a wandb trace tree SpanKind."""
if event_type == CBEventType.CHUNKING:
span_kind = None
elif event_type == CBEventType.NODE_PARSING:
span_kind = None
elif event_type == CBEventType.EMBEDDING:
# TODO: add span kind for EMBEDDING when it's available
span_kind = None
elif event_type == CBEventType.LLM:
span_kind = self._trace_tree.SpanKind.LLM
elif event_type == CBEventType.QUERY:
span_kind = self._trace_tree.SpanKind.AGENT
elif event_type == CBEventType.AGENT_STEP:
span_kind = self._trace_tree.SpanKind.AGENT
elif event_type == CBEventType.RETRIEVE:
span_kind = self._trace_tree.SpanKind.TOOL
elif event_type == CBEventType.SYNTHESIZE:
span_kind = self._trace_tree.SpanKind.CHAIN
elif event_type == CBEventType.TREE:
span_kind = self._trace_tree.SpanKind.CHAIN
elif event_type == CBEventType.SUB_QUESTION:
span_kind = self._trace_tree.SpanKind.CHAIN
elif event_type == CBEventType.RERANKING:
span_kind = self._trace_tree.SpanKind.CHAIN
elif event_type == CBEventType.FUNCTION_CALL:
span_kind = self._trace_tree.SpanKind.TOOL
else:
span_kind = None
return span_kind
def _add_payload_to_span(
self, span: "trace_tree.Span", event_pair: List[CBEvent]
) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]], "trace_tree.Span"]:
"""Add the event's payload to the span."""
assert len(event_pair) == 2
event_type = event_pair[0].event_type
inputs = None
outputs = None
if event_type == CBEventType.NODE_PARSING:
# TODO: disabled full detailed inputs/outputs due to UI lag
inputs, outputs = self._handle_node_parsing_payload(event_pair)
elif event_type == CBEventType.LLM:
inputs, outputs, span = self._handle_llm_payload(event_pair, span)
elif event_type == CBEventType.QUERY:
inputs, outputs = self._handle_query_payload(event_pair)
elif event_type == CBEventType.EMBEDDING:
inputs, outputs = self._handle_embedding_payload(event_pair)
return inputs, outputs, span
def _handle_node_parsing_payload(
self, event_pair: List[CBEvent]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Handle the payload of a NODE_PARSING event."""
inputs = event_pair[0].payload
outputs = event_pair[-1].payload
if inputs and EventPayload.DOCUMENTS in inputs:
documents = inputs.pop(EventPayload.DOCUMENTS)
inputs["num_documents"] = len(documents)
if outputs and EventPayload.NODES in outputs:
nodes = outputs.pop(EventPayload.NODES)
outputs["num_nodes"] = len(nodes)
return inputs or {}, outputs or {}
def _handle_llm_payload(
self, event_pair: List[CBEvent], span: "trace_tree.Span"
) -> Tuple[Dict[str, Any], Dict[str, Any], "trace_tree.Span"]:
"""Handle the payload of a LLM event."""
inputs = event_pair[0].payload
outputs = event_pair[-1].payload
assert isinstance(inputs, dict) and isinstance(outputs, dict)
# Get `original_template` from Prompt
if EventPayload.PROMPT in inputs:
inputs[EventPayload.PROMPT] = inputs[EventPayload.PROMPT]
# Format messages
if EventPayload.MESSAGES in inputs:
inputs[EventPayload.MESSAGES] = "\n".join(
[str(x) for x in inputs[EventPayload.MESSAGES]]
)
token_counts = get_llm_token_counts(self._token_counter, outputs)
metadata = {
"formatted_prompt_tokens_count": token_counts.prompt_token_count,
"prediction_tokens_count": token_counts.completion_token_count,
"total_tokens_used": token_counts.total_token_count,
}
span.attributes = metadata
# Make `response` part of `outputs`
outputs = {EventPayload.RESPONSE: str(outputs[EventPayload.RESPONSE])}
return inputs, outputs, span
def _handle_query_payload(
self, event_pair: List[CBEvent]
) -> Tuple[Optional[Dict[str, Any]], Dict[str, Any]]:
"""Handle the payload of a QUERY event."""
inputs = event_pair[0].payload
outputs = event_pair[-1].payload
if outputs:
response_obj = outputs[EventPayload.RESPONSE]
response = str(outputs[EventPayload.RESPONSE])
if type(response).__name__ == "Response":
response = response_obj.response
elif type(response).__name__ == "StreamingResponse":
response = response_obj.get_response().response
else:
response = " "
outputs = {"response": response}
return inputs, outputs
def _handle_embedding_payload(
self,
event_pair: List[CBEvent],
) -> Tuple[Optional[Dict[str, Any]], Dict[str, Any]]:
event_pair[0].payload
outputs = event_pair[-1].payload
chunks = []
if outputs:
chunks = outputs.get(EventPayload.CHUNKS, [])
return {}, {"num_chunks": len(chunks)}
def _get_time_in_ms(self, event_pair: List[CBEvent]) -> Tuple[int, int]:
"""Get the start and end time of an event pair in milliseconds."""
start_time = datetime.strptime(event_pair[0].time, TIMESTAMP_FORMAT)
end_time = datetime.strptime(event_pair[1].time, TIMESTAMP_FORMAT)
start_time_in_ms = int(
(start_time - datetime(1970, 1, 1)).total_seconds() * 1000
)
end_time_in_ms = int((end_time - datetime(1970, 1, 1)).total_seconds() * 1000)
return start_time_in_ms, end_time_in_ms
def _ensure_run(self, should_print_url: bool = False) -> None:
"""Ensures an active W&B run exists.
If not, will start a new run with the provided run_args.
"""
if self._wandb.run is None:
# Make a shallow copy of the run args, so we don't modify the original
run_args = self._run_args or {} # type: ignore
run_args: dict = {**run_args} # type: ignore
# Prefer to run in silent mode since W&B has a lot of output
# which can be undesirable when dealing with text-based models.
if "settings" not in run_args: # type: ignore
run_args["settings"] = {"silent": True} # type: ignore
# Start the run and add the stream table
self._wandb.init(**run_args)
self._wandb.run._label(repo="llama_index") # type: ignore
if should_print_url:
self._print_wandb_init_message(
self._wandb.run.settings.run_url # type: ignore
)
def _print_wandb_init_message(self, run_url: str) -> None:
"""Print a message to the terminal when W&B is initialized."""
self._wandb.termlog(
f"Streaming LlamaIndex events to W&B at {run_url}\n"
"`WandbCallbackHandler` is currently in beta.\n"
"Please report any issues to https://github.com/wandb/wandb/issues "
"with the tag `llamaindex`."
)
def _print_upload_index_fail_message(self, e: Exception) -> None:
"""Print a message to the terminal when uploading the index fails."""
self._wandb.termlog(
f"Failed to upload index to W&B with the following error: {e}\n"
)
def finish(self) -> None:
"""Finish the callback handler."""
self._wandb.finish()

View File

@ -0,0 +1,11 @@
from llama_index.chat_engine.condense_plus_context import CondensePlusContextChatEngine
from llama_index.chat_engine.condense_question import CondenseQuestionChatEngine
from llama_index.chat_engine.context import ContextChatEngine
from llama_index.chat_engine.simple import SimpleChatEngine
__all__ = [
"SimpleChatEngine",
"CondenseQuestionChatEngine",
"ContextChatEngine",
"CondensePlusContextChatEngine",
]

View File

@ -0,0 +1,362 @@
import asyncio
import logging
from threading import Thread
from typing import Any, List, Optional, Tuple
from llama_index.callbacks import CallbackManager, trace_method
from llama_index.chat_engine.types import (
AgentChatResponse,
BaseChatEngine,
StreamingAgentChatResponse,
ToolOutput,
)
from llama_index.core.llms.types import ChatMessage, MessageRole
from llama_index.indices.base_retriever import BaseRetriever
from llama_index.indices.query.schema import QueryBundle
from llama_index.indices.service_context import ServiceContext
from llama_index.llms.generic_utils import messages_to_history_str
from llama_index.llms.llm import LLM
from llama_index.memory import BaseMemory, ChatMemoryBuffer
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.prompts.base import PromptTemplate
from llama_index.schema import MetadataMode, NodeWithScore
from llama_index.utilities.token_counting import TokenCounter
logger = logging.getLogger(__name__)
DEFAULT_CONTEXT_PROMPT_TEMPLATE = """
The following is a friendly conversation between a user and an AI assistant.
The assistant is talkative and provides lots of specific details from its context.
If the assistant does not know the answer to a question, it truthfully says it
does not know.
Here are the relevant documents for the context:
{context_str}
Instruction: Based on the above documents, provide a detailed answer for the user question below.
Answer "don't know" if not present in the document.
"""
DEFAULT_CONDENSE_PROMPT_TEMPLATE = """
Given the following conversation between a user and an AI assistant and a follow up question from user,
rephrase the follow up question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
class CondensePlusContextChatEngine(BaseChatEngine):
"""Condensed Conversation & Context Chat Engine.
First condense a conversation and latest user message to a standalone question
Then build a context for the standalone question from a retriever,
Then pass the context along with prompt and user message to LLM to generate a response.
"""
def __init__(
self,
retriever: BaseRetriever,
llm: LLM,
memory: BaseMemory,
context_prompt: Optional[str] = None,
condense_prompt: Optional[str] = None,
system_prompt: Optional[str] = None,
skip_condense: bool = False,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
):
self._retriever = retriever
self._llm = llm
self._memory = memory
self._context_prompt_template = (
context_prompt or DEFAULT_CONTEXT_PROMPT_TEMPLATE
)
condense_prompt_str = condense_prompt or DEFAULT_CONDENSE_PROMPT_TEMPLATE
self._condense_prompt_template = PromptTemplate(condense_prompt_str)
self._system_prompt = system_prompt
self._skip_condense = skip_condense
self._node_postprocessors = node_postprocessors or []
self.callback_manager = callback_manager or CallbackManager([])
for node_postprocessor in self._node_postprocessors:
node_postprocessor.callback_manager = self.callback_manager
self._token_counter = TokenCounter()
self._verbose = verbose
@classmethod
def from_defaults(
cls,
retriever: BaseRetriever,
service_context: Optional[ServiceContext] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
system_prompt: Optional[str] = None,
context_prompt: Optional[str] = None,
condense_prompt: Optional[str] = None,
skip_condense: bool = False,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
verbose: bool = False,
**kwargs: Any,
) -> "CondensePlusContextChatEngine":
"""Initialize a CondensePlusContextChatEngine from default parameters."""
service_context = service_context or ServiceContext.from_defaults()
llm = service_context.llm
chat_history = chat_history or []
memory = memory or ChatMemoryBuffer.from_defaults(
chat_history=chat_history, token_limit=llm.metadata.context_window - 256
)
return cls(
retriever=retriever,
llm=llm,
memory=memory,
context_prompt=context_prompt,
condense_prompt=condense_prompt,
skip_condense=skip_condense,
callback_manager=service_context.callback_manager,
node_postprocessors=node_postprocessors,
system_prompt=system_prompt,
verbose=verbose,
)
def _condense_question(
self, chat_history: List[ChatMessage], latest_message: str
) -> str:
"""Condense a conversation history and latest user message to a standalone question."""
if self._skip_condense or len(chat_history) == 0:
return latest_message
chat_history_str = messages_to_history_str(chat_history)
logger.debug(chat_history_str)
return self._llm.predict(
self._condense_prompt_template,
question=latest_message,
chat_history=chat_history_str,
)
async def _acondense_question(
self, chat_history: List[ChatMessage], latest_message: str
) -> str:
"""Condense a conversation history and latest user message to a standalone question."""
if self._skip_condense or len(chat_history) == 0:
return latest_message
chat_history_str = messages_to_history_str(chat_history)
logger.debug(chat_history_str)
return await self._llm.apredict(
self._condense_prompt_template,
question=latest_message,
chat_history=chat_history_str,
)
def _retrieve_context(self, message: str) -> Tuple[str, List[NodeWithScore]]:
"""Build context for a message from retriever."""
nodes = self._retriever.retrieve(message)
for postprocessor in self._node_postprocessors:
nodes = postprocessor.postprocess_nodes(
nodes, query_bundle=QueryBundle(message)
)
context_str = "\n\n".join(
[n.node.get_content(metadata_mode=MetadataMode.LLM).strip() for n in nodes]
)
return context_str, nodes
async def _aretrieve_context(self, message: str) -> Tuple[str, List[NodeWithScore]]:
"""Build context for a message from retriever."""
nodes = await self._retriever.aretrieve(message)
context_str = "\n\n".join(
[n.node.get_content(metadata_mode=MetadataMode.LLM).strip() for n in nodes]
)
return context_str, nodes
def _run_c3(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> Tuple[List[ChatMessage], ToolOutput, List[NodeWithScore]]:
if chat_history is not None:
self._memory.set(chat_history)
chat_history = self._memory.get()
# Condense conversation history and latest message to a standalone question
condensed_question = self._condense_question(chat_history, message)
logger.info(f"Condensed question: {condensed_question}")
if self._verbose:
print(f"Condensed question: {condensed_question}")
# Build context for the standalone question from a retriever
context_str, context_nodes = self._retrieve_context(condensed_question)
context_source = ToolOutput(
tool_name="retriever",
content=context_str,
raw_input={"message": condensed_question},
raw_output=context_str,
)
logger.debug(f"Context: {context_str}")
if self._verbose:
print(f"Context: {context_str}")
system_message_content = self._context_prompt_template.format(
context_str=context_str
)
if self._system_prompt:
system_message_content = self._system_prompt + "\n" + system_message_content
system_message = ChatMessage(
content=system_message_content, role=self._llm.metadata.system_role
)
initial_token_count = self._token_counter.estimate_tokens_in_messages(
[system_message]
)
self._memory.put(ChatMessage(content=message, role=MessageRole.USER))
chat_messages = [
system_message,
*self._memory.get(initial_token_count=initial_token_count),
]
return chat_messages, context_source, context_nodes
async def _arun_c3(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> Tuple[List[ChatMessage], ToolOutput, List[NodeWithScore]]:
if chat_history is not None:
self._memory.set(chat_history)
chat_history = self._memory.get()
# Condense conversation history and latest message to a standalone question
condensed_question = await self._acondense_question(chat_history, message)
logger.info(f"Condensed question: {condensed_question}")
if self._verbose:
print(f"Condensed question: {condensed_question}")
# Build context for the standalone question from a retriever
context_str, context_nodes = await self._aretrieve_context(condensed_question)
context_source = ToolOutput(
tool_name="retriever",
content=context_str,
raw_input={"message": condensed_question},
raw_output=context_str,
)
logger.debug(f"Context: {context_str}")
if self._verbose:
print(f"Context: {context_str}")
system_message_content = self._context_prompt_template.format(
context_str=context_str
)
if self._system_prompt:
system_message_content = self._system_prompt + "\n" + system_message_content
system_message = ChatMessage(
content=system_message_content, role=self._llm.metadata.system_role
)
initial_token_count = self._token_counter.estimate_tokens_in_messages(
[system_message]
)
self._memory.put(ChatMessage(content=message, role=MessageRole.USER))
chat_messages = [
system_message,
*self._memory.get(initial_token_count=initial_token_count),
]
return chat_messages, context_source, context_nodes
@trace_method("chat")
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
chat_messages, context_source, context_nodes = self._run_c3(
message, chat_history
)
# pass the context, system prompt and user message as chat to LLM to generate a response
chat_response = self._llm.chat(chat_messages)
assistant_message = chat_response.message
self._memory.put(assistant_message)
return AgentChatResponse(
response=str(assistant_message.content),
sources=[context_source],
source_nodes=context_nodes,
)
@trace_method("chat")
def stream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
chat_messages, context_source, context_nodes = self._run_c3(
message, chat_history
)
# pass the context, system prompt and user message as chat to LLM to generate a response
chat_response = StreamingAgentChatResponse(
chat_stream=self._llm.stream_chat(chat_messages),
sources=[context_source],
source_nodes=context_nodes,
)
thread = Thread(
target=chat_response.write_response_to_history, args=(self._memory,)
)
thread.start()
return chat_response
@trace_method("chat")
async def achat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
chat_messages, context_source, context_nodes = await self._arun_c3(
message, chat_history
)
# pass the context, system prompt and user message as chat to LLM to generate a response
chat_response = await self._llm.achat(chat_messages)
assistant_message = chat_response.message
self._memory.put(assistant_message)
return AgentChatResponse(
response=str(assistant_message.content),
sources=[context_source],
source_nodes=context_nodes,
)
@trace_method("chat")
async def astream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
chat_messages, context_source, context_nodes = await self._arun_c3(
message, chat_history
)
# pass the context, system prompt and user message as chat to LLM to generate a response
chat_response = StreamingAgentChatResponse(
achat_stream=await self._llm.astream_chat(chat_messages),
sources=[context_source],
source_nodes=context_nodes,
)
thread = Thread(
target=lambda x: asyncio.run(chat_response.awrite_response_to_history(x)),
args=(self._memory,),
)
thread.start()
return chat_response
def reset(self) -> None:
# Clear chat history
self._memory.reset()
@property
def chat_history(self) -> List[ChatMessage]:
"""Get chat history."""
return self._memory.get_all()

View File

@ -0,0 +1,362 @@
import logging
from threading import Thread
from typing import Any, List, Optional, Type
from llama_index.callbacks import CallbackManager, trace_method
from llama_index.chat_engine.types import (
AgentChatResponse,
BaseChatEngine,
StreamingAgentChatResponse,
)
from llama_index.chat_engine.utils import response_gen_from_query_engine
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.llms.types import ChatMessage, MessageRole
from llama_index.core.response.schema import RESPONSE_TYPE, StreamingResponse
from llama_index.llm_predictor.base import LLMPredictorType
from llama_index.llms.generic_utils import messages_to_history_str
from llama_index.llms.llm import LLM
from llama_index.memory import BaseMemory, ChatMemoryBuffer
from llama_index.prompts.base import BasePromptTemplate, PromptTemplate
from llama_index.service_context import ServiceContext
from llama_index.token_counter.mock_embed_model import MockEmbedding
from llama_index.tools import ToolOutput
logger = logging.getLogger(__name__)
DEFAULT_TEMPLATE = """\
Given a conversation (between Human and Assistant) and a follow up message from Human, \
rewrite the message to be a standalone question that captures all relevant context \
from the conversation.
<Chat History>
{chat_history}
<Follow Up Message>
{question}
<Standalone question>
"""
DEFAULT_PROMPT = PromptTemplate(DEFAULT_TEMPLATE)
class CondenseQuestionChatEngine(BaseChatEngine):
"""Condense Question Chat Engine.
First generate a standalone question from conversation context and last message,
then query the query engine for a response.
"""
def __init__(
self,
query_engine: BaseQueryEngine,
condense_question_prompt: BasePromptTemplate,
memory: BaseMemory,
llm: LLMPredictorType,
verbose: bool = False,
callback_manager: Optional[CallbackManager] = None,
) -> None:
self._query_engine = query_engine
self._condense_question_prompt = condense_question_prompt
self._memory = memory
self._llm = llm
self._verbose = verbose
self.callback_manager = callback_manager or CallbackManager([])
@classmethod
def from_defaults(
cls,
query_engine: BaseQueryEngine,
condense_question_prompt: Optional[BasePromptTemplate] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
service_context: Optional[ServiceContext] = None,
verbose: bool = False,
system_prompt: Optional[str] = None,
prefix_messages: Optional[List[ChatMessage]] = None,
llm: Optional[LLM] = None,
**kwargs: Any,
) -> "CondenseQuestionChatEngine":
"""Initialize a CondenseQuestionChatEngine from default parameters."""
condense_question_prompt = condense_question_prompt or DEFAULT_PROMPT
if llm is None:
service_context = service_context or ServiceContext.from_defaults(
embed_model=MockEmbedding(embed_dim=2)
)
llm = service_context.llm
else:
service_context = service_context or ServiceContext.from_defaults(
llm=llm, embed_model=MockEmbedding(embed_dim=2)
)
chat_history = chat_history or []
memory = memory or memory_cls.from_defaults(chat_history=chat_history, llm=llm)
if system_prompt is not None:
raise NotImplementedError(
"system_prompt is not supported for CondenseQuestionChatEngine."
)
if prefix_messages is not None:
raise NotImplementedError(
"prefix_messages is not supported for CondenseQuestionChatEngine."
)
return cls(
query_engine,
condense_question_prompt,
memory,
llm,
verbose=verbose,
callback_manager=service_context.callback_manager,
)
def _condense_question(
self, chat_history: List[ChatMessage], last_message: str
) -> str:
"""
Generate standalone question from conversation context and last message.
"""
chat_history_str = messages_to_history_str(chat_history)
logger.debug(chat_history_str)
return self._llm.predict(
self._condense_question_prompt,
question=last_message,
chat_history=chat_history_str,
)
async def _acondense_question(
self, chat_history: List[ChatMessage], last_message: str
) -> str:
"""
Generate standalone question from conversation context and last message.
"""
chat_history_str = messages_to_history_str(chat_history)
logger.debug(chat_history_str)
return await self._llm.apredict(
self._condense_question_prompt,
question=last_message,
chat_history=chat_history_str,
)
def _get_tool_output_from_response(
self, query: str, response: RESPONSE_TYPE
) -> ToolOutput:
if isinstance(response, StreamingResponse):
return ToolOutput(
content="",
tool_name="query_engine",
raw_input={"query": query},
raw_output=response,
)
else:
return ToolOutput(
content=str(response),
tool_name="query_engine",
raw_input={"query": query},
raw_output=response,
)
@trace_method("chat")
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
chat_history = chat_history or self._memory.get()
# Generate standalone question from conversation context and last message
condensed_question = self._condense_question(chat_history, message)
log_str = f"Querying with: {condensed_question}"
logger.info(log_str)
if self._verbose:
print(log_str)
# TODO: right now, query engine uses class attribute to configure streaming,
# we are moving towards separate streaming and non-streaming methods.
# In the meanwhile, use this hack to toggle streaming.
from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
if isinstance(self._query_engine, RetrieverQueryEngine):
is_streaming = self._query_engine._response_synthesizer._streaming
self._query_engine._response_synthesizer._streaming = False
# Query with standalone question
query_response = self._query_engine.query(condensed_question)
# NOTE: reset streaming flag
if isinstance(self._query_engine, RetrieverQueryEngine):
self._query_engine._response_synthesizer._streaming = is_streaming
tool_output = self._get_tool_output_from_response(
condensed_question, query_response
)
# Record response
self._memory.put(ChatMessage(role=MessageRole.USER, content=message))
self._memory.put(
ChatMessage(role=MessageRole.ASSISTANT, content=str(query_response))
)
return AgentChatResponse(response=str(query_response), sources=[tool_output])
@trace_method("chat")
def stream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
chat_history = chat_history or self._memory.get()
# Generate standalone question from conversation context and last message
condensed_question = self._condense_question(chat_history, message)
log_str = f"Querying with: {condensed_question}"
logger.info(log_str)
if self._verbose:
print(log_str)
# TODO: right now, query engine uses class attribute to configure streaming,
# we are moving towards separate streaming and non-streaming methods.
# In the meanwhile, use this hack to toggle streaming.
from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
if isinstance(self._query_engine, RetrieverQueryEngine):
is_streaming = self._query_engine._response_synthesizer._streaming
self._query_engine._response_synthesizer._streaming = True
# Query with standalone question
query_response = self._query_engine.query(condensed_question)
# NOTE: reset streaming flag
if isinstance(self._query_engine, RetrieverQueryEngine):
self._query_engine._response_synthesizer._streaming = is_streaming
tool_output = self._get_tool_output_from_response(
condensed_question, query_response
)
# Record response
if (
isinstance(query_response, StreamingResponse)
and query_response.response_gen is not None
):
# override the generator to include writing to chat history
self._memory.put(ChatMessage(role=MessageRole.USER, content=message))
response = StreamingAgentChatResponse(
chat_stream=response_gen_from_query_engine(query_response.response_gen),
sources=[tool_output],
)
thread = Thread(
target=response.write_response_to_history, args=(self._memory, True)
)
thread.start()
else:
raise ValueError("Streaming is not enabled. Please use chat() instead.")
return response
@trace_method("chat")
async def achat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
chat_history = chat_history or self._memory.get()
# Generate standalone question from conversation context and last message
condensed_question = await self._acondense_question(chat_history, message)
log_str = f"Querying with: {condensed_question}"
logger.info(log_str)
if self._verbose:
print(log_str)
# TODO: right now, query engine uses class attribute to configure streaming,
# we are moving towards separate streaming and non-streaming methods.
# In the meanwhile, use this hack to toggle streaming.
from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
if isinstance(self._query_engine, RetrieverQueryEngine):
is_streaming = self._query_engine._response_synthesizer._streaming
self._query_engine._response_synthesizer._streaming = False
# Query with standalone question
query_response = await self._query_engine.aquery(condensed_question)
# NOTE: reset streaming flag
if isinstance(self._query_engine, RetrieverQueryEngine):
self._query_engine._response_synthesizer._streaming = is_streaming
tool_output = self._get_tool_output_from_response(
condensed_question, query_response
)
# Record response
self._memory.put(ChatMessage(role=MessageRole.USER, content=message))
self._memory.put(
ChatMessage(role=MessageRole.ASSISTANT, content=str(query_response))
)
return AgentChatResponse(response=str(query_response), sources=[tool_output])
@trace_method("chat")
async def astream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
chat_history = chat_history or self._memory.get()
# Generate standalone question from conversation context and last message
condensed_question = await self._acondense_question(chat_history, message)
log_str = f"Querying with: {condensed_question}"
logger.info(log_str)
if self._verbose:
print(log_str)
# TODO: right now, query engine uses class attribute to configure streaming,
# we are moving towards separate streaming and non-streaming methods.
# In the meanwhile, use this hack to toggle streaming.
from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
if isinstance(self._query_engine, RetrieverQueryEngine):
is_streaming = self._query_engine._response_synthesizer._streaming
self._query_engine._response_synthesizer._streaming = True
# Query with standalone question
query_response = await self._query_engine.aquery(condensed_question)
# NOTE: reset streaming flag
if isinstance(self._query_engine, RetrieverQueryEngine):
self._query_engine._response_synthesizer._streaming = is_streaming
tool_output = self._get_tool_output_from_response(
condensed_question, query_response
)
# Record response
if (
isinstance(query_response, StreamingResponse)
and query_response.response_gen is not None
):
# override the generator to include writing to chat history
# TODO: query engine does not support async generator yet
self._memory.put(ChatMessage(role=MessageRole.USER, content=message))
response = StreamingAgentChatResponse(
chat_stream=response_gen_from_query_engine(query_response.response_gen),
sources=[tool_output],
)
thread = Thread(
target=response.write_response_to_history, args=(self._memory,)
)
thread.start()
else:
raise ValueError("Streaming is not enabled. Please use achat() instead.")
return response
def reset(self) -> None:
# Clear chat history
self._memory.reset()
@property
def chat_history(self) -> List[ChatMessage]:
"""Get chat history."""
return self._memory.get_all()

View File

@ -0,0 +1,301 @@
import asyncio
from threading import Thread
from typing import Any, List, Optional, Tuple
from llama_index.callbacks import CallbackManager, trace_method
from llama_index.chat_engine.types import (
AgentChatResponse,
BaseChatEngine,
StreamingAgentChatResponse,
ToolOutput,
)
from llama_index.core.base_retriever import BaseRetriever
from llama_index.core.llms.types import ChatMessage, MessageRole
from llama_index.llms.llm import LLM
from llama_index.memory import BaseMemory, ChatMemoryBuffer
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle
from llama_index.service_context import ServiceContext
DEFAULT_CONTEXT_TEMPLATE = (
"Context information is below."
"\n--------------------\n"
"{context_str}"
"\n--------------------\n"
)
class ContextChatEngine(BaseChatEngine):
"""Context Chat Engine.
Uses a retriever to retrieve a context, set the context in the system prompt,
and then uses an LLM to generate a response, for a fluid chat experience.
"""
def __init__(
self,
retriever: BaseRetriever,
llm: LLM,
memory: BaseMemory,
prefix_messages: List[ChatMessage],
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
context_template: Optional[str] = None,
callback_manager: Optional[CallbackManager] = None,
) -> None:
self._retriever = retriever
self._llm = llm
self._memory = memory
self._prefix_messages = prefix_messages
self._node_postprocessors = node_postprocessors or []
self._context_template = context_template or DEFAULT_CONTEXT_TEMPLATE
self.callback_manager = callback_manager or CallbackManager([])
for node_postprocessor in self._node_postprocessors:
node_postprocessor.callback_manager = self.callback_manager
@classmethod
def from_defaults(
cls,
retriever: BaseRetriever,
service_context: Optional[ServiceContext] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
system_prompt: Optional[str] = None,
prefix_messages: Optional[List[ChatMessage]] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
context_template: Optional[str] = None,
**kwargs: Any,
) -> "ContextChatEngine":
"""Initialize a ContextChatEngine from default parameters."""
service_context = service_context or ServiceContext.from_defaults()
llm = service_context.llm
chat_history = chat_history or []
memory = memory or ChatMemoryBuffer.from_defaults(
chat_history=chat_history, token_limit=llm.metadata.context_window - 256
)
if system_prompt is not None:
if prefix_messages is not None:
raise ValueError(
"Cannot specify both system_prompt and prefix_messages"
)
prefix_messages = [
ChatMessage(content=system_prompt, role=llm.metadata.system_role)
]
prefix_messages = prefix_messages or []
node_postprocessors = node_postprocessors or []
return cls(
retriever,
llm=llm,
memory=memory,
prefix_messages=prefix_messages,
node_postprocessors=node_postprocessors,
callback_manager=service_context.callback_manager,
context_template=context_template,
)
def _generate_context(self, message: str) -> Tuple[str, List[NodeWithScore]]:
"""Generate context information from a message."""
nodes = self._retriever.retrieve(message)
for postprocessor in self._node_postprocessors:
nodes = postprocessor.postprocess_nodes(
nodes, query_bundle=QueryBundle(message)
)
context_str = "\n\n".join(
[n.node.get_content(metadata_mode=MetadataMode.LLM).strip() for n in nodes]
)
return self._context_template.format(context_str=context_str), nodes
async def _agenerate_context(self, message: str) -> Tuple[str, List[NodeWithScore]]:
"""Generate context information from a message."""
nodes = await self._retriever.aretrieve(message)
for postprocessor in self._node_postprocessors:
nodes = postprocessor.postprocess_nodes(
nodes, query_bundle=QueryBundle(message)
)
context_str = "\n\n".join(
[n.node.get_content(metadata_mode=MetadataMode.LLM).strip() for n in nodes]
)
return self._context_template.format(context_str=context_str), nodes
def _get_prefix_messages_with_context(self, context_str: str) -> List[ChatMessage]:
"""Get the prefix messages with context."""
# ensure we grab the user-configured system prompt
system_prompt = ""
prefix_messages = self._prefix_messages
if (
len(self._prefix_messages) != 0
and self._prefix_messages[0].role == MessageRole.SYSTEM
):
system_prompt = str(self._prefix_messages[0].content)
prefix_messages = self._prefix_messages[1:]
context_str_w_sys_prompt = system_prompt.strip() + "\n" + context_str
return [
ChatMessage(
content=context_str_w_sys_prompt, role=self._llm.metadata.system_role
),
*prefix_messages,
]
@trace_method("chat")
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
context_str_template, nodes = self._generate_context(message)
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
prefix_messages_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in prefix_messages])
)
)
all_messages = prefix_messages + self._memory.get(
initial_token_count=prefix_messages_token_count
)
chat_response = self._llm.chat(all_messages)
ai_message = chat_response.message
self._memory.put(ai_message)
return AgentChatResponse(
response=str(chat_response.message.content),
sources=[
ToolOutput(
tool_name="retriever",
content=str(prefix_messages[0]),
raw_input={"message": message},
raw_output=prefix_messages[0],
)
],
source_nodes=nodes,
)
@trace_method("chat")
def stream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
context_str_template, nodes = self._generate_context(message)
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in prefix_messages])
)
)
all_messages = prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)
chat_response = StreamingAgentChatResponse(
chat_stream=self._llm.stream_chat(all_messages),
sources=[
ToolOutput(
tool_name="retriever",
content=str(prefix_messages[0]),
raw_input={"message": message},
raw_output=prefix_messages[0],
)
],
source_nodes=nodes,
)
thread = Thread(
target=chat_response.write_response_to_history, args=(self._memory,)
)
thread.start()
return chat_response
@trace_method("chat")
async def achat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
context_str_template, nodes = await self._agenerate_context(message)
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in prefix_messages])
)
)
all_messages = prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)
chat_response = await self._llm.achat(all_messages)
ai_message = chat_response.message
self._memory.put(ai_message)
return AgentChatResponse(
response=str(chat_response.message.content),
sources=[
ToolOutput(
tool_name="retriever",
content=str(prefix_messages[0]),
raw_input={"message": message},
raw_output=prefix_messages[0],
)
],
source_nodes=nodes,
)
@trace_method("chat")
async def astream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
context_str_template, nodes = await self._agenerate_context(message)
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in prefix_messages])
)
)
all_messages = prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)
chat_response = StreamingAgentChatResponse(
achat_stream=await self._llm.astream_chat(all_messages),
sources=[
ToolOutput(
tool_name="retriever",
content=str(prefix_messages[0]),
raw_input={"message": message},
raw_output=prefix_messages[0],
)
],
source_nodes=nodes,
)
thread = Thread(
target=lambda x: asyncio.run(chat_response.awrite_response_to_history(x)),
args=(self._memory,),
)
thread.start()
return chat_response
def reset(self) -> None:
self._memory.reset()
@property
def chat_history(self) -> List[ChatMessage]:
"""Get chat history."""
return self._memory.get_all()

View File

@ -0,0 +1,175 @@
import asyncio
from threading import Thread
from typing import Any, List, Optional, Type
from llama_index.callbacks import CallbackManager, trace_method
from llama_index.chat_engine.types import (
AgentChatResponse,
BaseChatEngine,
StreamingAgentChatResponse,
)
from llama_index.core.llms.types import ChatMessage
from llama_index.llms.llm import LLM
from llama_index.memory import BaseMemory, ChatMemoryBuffer
from llama_index.service_context import ServiceContext
class SimpleChatEngine(BaseChatEngine):
"""Simple Chat Engine.
Have a conversation with the LLM.
This does not make use of a knowledge base.
"""
def __init__(
self,
llm: LLM,
memory: BaseMemory,
prefix_messages: List[ChatMessage],
callback_manager: Optional[CallbackManager] = None,
) -> None:
self._llm = llm
self._memory = memory
self._prefix_messages = prefix_messages
self.callback_manager = callback_manager or CallbackManager([])
@classmethod
def from_defaults(
cls,
service_context: Optional[ServiceContext] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
system_prompt: Optional[str] = None,
prefix_messages: Optional[List[ChatMessage]] = None,
**kwargs: Any,
) -> "SimpleChatEngine":
"""Initialize a SimpleChatEngine from default parameters."""
service_context = service_context or ServiceContext.from_defaults()
llm = service_context.llm
chat_history = chat_history or []
memory = memory or memory_cls.from_defaults(chat_history=chat_history, llm=llm)
if system_prompt is not None:
if prefix_messages is not None:
raise ValueError(
"Cannot specify both system_prompt and prefix_messages"
)
prefix_messages = [
ChatMessage(content=system_prompt, role=llm.metadata.system_role)
]
prefix_messages = prefix_messages or []
return cls(
llm=llm,
memory=memory,
prefix_messages=prefix_messages,
callback_manager=service_context.callback_manager,
)
@trace_method("chat")
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in self._prefix_messages])
)
)
all_messages = self._prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)
chat_response = self._llm.chat(all_messages)
ai_message = chat_response.message
self._memory.put(ai_message)
return AgentChatResponse(response=str(chat_response.message.content))
@trace_method("chat")
def stream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in self._prefix_messages])
)
)
all_messages = self._prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)
chat_response = StreamingAgentChatResponse(
chat_stream=self._llm.stream_chat(all_messages)
)
thread = Thread(
target=chat_response.write_response_to_history, args=(self._memory,)
)
thread.start()
return chat_response
@trace_method("chat")
async def achat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in self._prefix_messages])
)
)
all_messages = self._prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)
chat_response = await self._llm.achat(all_messages)
ai_message = chat_response.message
self._memory.put(ai_message)
return AgentChatResponse(response=str(chat_response.message.content))
@trace_method("chat")
async def astream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in self._prefix_messages])
)
)
all_messages = self._prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)
chat_response = StreamingAgentChatResponse(
achat_stream=await self._llm.astream_chat(all_messages)
)
thread = Thread(
target=lambda x: asyncio.run(chat_response.awrite_response_to_history(x)),
args=(self._memory,),
)
thread.start()
return chat_response
def reset(self) -> None:
self._memory.reset()
@property
def chat_history(self) -> List[ChatMessage]:
"""Get chat history."""
return self._memory.get_all()

View File

@ -0,0 +1,314 @@
import asyncio
import logging
import queue
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from threading import Event
from typing import AsyncGenerator, Generator, List, Optional, Union
from llama_index.core.llms.types import (
ChatMessage,
ChatResponseAsyncGen,
ChatResponseGen,
)
from llama_index.core.response.schema import Response, StreamingResponse
from llama_index.memory import BaseMemory
from llama_index.schema import NodeWithScore
from llama_index.tools import ToolOutput
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
def is_function(message: ChatMessage) -> bool:
"""Utility for ChatMessage responses from OpenAI models."""
return "tool_calls" in message.additional_kwargs
class ChatResponseMode(str, Enum):
"""Flag toggling waiting/streaming in `Agent._chat`."""
WAIT = "wait"
STREAM = "stream"
@dataclass
class AgentChatResponse:
"""Agent chat response."""
response: str = ""
sources: List[ToolOutput] = field(default_factory=list)
source_nodes: List[NodeWithScore] = field(default_factory=list)
def __post_init__(self) -> None:
if self.sources and not self.source_nodes:
for tool_output in self.sources:
if isinstance(tool_output.raw_output, (Response, StreamingResponse)):
self.source_nodes.extend(tool_output.raw_output.source_nodes)
def __str__(self) -> str:
return self.response
@dataclass
class StreamingAgentChatResponse:
"""Streaming chat response to user and writing to chat history."""
response: str = ""
sources: List[ToolOutput] = field(default_factory=list)
chat_stream: Optional[ChatResponseGen] = None
achat_stream: Optional[ChatResponseAsyncGen] = None
source_nodes: List[NodeWithScore] = field(default_factory=list)
_unformatted_response: str = ""
_queue: queue.Queue = field(default_factory=queue.Queue)
_aqueue: asyncio.Queue = field(default_factory=asyncio.Queue)
# flag when chat message is a function call
_is_function: Optional[bool] = None
# flag when processing done
_is_done = False
# signal when a new item is added to the queue
_new_item_event: asyncio.Event = field(default_factory=asyncio.Event)
# NOTE: async code uses two events rather than one since it yields
# control when waiting for queue item
# signal when the OpenAI functions stop executing
_is_function_false_event: asyncio.Event = field(default_factory=asyncio.Event)
# signal when an OpenAI function is being executed
_is_function_not_none_thread_event: Event = field(default_factory=Event)
def __post_init__(self) -> None:
if self.sources and not self.source_nodes:
for tool_output in self.sources:
if isinstance(tool_output.raw_output, (Response, StreamingResponse)):
self.source_nodes.extend(tool_output.raw_output.source_nodes)
def __str__(self) -> str:
if self._is_done and not self._queue.empty() and not self._is_function:
while self._queue.queue:
delta = self._queue.queue.popleft()
self._unformatted_response += delta
self.response = self._unformatted_response.strip()
return self.response
def put_in_queue(self, delta: Optional[str]) -> None:
self._queue.put_nowait(delta)
self._is_function_not_none_thread_event.set()
def aput_in_queue(self, delta: Optional[str]) -> None:
self._aqueue.put_nowait(delta)
self._new_item_event.set()
def write_response_to_history(
self, memory: BaseMemory, raise_error: bool = False
) -> None:
if self.chat_stream is None:
raise ValueError(
"chat_stream is None. Cannot write to history without chat_stream."
)
# try/except to prevent hanging on error
try:
final_text = ""
for chat in self.chat_stream:
self._is_function = is_function(chat.message)
self.put_in_queue(chat.delta)
final_text += chat.delta or ""
if self._is_function is not None: # if loop has gone through iteration
# NOTE: this is to handle the special case where we consume some of the
# chat stream, but not all of it (e.g. in react agent)
chat.message.content = final_text.strip() # final message
memory.put(chat.message)
except Exception as e:
if not raise_error:
logger.warning(
f"Encountered exception writing response to history: {e}"
)
else:
raise
self._is_done = True
# This act as is_done events for any consumers waiting
self._is_function_not_none_thread_event.set()
async def awrite_response_to_history(
self,
memory: BaseMemory,
) -> None:
if self.achat_stream is None:
raise ValueError(
"achat_stream is None. Cannot asynchronously write to "
"history without achat_stream."
)
# try/except to prevent hanging on error
try:
final_text = ""
async for chat in self.achat_stream:
self._is_function = is_function(chat.message)
self.aput_in_queue(chat.delta)
final_text += chat.delta or ""
self._new_item_event.set()
if self._is_function is False:
self._is_function_false_event.set()
if self._is_function is not None: # if loop has gone through iteration
# NOTE: this is to handle the special case where we consume some of the
# chat stream, but not all of it (e.g. in react agent)
chat.message.content = final_text.strip() # final message
memory.put(chat.message)
except Exception as e:
logger.warning(f"Encountered exception writing response to history: {e}")
self._is_done = True
# These act as is_done events for any consumers waiting
self._is_function_false_event.set()
self._new_item_event.set()
@property
def response_gen(self) -> Generator[str, None, None]:
while not self._is_done or not self._queue.empty():
try:
delta = self._queue.get(block=False)
self._unformatted_response += delta
yield delta
except queue.Empty:
# Queue is empty, but we're not done yet
time.sleep(0.01)
self.response = self._unformatted_response.strip()
async def async_response_gen(self) -> AsyncGenerator[str, None]:
while not self._is_done or not self._aqueue.empty():
if not self._aqueue.empty():
delta = self._aqueue.get_nowait()
self._unformatted_response += delta
yield delta
else:
await self._new_item_event.wait() # Wait until a new item is added
self._new_item_event.clear() # Clear the event for the next wait
self.response = self._unformatted_response.strip()
def print_response_stream(self) -> None:
for token in self.response_gen:
print(token, end="", flush=True)
async def aprint_response_stream(self) -> None:
async for token in self.async_response_gen():
print(token, end="", flush=True)
AGENT_CHAT_RESPONSE_TYPE = Union[AgentChatResponse, StreamingAgentChatResponse]
class BaseChatEngine(ABC):
"""Base Chat Engine."""
@abstractmethod
def reset(self) -> None:
"""Reset conversation state."""
@abstractmethod
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Main chat interface."""
@abstractmethod
def stream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
"""Stream chat interface."""
@abstractmethod
async def achat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Async version of main chat interface."""
@abstractmethod
async def astream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
"""Async version of main chat interface."""
def chat_repl(self) -> None:
"""Enter interactive chat REPL."""
print("===== Entering Chat REPL =====")
print('Type "exit" to exit.\n')
self.reset()
message = input("Human: ")
while message != "exit":
response = self.chat(message)
print(f"Assistant: {response}\n")
message = input("Human: ")
def streaming_chat_repl(self) -> None:
"""Enter interactive chat REPL with streaming responses."""
print("===== Entering Chat REPL =====")
print('Type "exit" to exit.\n')
self.reset()
message = input("Human: ")
while message != "exit":
response = self.stream_chat(message)
print("Assistant: ", end="", flush=True)
response.print_response_stream()
print("\n")
message = input("Human: ")
@property
@abstractmethod
def chat_history(self) -> List[ChatMessage]:
pass
class ChatMode(str, Enum):
"""Chat Engine Modes."""
SIMPLE = "simple"
"""Corresponds to `SimpleChatEngine`.
Chat with LLM, without making use of a knowledge base.
"""
CONDENSE_QUESTION = "condense_question"
"""Corresponds to `CondenseQuestionChatEngine`.
First generate a standalone question from conversation context and last message,
then query the query engine for a response.
"""
CONTEXT = "context"
"""Corresponds to `ContextChatEngine`.
First retrieve text from the index using the user's message, then use the context
in the system prompt to generate a response.
"""
CONDENSE_PLUS_CONTEXT = "condense_plus_context"
"""Corresponds to `CondensePlusContextChatEngine`.
First condense a conversation and latest user message to a standalone question.
Then build a context for the standalone question from a retriever,
Then pass the context along with prompt and user message to LLM to generate a response.
"""
REACT = "react"
"""Corresponds to `ReActAgent`.
Use a ReAct agent loop with query engine tools.
"""
OPENAI = "openai"
"""Corresponds to `OpenAIAgent`.
Use an OpenAI function calling agent loop.
NOTE: only works with OpenAI models that support function calling API.
"""
BEST = "best"
"""Select the best chat engine based on the current LLM.
Corresponds to `OpenAIAgent` if using an OpenAI model that supports
function calling API, otherwise, corresponds to `ReActAgent`.
"""

View File

@ -0,0 +1,17 @@
from llama_index.core.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseGen,
MessageRole,
)
from llama_index.types import TokenGen
def response_gen_from_query_engine(response_gen: TokenGen) -> ChatResponseGen:
response_str = ""
for token in response_gen:
response_str += token
yield ChatResponse(
message=ChatMessage(role=MessageRole.ASSISTANT, content=response_str),
delta=token,
)

View File

View File

@ -0,0 +1,172 @@
import argparse
from typing import Any, Optional
from llama_index.command_line.rag import RagCLI, default_ragcli_persist_dir
from llama_index.embeddings import OpenAIEmbedding
from llama_index.ingestion import IngestionCache, IngestionPipeline
from llama_index.llama_dataset.download import (
LLAMA_DATASETS_LFS_URL,
LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL,
download_llama_dataset,
)
from llama_index.llama_pack.download import LLAMA_HUB_URL, download_llama_pack
from llama_index.storage.docstore import SimpleDocumentStore
from llama_index.text_splitter import SentenceSplitter
from llama_index.vector_stores import ChromaVectorStore
def handle_download_llama_pack(
llama_pack_class: Optional[str] = None,
download_dir: Optional[str] = None,
llama_hub_url: str = LLAMA_HUB_URL,
**kwargs: Any,
) -> None:
assert llama_pack_class is not None
assert download_dir is not None
download_llama_pack(
llama_pack_class=llama_pack_class,
download_dir=download_dir,
llama_hub_url=llama_hub_url,
)
print(f"Successfully downloaded {llama_pack_class} to {download_dir}")
def handle_download_llama_dataset(
llama_dataset_class: Optional[str] = None,
download_dir: Optional[str] = None,
llama_hub_url: str = LLAMA_HUB_URL,
llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL,
llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL,
**kwargs: Any,
) -> None:
assert llama_dataset_class is not None
assert download_dir is not None
download_llama_dataset(
llama_dataset_class=llama_dataset_class,
download_dir=download_dir,
llama_hub_url=llama_hub_url,
llama_datasets_lfs_url=llama_datasets_lfs_url,
llama_datasets_source_files_tree_url=llama_datasets_source_files_tree_url,
show_progress=True,
load_documents=False,
)
print(f"Successfully downloaded {llama_dataset_class} to {download_dir}")
def default_rag_cli() -> RagCLI:
import chromadb
persist_dir = default_ragcli_persist_dir()
chroma_client = chromadb.PersistentClient(path=persist_dir)
chroma_collection = chroma_client.create_collection("default", get_or_create=True)
vector_store = ChromaVectorStore(
chroma_collection=chroma_collection, persist_dir=persist_dir
)
docstore = SimpleDocumentStore()
ingestion_pipeline = IngestionPipeline(
transformations=[SentenceSplitter(), OpenAIEmbedding()],
vector_store=vector_store,
docstore=docstore,
cache=IngestionCache(),
)
try:
ingestion_pipeline.load(persist_dir=persist_dir)
except FileNotFoundError:
pass
return RagCLI(
ingestion_pipeline=ingestion_pipeline,
verbose=False,
persist_dir=persist_dir,
)
def main() -> None:
parser = argparse.ArgumentParser(description="LlamaIndex CLI tool.")
# Subparsers for the main commands
subparsers = parser.add_subparsers(title="commands", dest="command", required=True)
# llama rag command
llamarag_parser = subparsers.add_parser(
"rag", help="Ask a question to a document / a directory of documents."
)
RagCLI.add_parser_args(llamarag_parser, default_rag_cli)
# download llamapacks command
llamapack_parser = subparsers.add_parser(
"download-llamapack", help="Download a llama-pack"
)
llamapack_parser.add_argument(
"llama_pack_class",
type=str,
help=(
"The name of the llama-pack class you want to download, "
"such as `GmailOpenAIAgentPack`."
),
)
llamapack_parser.add_argument(
"-d",
"--download-dir",
type=str,
default="./llama_packs",
help="Custom dirpath to download the pack into.",
)
llamapack_parser.add_argument(
"--llama-hub-url",
type=str,
default=LLAMA_HUB_URL,
help="URL to llama hub.",
)
llamapack_parser.set_defaults(
func=lambda args: handle_download_llama_pack(**vars(args))
)
# download llamadatasets command
llamadataset_parser = subparsers.add_parser(
"download-llamadataset", help="Download a llama-dataset"
)
llamadataset_parser.add_argument(
"llama_dataset_class",
type=str,
help=(
"The name of the llama-dataset class you want to download, "
"such as `PaulGrahamEssayDataset`."
),
)
llamadataset_parser.add_argument(
"-d",
"--download-dir",
type=str,
default="./llama_datasets",
help="Custom dirpath to download the pack into.",
)
llamadataset_parser.add_argument(
"--llama-hub-url",
type=str,
default=LLAMA_HUB_URL,
help="URL to llama hub.",
)
llamadataset_parser.add_argument(
"--llama-datasets-lfs-url",
type=str,
default=LLAMA_DATASETS_LFS_URL,
help="URL to llama datasets.",
)
llamadataset_parser.set_defaults(
func=lambda args: handle_download_llama_dataset(**vars(args))
)
# Parse the command-line arguments
args = parser.parse_args()
# Call the appropriate function based on the command
args.func(args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,373 @@
import asyncio
import os
import shutil
from argparse import ArgumentParser
from glob import iglob
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union, cast
from llama_index import (
Response,
ServiceContext,
SimpleDirectoryReader,
VectorStoreIndex,
)
from llama_index.bridge.pydantic import BaseModel, Field, validator
from llama_index.chat_engine import CondenseQuestionChatEngine
from llama_index.core.response.schema import RESPONSE_TYPE, StreamingResponse
from llama_index.embeddings.base import BaseEmbedding
from llama_index.ingestion import IngestionPipeline
from llama_index.llms import LLM, OpenAI
from llama_index.query_engine import CustomQueryEngine
from llama_index.query_pipeline import FnComponent
from llama_index.query_pipeline.query import QueryPipeline
from llama_index.readers.base import BaseReader
from llama_index.response_synthesizers import CompactAndRefine
from llama_index.utils import get_cache_dir
RAG_HISTORY_FILE_NAME = "files_history.txt"
def default_ragcli_persist_dir() -> str:
return str(Path(get_cache_dir()) / "rag_cli")
def query_input(query_str: Optional[str] = None) -> str:
return query_str or ""
class QueryPipelineQueryEngine(CustomQueryEngine):
query_pipeline: QueryPipeline = Field(
description="Query Pipeline to use for Q&A.",
)
def custom_query(self, query_str: str) -> RESPONSE_TYPE:
return self.query_pipeline.run(query_str=query_str)
async def acustom_query(self, query_str: str) -> RESPONSE_TYPE:
return await self.query_pipeline.arun(query_str=query_str)
class RagCLI(BaseModel):
"""
CLI tool for chatting with output of a IngestionPipeline via a QueryPipeline.
"""
ingestion_pipeline: IngestionPipeline = Field(
description="Ingestion pipeline to run for RAG ingestion."
)
verbose: bool = Field(
description="Whether to print out verbose information during execution.",
default=False,
)
persist_dir: str = Field(
description="Directory to persist ingestion pipeline.",
default_factory=default_ragcli_persist_dir,
)
llm: LLM = Field(
description="Language model to use for response generation.",
default_factory=lambda: OpenAI(model="gpt-3.5-turbo", streaming=True),
)
query_pipeline: Optional[QueryPipeline] = Field(
description="Query Pipeline to use for Q&A.",
default=None,
)
chat_engine: Optional[CondenseQuestionChatEngine] = Field(
description="Chat engine to use for chatting.",
default_factory=None,
)
file_extractor: Optional[Dict[str, BaseReader]] = Field(
description="File extractor to use for extracting text from files.",
default=None,
)
class Config:
arbitrary_types_allowed = True
@validator("query_pipeline", always=True)
def query_pipeline_from_ingestion_pipeline(
cls, query_pipeline: Any, values: Dict[str, Any]
) -> Optional[QueryPipeline]:
"""
If query_pipeline is not provided, create one from ingestion_pipeline.
"""
if query_pipeline is not None:
return query_pipeline
ingestion_pipeline = cast(IngestionPipeline, values["ingestion_pipeline"])
if ingestion_pipeline.vector_store is None:
return None
verbose = cast(bool, values["verbose"])
query_component = FnComponent(
fn=query_input, output_key="output", req_params={"query_str"}
)
llm = cast(LLM, values["llm"])
# get embed_model from transformations if possible
embed_model = None
if ingestion_pipeline.transformations is not None:
for transformation in ingestion_pipeline.transformations:
if isinstance(transformation, BaseEmbedding):
embed_model = transformation
break
service_context = ServiceContext.from_defaults(
llm=llm, embed_model=embed_model or "default"
)
retriever = VectorStoreIndex.from_vector_store(
ingestion_pipeline.vector_store, service_context=service_context
).as_retriever(similarity_top_k=8)
response_synthesizer = CompactAndRefine(
service_context=service_context, streaming=True, verbose=verbose
)
# define query pipeline
query_pipeline = QueryPipeline(verbose=verbose)
query_pipeline.add_modules(
{
"query": query_component,
"retriever": retriever,
"summarizer": response_synthesizer,
}
)
query_pipeline.add_link("query", "retriever")
query_pipeline.add_link("retriever", "summarizer", dest_key="nodes")
query_pipeline.add_link("query", "summarizer", dest_key="query_str")
return query_pipeline
@validator("chat_engine", always=True)
def chat_engine_from_query_pipeline(
cls, chat_engine: Any, values: Dict[str, Any]
) -> Optional[CondenseQuestionChatEngine]:
"""
If chat_engine is not provided, create one from query_pipeline.
"""
if chat_engine is not None:
return chat_engine
if values.get("query_pipeline", None) is None:
values["query_pipeline"] = cls.query_pipeline_from_ingestion_pipeline(
query_pipeline=None, values=values
)
query_pipeline = cast(QueryPipeline, values["query_pipeline"])
if query_pipeline is None:
return None
query_engine = QueryPipelineQueryEngine(query_pipeline=query_pipeline) # type: ignore
verbose = cast(bool, values["verbose"])
llm = cast(LLM, values["llm"])
return CondenseQuestionChatEngine.from_defaults(
query_engine=query_engine, llm=llm, verbose=verbose
)
async def handle_cli(
self,
files: Optional[str] = None,
question: Optional[str] = None,
chat: bool = False,
verbose: bool = False,
clear: bool = False,
create_llama: bool = False,
**kwargs: Dict[str, Any],
) -> None:
"""
Entrypoint for local document RAG CLI tool.
"""
if clear:
# delete self.persist_dir directory including all subdirectories and files
if os.path.exists(self.persist_dir):
# Ask for confirmation
response = input(
f"Are you sure you want to delete data within {self.persist_dir}? [y/N] "
)
if response.strip().lower() != "y":
print("Aborted.")
return
os.system(f"rm -rf {self.persist_dir}")
print(f"Successfully cleared {self.persist_dir}")
self.verbose = verbose
ingestion_pipeline = cast(IngestionPipeline, self.ingestion_pipeline)
if self.verbose:
print("Saving/Loading from persist_dir: ", self.persist_dir)
if files is not None:
documents = []
for _file in iglob(files, recursive=True):
_file = os.path.abspath(_file)
if os.path.isdir(_file):
reader = SimpleDirectoryReader(
input_dir=_file,
filename_as_id=True,
file_extractor=self.file_extractor,
)
else:
reader = SimpleDirectoryReader(
input_files=[_file],
filename_as_id=True,
file_extractor=self.file_extractor,
)
documents.extend(reader.load_data(show_progress=verbose))
await ingestion_pipeline.arun(show_progress=verbose, documents=documents)
ingestion_pipeline.persist(persist_dir=self.persist_dir)
# Append the `--files` argument to the history file
with open(f"{self.persist_dir}/{RAG_HISTORY_FILE_NAME}", "a") as f:
f.write(files + "\n")
if create_llama:
if shutil.which("npx") is None:
print(
"`npx` is not installed. Please install it by calling `npm install -g npx`"
)
else:
history_file_path = Path(f"{self.persist_dir}/{RAG_HISTORY_FILE_NAME}")
if not history_file_path.exists():
print(
"No data has been ingested, "
"please specify `--files` to create llama dataset."
)
else:
with open(history_file_path) as f:
stored_paths = {line.strip() for line in f if line.strip()}
if len(stored_paths) == 0:
print(
"No data has been ingested, "
"please specify `--files` to create llama dataset."
)
elif len(stored_paths) > 1:
print(
"Multiple files or folders were ingested, which is not supported by create-llama. "
"Please call `llamaindex-cli rag --clear` to clear the cache first, "
"then call `llamaindex-cli rag --files` again with a single folder or file"
)
else:
path = stored_paths.pop()
if "*" in path:
print(
"Glob pattern is not supported by create-llama. "
"Please call `llamaindex-cli rag --clear` to clear the cache first, "
"then call `llamaindex-cli rag --files` again with a single folder or file."
)
elif not os.path.exists(path):
print(
f"The path {path} does not exist. "
"Please call `llamaindex-cli rag --clear` to clear the cache first, "
"then call `llamaindex-cli rag --files` again with a single folder or file."
)
else:
print(f"Calling create-llama using data from {path} ...")
command_args = [
"npx",
"create-llama@latest",
"--frontend",
"--template",
"streaming",
"--framework",
"fastapi",
"--ui",
"shadcn",
"--vector-db",
"none",
"--engine",
"context",
f"--files {path}",
]
os.system(" ".join(command_args))
if question is not None:
await self.handle_question(question)
if chat:
await self.start_chat_repl()
async def handle_question(self, question: str) -> None:
if self.query_pipeline is None:
raise ValueError("query_pipeline is not defined.")
query_pipeline = cast(QueryPipeline, self.query_pipeline)
query_pipeline.verbose = self.verbose
chat_engine = cast(CondenseQuestionChatEngine, self.chat_engine)
response = chat_engine.chat(question)
if isinstance(response, StreamingResponse):
response.print_response_stream()
else:
response = cast(Response, response)
print(response)
async def start_chat_repl(self) -> None:
"""
Start a REPL for chatting with the agent.
"""
if self.query_pipeline is None:
raise ValueError("query_pipeline is not defined.")
chat_engine = cast(CondenseQuestionChatEngine, self.chat_engine)
chat_engine.streaming_chat_repl()
@classmethod
def add_parser_args(
cls,
parser: Union[ArgumentParser, Any],
instance_generator: Callable[[], "RagCLI"],
) -> None:
parser.add_argument(
"-q",
"--question",
type=str,
help="The question you want to ask.",
required=False,
)
parser.add_argument(
"-f",
"--files",
type=str,
help=(
"The name of the file or directory you want to ask a question about,"
'such as "file.pdf".'
),
)
parser.add_argument(
"-c",
"--chat",
help="If flag is present, opens a chat REPL.",
action="store_true",
)
parser.add_argument(
"-v",
"--verbose",
help="Whether to print out verbose information during execution.",
action="store_true",
)
parser.add_argument(
"--clear",
help="Clears out all currently embedded data.",
action="store_true",
)
parser.add_argument(
"--create-llama",
help="Create a LlamaIndex application with your embedded data.",
required=False,
action="store_true",
)
parser.set_defaults(
func=lambda args: asyncio.run(instance_generator().handle_cli(**vars(args)))
)
def cli(self) -> None:
"""
Entrypoint for CLI tool.
"""
parser = ArgumentParser(description="LlamaIndex RAG Q&A tool.")
subparsers = parser.add_subparsers(
title="commands", dest="command", required=True
)
llamarag_parser = subparsers.add_parser(
"rag", help="Ask a question to a document / a directory of documents."
)
self.add_parser_args(llamarag_parser, lambda: self)
# Parse the command-line arguments
args = parser.parse_args()
# Call the appropriate function based on the command
args.func(args)

View File

@ -0,0 +1,7 @@
"""Init composability."""
from llama_index.composability.base import ComposableGraph
from llama_index.composability.joint_qa_summary import QASummaryQueryEngineBuilder
__all__ = ["ComposableGraph", "QASummaryQueryEngineBuilder"]

View File

@ -0,0 +1,4 @@
"""Composable graph."""
# TODO: remove this file, only keep for backwards compatibility
from llama_index.indices.composability.graph import ComposableGraph # noqa

View File

@ -0,0 +1,98 @@
"""Joint QA Summary graph."""
from typing import Optional, Sequence
from llama_index.indices.list.base import SummaryIndex
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.ingestion import run_transformations
from llama_index.query_engine.router_query_engine import RouterQueryEngine
from llama_index.schema import Document
from llama_index.service_context import ServiceContext
from llama_index.storage.storage_context import StorageContext
from llama_index.tools.query_engine import QueryEngineTool
DEFAULT_SUMMARY_TEXT = "Use this index for summarization queries"
DEFAULT_QA_TEXT = (
"Use this index for queries that require retrieval of specific "
"context from documents."
)
class QASummaryQueryEngineBuilder:
"""Joint QA Summary graph builder.
Can build a graph that provides a unified query interface
for both QA and summarization tasks.
NOTE: this is a beta feature. The API may change in the future.
Args:
docstore (BaseDocumentStore): A BaseDocumentStore to use for storing nodes.
service_context (ServiceContext): A ServiceContext to use for
building indices.
summary_text (str): Text to use for the summary index.
qa_text (str): Text to use for the QA index.
node_parser (NodeParser): A NodeParser to use for parsing.
"""
def __init__(
self,
storage_context: Optional[StorageContext] = None,
service_context: Optional[ServiceContext] = None,
summary_text: str = DEFAULT_SUMMARY_TEXT,
qa_text: str = DEFAULT_QA_TEXT,
) -> None:
"""Init params."""
self._storage_context = storage_context or StorageContext.from_defaults()
self._service_context = service_context or ServiceContext.from_defaults()
self._summary_text = summary_text
self._qa_text = qa_text
def build_from_documents(
self,
documents: Sequence[Document],
) -> RouterQueryEngine:
"""Build query engine."""
# parse nodes
nodes = run_transformations(
documents, self._service_context.transformations # type: ignore
)
# ingest nodes
self._storage_context.docstore.add_documents(nodes, allow_update=True)
# build indices
vector_index = VectorStoreIndex(
nodes,
service_context=self._service_context,
storage_context=self._storage_context,
)
summary_index = SummaryIndex(
nodes,
service_context=self._service_context,
storage_context=self._storage_context,
)
vector_query_engine = vector_index.as_query_engine(
service_context=self._service_context
)
list_query_engine = summary_index.as_query_engine(
service_context=self._service_context,
response_mode="tree_summarize",
)
# build query engine
return RouterQueryEngine.from_defaults(
query_engine_tools=[
QueryEngineTool.from_defaults(
vector_query_engine, description=self._qa_text
),
QueryEngineTool.from_defaults(
list_query_engine, description=self._summary_text
),
],
service_context=self._service_context,
select_multi=False,
)

29
llama_index/constants.py Normal file
View File

@ -0,0 +1,29 @@
"""Set of constants."""
DEFAULT_TEMPERATURE = 0.1
DEFAULT_CONTEXT_WINDOW = 3900 # tokens
DEFAULT_NUM_OUTPUTS = 256 # tokens
DEFAULT_NUM_INPUT_FILES = 10 # files
DEFAULT_EMBED_BATCH_SIZE = 10
DEFAULT_CHUNK_SIZE = 1024 # tokens
DEFAULT_CHUNK_OVERLAP = 20 # tokens
DEFAULT_SIMILARITY_TOP_K = 2
DEFAULT_IMAGE_SIMILARITY_TOP_K = 2
# NOTE: for text-embedding-ada-002
DEFAULT_EMBEDDING_DIM = 1536
# context window size for llm predictor
COHERE_CONTEXT_WINDOW = 2048
AI21_J2_CONTEXT_WINDOW = 8192
TYPE_KEY = "__type__"
DATA_KEY = "__data__"
VECTOR_STORE_KEY = "vector_store"
IMAGE_STORE_KEY = "image_store"
GRAPH_STORE_KEY = "graph_store"
INDEX_STORE_KEY = "index_store"
DOC_STORE_KEY = "doc_store"

View File

View File

@ -0,0 +1,43 @@
from abc import abstractmethod
from typing import Any, List, Tuple
from llama_index.bridge.pydantic import BaseModel
from llama_index.core.base_retriever import BaseRetriever
from llama_index.schema import NodeWithScore, QueryBundle
class BaseAutoRetriever(BaseRetriever):
"""Base auto retriever."""
@abstractmethod
def generate_retrieval_spec(
self, query_bundle: QueryBundle, **kwargs: Any
) -> BaseModel:
"""Generate retrieval spec synchronously."""
...
@abstractmethod
async def agenerate_retrieval_spec(
self, query_bundle: QueryBundle, **kwargs: Any
) -> BaseModel:
"""Generate retrieval spec asynchronously."""
...
@abstractmethod
def _build_retriever_from_spec(
self, retrieval_spec: BaseModel
) -> Tuple[BaseRetriever, QueryBundle]:
"""Build retriever from spec and provide query bundle."""
...
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve using generated spec."""
retrieval_spec = self.generate_retrieval_spec(query_bundle)
retriever, new_query_bundle = self._build_retriever_from_spec(retrieval_spec)
return retriever.retrieve(new_query_bundle)
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve using generated spec asynchronously."""
retrieval_spec = await self.agenerate_retrieval_spec(query_bundle)
retriever, new_query_bundle = self._build_retriever_from_spec(retrieval_spec)
return await retriever.aretrieve(new_query_bundle)

View File

@ -0,0 +1,70 @@
"""base multi modal retriever."""
from abc import abstractmethod
from typing import List
from llama_index.core.base_retriever import BaseRetriever
from llama_index.core.image_retriever import BaseImageRetriever
from llama_index.indices.query.schema import QueryType
from llama_index.schema import NodeWithScore
class MultiModalRetriever(BaseRetriever, BaseImageRetriever):
"""Multi Modal base retriever."""
@abstractmethod
def text_retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
"""Retrieve text nodes given text query.
Implemented by the user.
"""
@abstractmethod
def text_to_image_retrieve(
self, str_or_query_bundle: QueryType
) -> List[NodeWithScore]:
"""Retrieve image nodes given text query.
Implemented by the user.
"""
@abstractmethod
def image_to_image_retrieve(
self, str_or_query_bundle: QueryType
) -> List[NodeWithScore]:
"""Retrieve image nodes given image query.
Implemented by the user.
"""
@abstractmethod
async def atext_retrieve(
self, str_or_query_bundle: QueryType
) -> List[NodeWithScore]:
"""Async Retrieve text nodes given text query.
Implemented by the user.
"""
@abstractmethod
async def atext_to_image_retrieve(
self, str_or_query_bundle: QueryType
) -> List[NodeWithScore]:
"""Async Retrieve image nodes given text query.
Implemented by the user.
"""
@abstractmethod
async def aimage_to_image_retrieve(
self, str_or_query_bundle: QueryType
) -> List[NodeWithScore]:
"""Async Retrieve image nodes given image query.
Implemented by the user.
"""

View File

@ -0,0 +1,122 @@
"""Base query engine."""
import logging
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Sequence
from llama_index.bridge.pydantic import Field
from llama_index.callbacks.base import CallbackManager
from llama_index.core.query_pipeline.query_component import (
ChainableMixin,
InputKeys,
OutputKeys,
QueryComponent,
validate_and_convert_stringable,
)
from llama_index.core.response.schema import RESPONSE_TYPE
from llama_index.prompts.mixin import PromptDictType, PromptMixin
from llama_index.schema import NodeWithScore, QueryBundle, QueryType
logger = logging.getLogger(__name__)
class BaseQueryEngine(ChainableMixin, PromptMixin):
"""Base query engine."""
def __init__(self, callback_manager: Optional[CallbackManager]) -> None:
self.callback_manager = callback_manager or CallbackManager([])
def _get_prompts(self) -> Dict[str, Any]:
"""Get prompts."""
return {}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
def query(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE:
with self.callback_manager.as_trace("query"):
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(str_or_query_bundle)
return self._query(str_or_query_bundle)
async def aquery(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE:
with self.callback_manager.as_trace("query"):
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(str_or_query_bundle)
return await self._aquery(str_or_query_bundle)
def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
raise NotImplementedError(
"This query engine does not support retrieve, use query directly"
)
def synthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
) -> RESPONSE_TYPE:
raise NotImplementedError(
"This query engine does not support synthesize, use query directly"
)
async def asynthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
) -> RESPONSE_TYPE:
raise NotImplementedError(
"This query engine does not support asynthesize, use aquery directly"
)
@abstractmethod
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
pass
@abstractmethod
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
pass
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
"""Return a query component."""
return QueryEngineComponent(query_engine=self)
class QueryEngineComponent(QueryComponent):
"""Query engine component."""
query_engine: BaseQueryEngine = Field(..., description="Query engine")
class Config:
arbitrary_types_allowed = True
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
self.query_engine.callback_manager = callback_manager
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
# make sure input is a string
input["input"] = validate_and_convert_stringable(input["input"])
return input
def _run_component(self, **kwargs: Any) -> Any:
"""Run component."""
output = self.query_engine.query(kwargs["input"])
return {"output": output}
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component."""
output = await self.query_engine.aquery(kwargs["input"])
return {"output": output}
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
return InputKeys.from_keys({"input"})
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
return OutputKeys.from_keys({"output"})

View File

@ -0,0 +1,324 @@
"""Base retriever."""
from abc import abstractmethod
from typing import Any, Dict, List, Optional
from llama_index.bridge.pydantic import Field
from llama_index.callbacks.base import CallbackManager
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.query_pipeline.query_component import (
ChainableMixin,
InputKeys,
OutputKeys,
QueryComponent,
validate_and_convert_stringable,
)
from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType
from llama_index.schema import (
BaseNode,
IndexNode,
NodeWithScore,
QueryBundle,
QueryType,
TextNode,
)
from llama_index.service_context import ServiceContext
from llama_index.utils import print_text
class BaseRetriever(ChainableMixin, PromptMixin):
"""Base retriever."""
def __init__(
self,
callback_manager: Optional[CallbackManager] = None,
object_map: Optional[Dict] = None,
objects: Optional[List[IndexNode]] = None,
verbose: bool = False,
) -> None:
self.callback_manager = callback_manager or CallbackManager()
if objects is not None:
object_map = {obj.index_id: obj.obj for obj in objects}
self.object_map = object_map or {}
self._verbose = verbose
def _check_callback_manager(self) -> None:
"""Check callback manager."""
if not hasattr(self, "callback_manager"):
self.callback_manager = CallbackManager()
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {}
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt modules."""
return {}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
def _retrieve_from_object(
self,
obj: Any,
query_bundle: QueryBundle,
score: float,
) -> List[NodeWithScore]:
"""Retrieve nodes from object."""
if self._verbose:
print_text(
f"Retrieving from object {obj.__class__.__name__} with query {query_bundle.query_str}\n",
color="llama_pink",
)
if isinstance(obj, NodeWithScore):
return [obj]
elif isinstance(obj, BaseNode):
return [NodeWithScore(node=obj, score=score)]
elif isinstance(obj, BaseQueryEngine):
response = obj.query(query_bundle)
return [
NodeWithScore(
node=TextNode(text=str(response), metadata=response.metadata or {}),
score=score,
)
]
elif isinstance(obj, BaseRetriever):
return obj.retrieve(query_bundle)
elif isinstance(obj, QueryComponent):
component_keys = obj.input_keys.required_keys
if len(component_keys) > 1:
raise ValueError(
f"QueryComponent {obj} has more than one input key: {component_keys}"
)
elif len(component_keys) == 0:
component_response = obj.run_component()
else:
kwargs = {next(iter(component_keys)): query_bundle.query_str}
component_response = obj.run_component(**kwargs)
result_output = str(next(iter(component_response.values())))
return [NodeWithScore(node=TextNode(text=result_output), score=score)]
else:
raise ValueError(f"Object {obj} is not retrievable.")
async def _aretrieve_from_object(
self,
obj: Any,
query_bundle: QueryBundle,
score: float,
) -> List[NodeWithScore]:
"""Retrieve nodes from object."""
if isinstance(obj, NodeWithScore):
return [obj]
elif isinstance(obj, BaseNode):
return [NodeWithScore(node=obj, score=score)]
elif isinstance(obj, BaseQueryEngine):
response = await obj.aquery(query_bundle)
return [NodeWithScore(node=TextNode(text=str(response)), score=score)]
elif isinstance(obj, BaseRetriever):
return await obj.aretrieve(query_bundle)
elif isinstance(obj, QueryComponent):
component_keys = obj.input_keys.required_keys
if len(component_keys) > 1:
raise ValueError(
f"QueryComponent {obj} has more than one input key: {component_keys}"
)
elif len(component_keys) == 0:
component_response = await obj.arun_component()
else:
kwargs = {next(iter(component_keys)): query_bundle.query_str}
component_response = await obj.arun_component(**kwargs)
result_output = str(next(iter(component_response.values())))
return [NodeWithScore(node=TextNode(text=result_output), score=score)]
else:
raise ValueError(f"Object {obj} is not retrievable.")
def _handle_recursive_retrieval(
self, query_bundle: QueryBundle, nodes: List[NodeWithScore]
) -> List[NodeWithScore]:
retrieved_nodes: List[NodeWithScore] = []
for n in nodes:
node = n.node
score = n.score or 1.0
if isinstance(node, IndexNode):
obj = self.object_map.get(node.index_id, None)
if obj is not None:
if self._verbose:
print_text(
f"Retrieval entering {node.index_id}: {obj.__class__.__name__}\n",
color="llama_turquoise",
)
retrieved_nodes.extend(
self._retrieve_from_object(
obj, query_bundle=query_bundle, score=score
)
)
else:
retrieved_nodes.append(n)
else:
retrieved_nodes.append(n)
seen = set()
return [
n
for n in retrieved_nodes
if not (n.node.hash in seen or seen.add(n.node.hash)) # type: ignore[func-returns-value]
]
async def _ahandle_recursive_retrieval(
self, query_bundle: QueryBundle, nodes: List[NodeWithScore]
) -> List[NodeWithScore]:
retrieved_nodes: List[NodeWithScore] = []
for n in nodes:
node = n.node
score = n.score or 1.0
if isinstance(node, IndexNode):
obj = self.object_map.get(node.index_id, None)
if obj is not None:
if self._verbose:
print_text(
f"Retrieval entering {node.index_id}: {obj.__class__.__name__}\n",
color="llama_turquoise",
)
# TODO: Add concurrent execution via `run_jobs()` ?
retrieved_nodes.extend(
await self._aretrieve_from_object(
obj, query_bundle=query_bundle, score=score
)
)
else:
retrieved_nodes.append(n)
else:
retrieved_nodes.append(n)
# remove any duplicates based on hash
seen = set()
return [
n
for n in retrieved_nodes
if not (n.node.hash in seen or seen.add(n.node.hash)) # type: ignore[func-returns-value]
]
def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
"""Retrieve nodes given query.
Args:
str_or_query_bundle (QueryType): Either a query string or
a QueryBundle object.
"""
self._check_callback_manager()
if isinstance(str_or_query_bundle, str):
query_bundle = QueryBundle(str_or_query_bundle)
else:
query_bundle = str_or_query_bundle
with self.callback_manager.as_trace("query"):
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: query_bundle.query_str},
) as retrieve_event:
nodes = self._retrieve(query_bundle)
nodes = self._handle_recursive_retrieval(query_bundle, nodes)
retrieve_event.on_end(
payload={EventPayload.NODES: nodes},
)
return nodes
async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
self._check_callback_manager()
if isinstance(str_or_query_bundle, str):
query_bundle = QueryBundle(str_or_query_bundle)
else:
query_bundle = str_or_query_bundle
with self.callback_manager.as_trace("query"):
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: query_bundle.query_str},
) as retrieve_event:
nodes = await self._aretrieve(query_bundle)
nodes = await self._ahandle_recursive_retrieval(query_bundle, nodes)
retrieve_event.on_end(
payload={EventPayload.NODES: nodes},
)
return nodes
@abstractmethod
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve nodes given query.
Implemented by the user.
"""
# TODO: make this abstract
# @abstractmethod
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Asynchronously retrieve nodes given query.
Implemented by the user.
"""
return self._retrieve(query_bundle)
def get_service_context(self) -> Optional[ServiceContext]:
"""Attempts to resolve a service context.
Short-circuits at self.service_context, self._service_context,
or self._index.service_context.
"""
if hasattr(self, "service_context"):
return self.service_context
if hasattr(self, "_service_context"):
return self._service_context
elif hasattr(self, "_index") and hasattr(self._index, "service_context"):
return self._index.service_context
return None
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
"""Return a query component."""
return RetrieverComponent(retriever=self)
class RetrieverComponent(QueryComponent):
"""Retriever component."""
retriever: BaseRetriever = Field(..., description="Retriever")
class Config:
arbitrary_types_allowed = True
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
self.retriever.callback_manager = callback_manager
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
# make sure input is a string
input["input"] = validate_and_convert_stringable(input["input"])
return input
def _run_component(self, **kwargs: Any) -> Any:
"""Run component."""
output = self.retriever.retrieve(kwargs["input"])
return {"output": output}
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component."""
output = await self.retriever.aretrieve(kwargs["input"])
return {"output": output}
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
return InputKeys.from_keys({"input"})
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
return OutputKeys.from_keys({"output"})

View File

@ -0,0 +1,112 @@
from abc import abstractmethod
from typing import Any, List, Sequence, Union
from llama_index.bridge.pydantic import BaseModel
from llama_index.core.query_pipeline.query_component import (
ChainableMixin,
QueryComponent,
)
from llama_index.prompts.mixin import PromptMixin, PromptMixinType
from llama_index.schema import QueryBundle, QueryType
from llama_index.tools.types import ToolMetadata
MetadataType = Union[str, ToolMetadata]
class SingleSelection(BaseModel):
"""A single selection of a choice."""
index: int
reason: str
class MultiSelection(BaseModel):
"""A multi-selection of choices."""
selections: List[SingleSelection]
@property
def ind(self) -> int:
if len(self.selections) != 1:
raise ValueError(
f"There are {len(self.selections)} selections, " "please use .inds."
)
return self.selections[0].index
@property
def reason(self) -> str:
if len(self.reasons) != 1:
raise ValueError(
f"There are {len(self.reasons)} selections, " "please use .reasons."
)
return self.selections[0].reason
@property
def inds(self) -> List[int]:
return [x.index for x in self.selections]
@property
def reasons(self) -> List[str]:
return [x.reason for x in self.selections]
# separate name for clarity and to not confuse function calling model
SelectorResult = MultiSelection
def _wrap_choice(choice: MetadataType) -> ToolMetadata:
if isinstance(choice, ToolMetadata):
return choice
elif isinstance(choice, str):
return ToolMetadata(description=choice)
else:
raise ValueError(f"Unexpected type: {type(choice)}")
def _wrap_query(query: QueryType) -> QueryBundle:
if isinstance(query, QueryBundle):
return query
elif isinstance(query, str):
return QueryBundle(query_str=query)
else:
raise ValueError(f"Unexpected type: {type(query)}")
class BaseSelector(PromptMixin, ChainableMixin):
"""Base selector."""
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt sub-modules."""
return {}
def select(
self, choices: Sequence[MetadataType], query: QueryType
) -> SelectorResult:
metadatas = [_wrap_choice(choice) for choice in choices]
query_bundle = _wrap_query(query)
return self._select(choices=metadatas, query=query_bundle)
async def aselect(
self, choices: Sequence[MetadataType], query: QueryType
) -> SelectorResult:
metadatas = [_wrap_choice(choice) for choice in choices]
query_bundle = _wrap_query(query)
return await self._aselect(choices=metadatas, query=query_bundle)
@abstractmethod
def _select(
self, choices: Sequence[ToolMetadata], query: QueryBundle
) -> SelectorResult:
pass
@abstractmethod
async def _aselect(
self, choices: Sequence[ToolMetadata], query: QueryBundle
) -> SelectorResult:
pass
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
"""As query component."""
from llama_index.query_pipeline.components.router import SelectorComponent
return SelectorComponent(selector=self)

View File

View File

@ -0,0 +1,354 @@
"""Base embeddings file."""
import asyncio
from abc import abstractmethod
from enum import Enum
from typing import Any, Callable, Coroutine, List, Optional, Tuple
import numpy as np
from llama_index.bridge.pydantic import Field, validator
from llama_index.callbacks.base import CallbackManager
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.constants import (
DEFAULT_EMBED_BATCH_SIZE,
)
from llama_index.schema import BaseNode, MetadataMode, TransformComponent
from llama_index.utils import get_tqdm_iterable
# TODO: change to numpy array
Embedding = List[float]
class SimilarityMode(str, Enum):
"""Modes for similarity/distance."""
DEFAULT = "cosine"
DOT_PRODUCT = "dot_product"
EUCLIDEAN = "euclidean"
def mean_agg(embeddings: List[Embedding]) -> Embedding:
"""Mean aggregation for embeddings."""
return list(np.array(embeddings).mean(axis=0))
def similarity(
embedding1: Embedding,
embedding2: Embedding,
mode: SimilarityMode = SimilarityMode.DEFAULT,
) -> float:
"""Get embedding similarity."""
if mode == SimilarityMode.EUCLIDEAN:
# Using -euclidean distance as similarity to achieve same ranking order
return -float(np.linalg.norm(np.array(embedding1) - np.array(embedding2)))
elif mode == SimilarityMode.DOT_PRODUCT:
return np.dot(embedding1, embedding2)
else:
product = np.dot(embedding1, embedding2)
norm = np.linalg.norm(embedding1) * np.linalg.norm(embedding2)
return product / norm
class BaseEmbedding(TransformComponent):
"""Base class for embeddings."""
model_name: str = Field(
default="unknown", description="The name of the embedding model."
)
embed_batch_size: int = Field(
default=DEFAULT_EMBED_BATCH_SIZE,
description="The batch size for embedding calls.",
gt=0,
lte=2048,
)
callback_manager: CallbackManager = Field(
default_factory=lambda: CallbackManager([]), exclude=True
)
class Config:
arbitrary_types_allowed = True
@validator("callback_manager", pre=True)
def _validate_callback_manager(
cls, v: Optional[CallbackManager]
) -> CallbackManager:
if v is None:
return CallbackManager([])
return v
@abstractmethod
def _get_query_embedding(self, query: str) -> Embedding:
"""
Embed the input query synchronously.
Subclasses should implement this method. Reference get_query_embedding's
docstring for more information.
"""
@abstractmethod
async def _aget_query_embedding(self, query: str) -> Embedding:
"""
Embed the input query asynchronously.
Subclasses should implement this method. Reference get_query_embedding's
docstring for more information.
"""
def get_query_embedding(self, query: str) -> Embedding:
"""
Embed the input query.
When embedding a query, depending on the model, a special instruction
can be prepended to the raw query string. For example, "Represent the
question for retrieving supporting documents: ". If you're curious,
other examples of predefined instructions can be found in
embeddings/huggingface_utils.py.
"""
with self.callback_manager.event(
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
) as event:
query_embedding = self._get_query_embedding(query)
event.on_end(
payload={
EventPayload.CHUNKS: [query],
EventPayload.EMBEDDINGS: [query_embedding],
},
)
return query_embedding
async def aget_query_embedding(self, query: str) -> Embedding:
"""Get query embedding."""
with self.callback_manager.event(
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
) as event:
query_embedding = await self._aget_query_embedding(query)
event.on_end(
payload={
EventPayload.CHUNKS: [query],
EventPayload.EMBEDDINGS: [query_embedding],
},
)
return query_embedding
def get_agg_embedding_from_queries(
self,
queries: List[str],
agg_fn: Optional[Callable[..., Embedding]] = None,
) -> Embedding:
"""Get aggregated embedding from multiple queries."""
query_embeddings = [self.get_query_embedding(query) for query in queries]
agg_fn = agg_fn or mean_agg
return agg_fn(query_embeddings)
async def aget_agg_embedding_from_queries(
self,
queries: List[str],
agg_fn: Optional[Callable[..., Embedding]] = None,
) -> Embedding:
"""Async get aggregated embedding from multiple queries."""
query_embeddings = [await self.aget_query_embedding(query) for query in queries]
agg_fn = agg_fn or mean_agg
return agg_fn(query_embeddings)
@abstractmethod
def _get_text_embedding(self, text: str) -> Embedding:
"""
Embed the input text synchronously.
Subclasses should implement this method. Reference get_text_embedding's
docstring for more information.
"""
async def _aget_text_embedding(self, text: str) -> Embedding:
"""
Embed the input text asynchronously.
Subclasses can implement this method if there is a true async
implementation. Reference get_text_embedding's docstring for more
information.
"""
# Default implementation just falls back on _get_text_embedding
return self._get_text_embedding(text)
def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
"""
Embed the input sequence of text synchronously.
Subclasses can implement this method if batch queries are supported.
"""
# Default implementation just loops over _get_text_embedding
return [self._get_text_embedding(text) for text in texts]
async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
"""
Embed the input sequence of text asynchronously.
Subclasses can implement this method if batch queries are supported.
"""
return await asyncio.gather(
*[self._aget_text_embedding(text) for text in texts]
)
def get_text_embedding(self, text: str) -> Embedding:
"""
Embed the input text.
When embedding text, depending on the model, a special instruction
can be prepended to the raw text string. For example, "Represent the
document for retrieval: ". If you're curious, other examples of
predefined instructions can be found in embeddings/huggingface_utils.py.
"""
with self.callback_manager.event(
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
) as event:
text_embedding = self._get_text_embedding(text)
event.on_end(
payload={
EventPayload.CHUNKS: [text],
EventPayload.EMBEDDINGS: [text_embedding],
}
)
return text_embedding
async def aget_text_embedding(self, text: str) -> Embedding:
"""Async get text embedding."""
with self.callback_manager.event(
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
) as event:
text_embedding = await self._aget_text_embedding(text)
event.on_end(
payload={
EventPayload.CHUNKS: [text],
EventPayload.EMBEDDINGS: [text_embedding],
}
)
return text_embedding
def get_text_embedding_batch(
self,
texts: List[str],
show_progress: bool = False,
**kwargs: Any,
) -> List[Embedding]:
"""Get a list of text embeddings, with batching."""
cur_batch: List[str] = []
result_embeddings: List[Embedding] = []
queue_with_progress = enumerate(
get_tqdm_iterable(texts, show_progress, "Generating embeddings")
)
for idx, text in queue_with_progress:
cur_batch.append(text)
if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size:
# flush
with self.callback_manager.event(
CBEventType.EMBEDDING,
payload={EventPayload.SERIALIZED: self.to_dict()},
) as event:
embeddings = self._get_text_embeddings(cur_batch)
result_embeddings.extend(embeddings)
event.on_end(
payload={
EventPayload.CHUNKS: cur_batch,
EventPayload.EMBEDDINGS: embeddings,
},
)
cur_batch = []
return result_embeddings
async def aget_text_embedding_batch(
self, texts: List[str], show_progress: bool = False
) -> List[Embedding]:
"""Asynchronously get a list of text embeddings, with batching."""
cur_batch: List[str] = []
callback_payloads: List[Tuple[str, List[str]]] = []
result_embeddings: List[Embedding] = []
embeddings_coroutines: List[Coroutine] = []
for idx, text in enumerate(texts):
cur_batch.append(text)
if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size:
# flush
event_id = self.callback_manager.on_event_start(
CBEventType.EMBEDDING,
payload={EventPayload.SERIALIZED: self.to_dict()},
)
callback_payloads.append((event_id, cur_batch))
embeddings_coroutines.append(self._aget_text_embeddings(cur_batch))
cur_batch = []
# flatten the results of asyncio.gather, which is a list of embeddings lists
nested_embeddings = []
if show_progress:
try:
from tqdm.auto import tqdm
nested_embeddings = [
await f
for f in tqdm(
asyncio.as_completed(embeddings_coroutines),
total=len(embeddings_coroutines),
desc="Generating embeddings",
)
]
except ImportError:
nested_embeddings = await asyncio.gather(*embeddings_coroutines)
else:
nested_embeddings = await asyncio.gather(*embeddings_coroutines)
result_embeddings = [
embedding for embeddings in nested_embeddings for embedding in embeddings
]
for (event_id, text_batch), embeddings in zip(
callback_payloads, nested_embeddings
):
self.callback_manager.on_event_end(
CBEventType.EMBEDDING,
payload={
EventPayload.CHUNKS: text_batch,
EventPayload.EMBEDDINGS: embeddings,
},
event_id=event_id,
)
return result_embeddings
def similarity(
self,
embedding1: Embedding,
embedding2: Embedding,
mode: SimilarityMode = SimilarityMode.DEFAULT,
) -> float:
"""Get embedding similarity."""
return similarity(embedding1=embedding1, embedding2=embedding2, mode=mode)
def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:
embeddings = self.get_text_embedding_batch(
[node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes],
**kwargs,
)
for node, embedding in zip(nodes, embeddings):
node.embedding = embedding
return nodes
async def acall(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:
embeddings = await self.aget_text_embedding_batch(
[node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes],
**kwargs,
)
for node, embedding in zip(nodes, embeddings):
node.embedding = embedding
return nodes

View File

@ -0,0 +1,103 @@
from abc import abstractmethod
from typing import List
from llama_index.indices.query.schema import QueryBundle, QueryType
from llama_index.prompts.mixin import PromptMixin
from llama_index.schema import NodeWithScore
class BaseImageRetriever(PromptMixin):
"""Base Image Retriever Abstraction."""
def text_to_image_retrieve(
self, str_or_query_bundle: QueryType
) -> List[NodeWithScore]:
"""Retrieve image nodes given query or single image input.
Args:
str_or_query_bundle (QueryType): a query text
string or a QueryBundle object.
"""
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(query_str=str_or_query_bundle)
return self._text_to_image_retrieve(str_or_query_bundle)
@abstractmethod
def _text_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
"""Retrieve image nodes or documents given query text.
Implemented by the user.
"""
def image_to_image_retrieve(
self, str_or_query_bundle: QueryType
) -> List[NodeWithScore]:
"""Retrieve image nodes given single image input.
Args:
str_or_query_bundle (QueryType): a image path
string or a QueryBundle object.
"""
if isinstance(str_or_query_bundle, str):
# leave query_str as empty since we are using image_path for image retrieval
str_or_query_bundle = QueryBundle(
query_str="", image_path=str_or_query_bundle
)
return self._image_to_image_retrieve(str_or_query_bundle)
@abstractmethod
def _image_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
"""Retrieve image nodes or documents given image.
Implemented by the user.
"""
# Async Methods
async def atext_to_image_retrieve(
self,
str_or_query_bundle: QueryType,
) -> List[NodeWithScore]:
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(query_str=str_or_query_bundle)
return await self._atext_to_image_retrieve(str_or_query_bundle)
@abstractmethod
async def _atext_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
"""Async retrieve image nodes or documents given query text.
Implemented by the user.
"""
async def aimage_to_image_retrieve(
self,
str_or_query_bundle: QueryType,
) -> List[NodeWithScore]:
if isinstance(str_or_query_bundle, str):
# leave query_str as empty since we are using image_path for image retrieval
str_or_query_bundle = QueryBundle(
query_str="", image_path=str_or_query_bundle
)
return await self._aimage_to_image_retrieve(str_or_query_bundle)
@abstractmethod
async def _aimage_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
"""Async retrieve image nodes or documents given image.
Implemented by the user.
"""

View File

View File

@ -0,0 +1,116 @@
from enum import Enum
from typing import Any, AsyncGenerator, Generator, Optional
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
class MessageRole(str, Enum):
"""Message role."""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
FUNCTION = "function"
TOOL = "tool"
CHATBOT = "chatbot"
# ===== Generic Model Input - Chat =====
class ChatMessage(BaseModel):
"""Chat message."""
role: MessageRole = MessageRole.USER
content: Optional[Any] = ""
additional_kwargs: dict = Field(default_factory=dict)
def __str__(self) -> str:
return f"{self.role.value}: {self.content}"
# ===== Generic Model Output - Chat =====
class ChatResponse(BaseModel):
"""Chat response."""
message: ChatMessage
raw: Optional[dict] = None
delta: Optional[str] = None
additional_kwargs: dict = Field(default_factory=dict)
def __str__(self) -> str:
return str(self.message)
ChatResponseGen = Generator[ChatResponse, None, None]
ChatResponseAsyncGen = AsyncGenerator[ChatResponse, None]
# ===== Generic Model Output - Completion =====
class CompletionResponse(BaseModel):
"""
Completion response.
Fields:
text: Text content of the response if not streaming, or if streaming,
the current extent of streamed text.
additional_kwargs: Additional information on the response(i.e. token
counts, function calling information).
raw: Optional raw JSON that was parsed to populate text, if relevant.
delta: New text that just streamed in (only relevant when streaming).
"""
text: str
additional_kwargs: dict = Field(default_factory=dict)
raw: Optional[dict] = None
delta: Optional[str] = None
def __str__(self) -> str:
return self.text
CompletionResponseGen = Generator[CompletionResponse, None, None]
CompletionResponseAsyncGen = AsyncGenerator[CompletionResponse, None]
class LLMMetadata(BaseModel):
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description=(
"Total number of tokens the model can be input and output for one response."
),
)
num_output: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description="Number of tokens the model can output when generating a response.",
)
is_chat_model: bool = Field(
default=False,
description=(
"Set True if the model exposes a chat interface (i.e. can be passed a"
" sequence of messages, rather than text), like OpenAI's"
" /v1/chat/completions endpoint."
),
)
is_function_calling_model: bool = Field(
default=False,
# SEE: https://openai.com/blog/function-calling-and-other-api-updates
description=(
"Set True if the model supports function calling messages, similar to"
" OpenAI's function calling API. For example, converting 'Email Anya to"
" see if she wants to get coffee next Friday' to a function call like"
" `send_email(to: string, body: string)`."
),
)
model_name: str = Field(
default="unknown",
description=(
"The model's name used for logging, testing, and sanity checking. For some"
" models this can be automatically discerned. For other models, like"
" locally loaded models, this must be manually specified."
),
)
system_role: MessageRole = Field(
default=MessageRole.SYSTEM,
description="The role this specific LLM provider"
"expects for system prompt. E.g. 'SYSTEM' for OpenAI, 'CHATBOT' for Cohere",
)

View File

@ -0,0 +1,266 @@
"""Query pipeline components."""
from inspect import signature
from typing import Any, Callable, Dict, Optional, Set, Tuple
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks.base import CallbackManager
from llama_index.core.query_pipeline.query_component import (
InputKeys,
OutputKeys,
QueryComponent,
)
def get_parameters(fn: Callable) -> Tuple[Set[str], Set[str]]:
"""Get parameters from function.
Returns:
Tuple[Set[str], Set[str]]: required and optional parameters
"""
# please write function below
params = signature(fn).parameters
required_params = set()
optional_params = set()
for param_name in params:
param_default = params[param_name].default
if param_default is params[param_name].empty:
required_params.add(param_name)
else:
optional_params.add(param_name)
return required_params, optional_params
class FnComponent(QueryComponent):
"""Query component that takes in an arbitrary function."""
fn: Callable = Field(..., description="Function to run.")
async_fn: Optional[Callable] = Field(
None, description="Async function to run. If not provided, will run `fn`."
)
output_key: str = Field(
default="output", description="Output key for component output."
)
_req_params: Set[str] = PrivateAttr()
_opt_params: Set[str] = PrivateAttr()
def __init__(
self,
fn: Callable,
async_fn: Optional[Callable] = None,
req_params: Optional[Set[str]] = None,
opt_params: Optional[Set[str]] = None,
output_key: str = "output",
**kwargs: Any,
) -> None:
"""Initialize."""
# determine parameters
default_req_params, default_opt_params = get_parameters(fn)
if req_params is None:
req_params = default_req_params
if opt_params is None:
opt_params = default_opt_params
self._req_params = req_params
self._opt_params = opt_params
super().__init__(fn=fn, async_fn=async_fn, output_key=output_key, **kwargs)
class Config:
arbitrary_types_allowed = True
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: implement
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
# check that all required parameters are present
missing_params = self._req_params - set(input.keys())
if missing_params:
raise ValueError(
f"Missing required parameters: {missing_params}. "
f"Input keys: {input.keys()}"
)
# check that no extra parameters are present
extra_params = set(input.keys()) - self._req_params - self._opt_params
if extra_params:
raise ValueError(
f"Extra parameters: {extra_params}. " f"Input keys: {input.keys()}"
)
return input
def _run_component(self, **kwargs: Any) -> Dict:
"""Run component."""
return {self.output_key: self.fn(**kwargs)}
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
if self.async_fn is None:
return self._run_component(**kwargs)
else:
return {self.output_key: await self.async_fn(**kwargs)}
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
return InputKeys.from_keys(
required_keys=self._req_params, optional_keys=self._opt_params
)
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
return OutputKeys.from_keys({self.output_key})
class InputComponent(QueryComponent):
"""Input component."""
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
return input
def _validate_component_outputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
return input
def validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs."""
# NOTE: we override this to do nothing
return input
def validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component outputs."""
# NOTE: we override this to do nothing
return output
def set_callback_manager(self, callback_manager: Any) -> None:
"""Set callback manager."""
def _run_component(self, **kwargs: Any) -> Any:
"""Run component."""
return kwargs
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
return self._run_component(**kwargs)
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
# NOTE: this shouldn't be used
return InputKeys.from_keys(set(), optional_keys=set())
# return InputComponentKeys.from_keys(set(), optional_keys=set())
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
return OutputKeys.from_keys(set())
class ArgPackComponent(QueryComponent):
"""Arg pack component.
Packs arbitrary number of args into a list.
"""
convert_fn: Optional[Callable] = Field(
default=None, description="Function to convert output."
)
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
raise NotImplementedError
def validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs."""
return input
def _validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component outputs."""
# make sure output value is a list
if not isinstance(output["output"], list):
raise ValueError(f"Output is not a list.")
return output
def set_callback_manager(self, callback_manager: Any) -> None:
"""Set callback manager."""
def _run_component(self, **kwargs: Any) -> Any:
"""Run component."""
# combine all lists into one
output = []
for v in kwargs.values():
if self.convert_fn is not None:
v = self.convert_fn(v)
output.append(v)
return {"output": output}
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
return self._run_component(**kwargs)
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
# NOTE: this shouldn't be used
return InputKeys.from_keys(set(), optional_keys=set())
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
return OutputKeys.from_keys({"output"})
class KwargPackComponent(QueryComponent):
"""Kwarg pack component.
Packs arbitrary number of kwargs into a dict.
"""
convert_fn: Optional[Callable] = Field(
default=None, description="Function to convert output."
)
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
raise NotImplementedError
def validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs."""
return input
def _validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component outputs."""
# make sure output value is a list
if not isinstance(output["output"], dict):
raise ValueError(f"Output is not a dict.")
return output
def set_callback_manager(self, callback_manager: Any) -> None:
"""Set callback manager."""
def _run_component(self, **kwargs: Any) -> Any:
"""Run component."""
if self.convert_fn is not None:
for k, v in kwargs.items():
kwargs[k] = self.convert_fn(v)
return {"output": kwargs}
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
return self._run_component(**kwargs)
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
# NOTE: this shouldn't be used
return InputKeys.from_keys(set(), optional_keys=set())
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
return OutputKeys.from_keys({"output"})

View File

@ -0,0 +1,338 @@
"""Pipeline schema."""
from abc import ABC, abstractmethod
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
Set,
Union,
cast,
get_args,
)
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.callbacks.base import CallbackManager
from llama_index.core.llms.types import (
ChatResponse,
CompletionResponse,
)
from llama_index.core.response.schema import Response
from llama_index.schema import NodeWithScore, QueryBundle, TextNode
## Define common types used throughout these components
StringableInput = Union[
CompletionResponse,
ChatResponse,
str,
QueryBundle,
Response,
Generator,
NodeWithScore,
TextNode,
]
def validate_and_convert_stringable(input: Any) -> str:
# special handling for generator
if isinstance(input, Generator):
# iterate through each element, make sure is stringable
new_input = ""
for elem in input:
if not isinstance(elem, get_args(StringableInput)):
raise ValueError(f"Input {elem} is not stringable.")
elif isinstance(elem, (ChatResponse, CompletionResponse)):
new_input += cast(str, elem.delta)
else:
new_input += str(elem)
return new_input
elif isinstance(input, List):
# iterate through each element, make sure is stringable
# do this recursively
new_input_list = []
for elem in input:
new_input_list.append(validate_and_convert_stringable(elem))
return str(new_input_list)
elif isinstance(input, ChatResponse):
return input.message.content or ""
elif isinstance(input, get_args(StringableInput)):
return str(input)
else:
raise ValueError(f"Input {input} is not stringable.")
class InputKeys(BaseModel):
"""Input keys."""
required_keys: Set[str] = Field(default_factory=set)
optional_keys: Set[str] = Field(default_factory=set)
@classmethod
def from_keys(
cls, required_keys: Set[str], optional_keys: Optional[Set[str]] = None
) -> "InputKeys":
"""Create InputKeys from tuple."""
return cls(required_keys=required_keys, optional_keys=optional_keys or set())
def validate(self, input_keys: Set[str]) -> None:
"""Validate input keys."""
# check if required keys are present, and that keys all are in required or optional
if not self.required_keys.issubset(input_keys):
raise ValueError(
f"Required keys {self.required_keys} are not present in input keys {input_keys}"
)
if not input_keys.issubset(self.required_keys.union(self.optional_keys)):
raise ValueError(
f"Input keys {input_keys} contain keys not in required or optional keys {self.required_keys.union(self.optional_keys)}"
)
def __len__(self) -> int:
"""Length of input keys."""
return len(self.required_keys) + len(self.optional_keys)
def all(self) -> Set[str]:
"""Get all input keys."""
return self.required_keys.union(self.optional_keys)
class OutputKeys(BaseModel):
"""Output keys."""
required_keys: Set[str] = Field(default_factory=set)
@classmethod
def from_keys(
cls,
required_keys: Set[str],
) -> "InputKeys":
"""Create InputKeys from tuple."""
return cls(required_keys=required_keys)
def validate(self, input_keys: Set[str]) -> None:
"""Validate input keys."""
# validate that input keys exactly match required keys
if input_keys != self.required_keys:
raise ValueError(
f"Input keys {input_keys} do not match required keys {self.required_keys}"
)
class ChainableMixin(ABC):
"""Chainable mixin.
A module that can produce a `QueryComponent` from a set of inputs through
`as_query_component`.
If plugged in directly into a `QueryPipeline`, the `ChainableMixin` will be
converted into a `QueryComponent` with default parameters.
"""
@abstractmethod
def _as_query_component(self, **kwargs: Any) -> "QueryComponent":
"""Get query component."""
def as_query_component(
self, partial: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> "QueryComponent":
"""Get query component."""
component = self._as_query_component(**kwargs)
component.partial(**(partial or {}))
return component
class QueryComponent(BaseModel):
"""Query component.
Represents a component that can be run in a `QueryPipeline`.
"""
partial_dict: Dict[str, Any] = Field(
default_factory=dict, description="Partial arguments to run_component"
)
# TODO: make this a subclass of BaseComponent (e.g. use Pydantic)
def partial(self, **kwargs: Any) -> None:
"""Update with partial arguments."""
self.partial_dict.update(kwargs)
@abstractmethod
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: refactor so that callback_manager is always passed in during runtime.
@property
def free_req_input_keys(self) -> Set[str]:
"""Get free input keys."""
return self.input_keys.required_keys.difference(self.partial_dict.keys())
@abstractmethod
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
def _validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component outputs during run_component."""
# override if needed
return output
def validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs."""
# make sure set of input keys == self.input_keys
self.input_keys.validate(set(input.keys()))
return self._validate_component_inputs(input)
def validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component outputs."""
# make sure set of output keys == self.output_keys
self.output_keys.validate(set(output.keys()))
return self._validate_component_outputs(output)
def run_component(self, **kwargs: Any) -> Dict[str, Any]:
"""Run component."""
kwargs.update(self.partial_dict)
kwargs = self.validate_component_inputs(kwargs)
component_outputs = self._run_component(**kwargs)
return self.validate_component_outputs(component_outputs)
async def arun_component(self, **kwargs: Any) -> Dict[str, Any]:
"""Run component."""
kwargs.update(self.partial_dict)
kwargs = self.validate_component_inputs(kwargs)
component_outputs = await self._arun_component(**kwargs)
return self.validate_component_outputs(component_outputs)
@abstractmethod
def _run_component(self, **kwargs: Any) -> Dict:
"""Run component."""
@abstractmethod
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
@property
@abstractmethod
def input_keys(self) -> InputKeys:
"""Input keys."""
@property
@abstractmethod
def output_keys(self) -> OutputKeys:
"""Output keys."""
@property
def sub_query_components(self) -> List["QueryComponent"]:
"""Get sub query components.
Certain query components may have sub query components, e.g. a
query pipeline will have sub query components, and so will
an IfElseComponent.
"""
return []
class CustomQueryComponent(QueryComponent):
"""Custom query component."""
callback_manager: CallbackManager = Field(
default_factory=CallbackManager, description="Callback manager"
)
class Config:
arbitrary_types_allowed = True
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
self.callback_manager = callback_manager
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
# NOTE: user can override this method to validate inputs
# but we do this by default for convenience
return input
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component (async)."""
raise NotImplementedError("This component does not support async run.")
@property
def _input_keys(self) -> Set[str]:
"""Input keys dict."""
raise NotImplementedError("Not implemented yet. Please override this method.")
@property
def _optional_input_keys(self) -> Set[str]:
"""Optional input keys dict."""
return set()
@property
def _output_keys(self) -> Set[str]:
"""Output keys dict."""
raise NotImplementedError("Not implemented yet. Please override this method.")
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
# NOTE: user can override this too, but we have them implement an
# abstract method to make sure they do it
return InputKeys.from_keys(
required_keys=self._input_keys, optional_keys=self._optional_input_keys
)
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
# NOTE: user can override this too, but we have them implement an
# abstract method to make sure they do it
return OutputKeys.from_keys(self._output_keys)
class Link(BaseModel):
"""Link between two components."""
src: str = Field(..., description="Source component name")
dest: str = Field(..., description="Destination component name")
src_key: Optional[str] = Field(
default=None, description="Source component output key"
)
dest_key: Optional[str] = Field(
default=None, description="Destination component input key"
)
condition_fn: Optional[Callable] = Field(
default=None, description="Condition to determine if link should be followed"
)
input_fn: Optional[Callable] = Field(
default=None, description="Input to destination component"
)
def __init__(
self,
src: str,
dest: str,
src_key: Optional[str] = None,
dest_key: Optional[str] = None,
condition_fn: Optional[Callable] = None,
input_fn: Optional[Callable] = None,
) -> None:
"""Init params."""
# NOTE: This is to enable positional args.
super().__init__(
src=src,
dest=dest,
src_key=src_key,
dest_key=dest_key,
condition_fn=condition_fn,
input_fn=input_fn,
)
# accept both QueryComponent and ChainableMixin as inputs to query pipeline
# ChainableMixin modules will be converted to components via `as_query_component`
QUERY_COMPONENT_TYPE = Union[QueryComponent, ChainableMixin]

View File

View File

@ -0,0 +1,142 @@
"""Response schema."""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from llama_index.bridge.pydantic import BaseModel
from llama_index.schema import NodeWithScore
from llama_index.types import TokenGen
from llama_index.utils import truncate_text
@dataclass
class Response:
"""Response object.
Returned if streaming=False.
Attributes:
response: The response text.
"""
response: Optional[str]
source_nodes: List[NodeWithScore] = field(default_factory=list)
metadata: Optional[Dict[str, Any]] = None
def __str__(self) -> str:
"""Convert to string representation."""
return self.response or "None"
def get_formatted_sources(self, length: int = 100) -> str:
"""Get formatted sources text."""
texts = []
for source_node in self.source_nodes:
fmt_text_chunk = truncate_text(source_node.node.get_content(), length)
doc_id = source_node.node.node_id or "None"
source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}"
texts.append(source_text)
return "\n\n".join(texts)
@dataclass
class PydanticResponse:
"""PydanticResponse object.
Returned if streaming=False.
Attributes:
response: The response text.
"""
response: Optional[BaseModel]
source_nodes: List[NodeWithScore] = field(default_factory=list)
metadata: Optional[Dict[str, Any]] = None
def __str__(self) -> str:
"""Convert to string representation."""
return self.response.json() if self.response else "None"
def __getattr__(self, name: str) -> Any:
"""Get attribute, but prioritize the pydantic response object."""
if self.response is not None and name in self.response.dict():
return getattr(self.response, name)
else:
return None
def get_formatted_sources(self, length: int = 100) -> str:
"""Get formatted sources text."""
texts = []
for source_node in self.source_nodes:
fmt_text_chunk = truncate_text(source_node.node.get_content(), length)
doc_id = source_node.node.node_id or "None"
source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}"
texts.append(source_text)
return "\n\n".join(texts)
def get_response(self) -> Response:
"""Get a standard response object."""
response_txt = self.response.json() if self.response else "None"
return Response(response_txt, self.source_nodes, self.metadata)
@dataclass
class StreamingResponse:
"""StreamingResponse object.
Returned if streaming=True.
Attributes:
response_gen: The response generator.
"""
response_gen: TokenGen
source_nodes: List[NodeWithScore] = field(default_factory=list)
metadata: Optional[Dict[str, Any]] = None
response_txt: Optional[str] = None
def __str__(self) -> str:
"""Convert to string representation."""
if self.response_txt is None and self.response_gen is not None:
response_txt = ""
for text in self.response_gen:
response_txt += text
self.response_txt = response_txt
return self.response_txt or "None"
def get_response(self) -> Response:
"""Get a standard response object."""
if self.response_txt is None and self.response_gen is not None:
response_txt = ""
for text in self.response_gen:
response_txt += text
self.response_txt = response_txt
return Response(self.response_txt, self.source_nodes, self.metadata)
def print_response_stream(self) -> None:
"""Print the response stream."""
if self.response_txt is None and self.response_gen is not None:
response_txt = ""
for text in self.response_gen:
print(text, end="", flush=True)
response_txt += text
self.response_txt = response_txt
else:
print(self.response_txt)
def get_formatted_sources(self, length: int = 100, trim_text: int = True) -> str:
"""Get formatted sources text."""
texts = []
for source_node in self.source_nodes:
fmt_text_chunk = source_node.node.get_content()
if trim_text:
fmt_text_chunk = truncate_text(fmt_text_chunk, length)
node_id = source_node.node.node_id or "None"
source_text = f"> Source (Node id: {node_id}): {fmt_text_chunk}"
texts.append(source_text)
return "\n\n".join(texts)
RESPONSE_TYPE = Union[Response, StreamingResponse, PydanticResponse]

View File

@ -0,0 +1,19 @@
"""Init file."""
from llama_index.data_structs.data_structs import (
IndexDict,
IndexGraph,
IndexList,
KeywordTable,
Node,
)
from llama_index.data_structs.table import StructDatapoint
__all__ = [
"IndexGraph",
"KeywordTable",
"IndexList",
"IndexDict",
"StructDatapoint",
"Node",
]

View File

@ -0,0 +1,267 @@
"""Data structures.
Nodes are decoupled from the indices.
"""
import uuid
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence, Set
from dataclasses_json import DataClassJsonMixin
from llama_index.data_structs.struct_type import IndexStructType
from llama_index.schema import BaseNode, TextNode
# TODO: legacy backport of old Node class
Node = TextNode
@dataclass
class IndexStruct(DataClassJsonMixin):
"""A base data struct for a LlamaIndex."""
index_id: str = field(default_factory=lambda: str(uuid.uuid4()))
summary: Optional[str] = None
def get_summary(self) -> str:
"""Get text summary."""
if self.summary is None:
raise ValueError("summary field of the index_struct not set.")
return self.summary
@classmethod
@abstractmethod
def get_type(cls) -> IndexStructType:
"""Get index struct type."""
@dataclass
class IndexGraph(IndexStruct):
"""A graph representing the tree-structured index."""
# mapping from index in tree to Node doc id.
all_nodes: Dict[int, str] = field(default_factory=dict)
root_nodes: Dict[int, str] = field(default_factory=dict)
node_id_to_children_ids: Dict[str, List[str]] = field(default_factory=dict)
@property
def node_id_to_index(self) -> Dict[str, int]:
"""Map from node id to index."""
return {node_id: index for index, node_id in self.all_nodes.items()}
@property
def size(self) -> int:
"""Get the size of the graph."""
return len(self.all_nodes)
def get_index(self, node: BaseNode) -> int:
"""Get index of node."""
return self.node_id_to_index[node.node_id]
def insert(
self,
node: BaseNode,
index: Optional[int] = None,
children_nodes: Optional[Sequence[BaseNode]] = None,
) -> None:
"""Insert node."""
index = index or self.size
node_id = node.node_id
self.all_nodes[index] = node_id
if children_nodes is None:
children_nodes = []
children_ids = [n.node_id for n in children_nodes]
self.node_id_to_children_ids[node_id] = children_ids
def get_children(self, parent_node: Optional[BaseNode]) -> Dict[int, str]:
"""Get children nodes."""
if parent_node is None:
return self.root_nodes
else:
parent_id = parent_node.node_id
children_ids = self.node_id_to_children_ids[parent_id]
return {
self.node_id_to_index[child_id]: child_id for child_id in children_ids
}
def insert_under_parent(
self,
node: BaseNode,
parent_node: Optional[BaseNode],
new_index: Optional[int] = None,
) -> None:
"""Insert under parent node."""
new_index = new_index or self.size
if parent_node is None:
self.root_nodes[new_index] = node.node_id
self.node_id_to_children_ids[node.node_id] = []
else:
if parent_node.node_id not in self.node_id_to_children_ids:
self.node_id_to_children_ids[parent_node.node_id] = []
self.node_id_to_children_ids[parent_node.node_id].append(node.node_id)
self.all_nodes[new_index] = node.node_id
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.TREE
@dataclass
class KeywordTable(IndexStruct):
"""A table of keywords mapping keywords to text chunks."""
table: Dict[str, Set[str]] = field(default_factory=dict)
def add_node(self, keywords: List[str], node: BaseNode) -> None:
"""Add text to table."""
for keyword in keywords:
if keyword not in self.table:
self.table[keyword] = set()
self.table[keyword].add(node.node_id)
@property
def node_ids(self) -> Set[str]:
"""Get all node ids."""
return set.union(*self.table.values())
@property
def keywords(self) -> Set[str]:
"""Get all keywords in the table."""
return set(self.table.keys())
@property
def size(self) -> int:
"""Get the size of the table."""
return len(self.table)
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.KEYWORD_TABLE
@dataclass
class IndexList(IndexStruct):
"""A list of documents."""
nodes: List[str] = field(default_factory=list)
def add_node(self, node: BaseNode) -> None:
"""Add text to table, return current position in list."""
# don't worry about child indices for now, nodes are all in order
self.nodes.append(node.node_id)
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.LIST
@dataclass
class IndexDict(IndexStruct):
"""A simple dictionary of documents."""
# TODO: slightly deprecated, should likely be a list or set now
# mapping from vector store id to node doc_id
nodes_dict: Dict[str, str] = field(default_factory=dict)
# TODO: deprecated, not used
# mapping from node doc_id to vector store id
doc_id_dict: Dict[str, List[str]] = field(default_factory=dict)
# TODO: deprecated, not used
# this should be empty for all other indices
embeddings_dict: Dict[str, List[float]] = field(default_factory=dict)
def add_node(
self,
node: BaseNode,
text_id: Optional[str] = None,
) -> str:
"""Add text to table, return current position in list."""
# # don't worry about child indices for now, nodes are all in order
# self.nodes_dict[int_id] = node
vector_id = text_id if text_id is not None else node.node_id
self.nodes_dict[vector_id] = node.node_id
return vector_id
def delete(self, doc_id: str) -> None:
"""Delete a Node."""
del self.nodes_dict[doc_id]
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.VECTOR_STORE
@dataclass
class MultiModelIndexDict(IndexDict):
"""A simple dictionary of documents, but loads a MultiModelVectorStore."""
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.MULTIMODAL_VECTOR_STORE
@dataclass
class KG(IndexStruct):
"""A table of keywords mapping keywords to text chunks."""
# Unidirectional
# table of keywords to node ids
table: Dict[str, Set[str]] = field(default_factory=dict)
# TODO: legacy attribute, remove in future releases
rel_map: Dict[str, List[List[str]]] = field(default_factory=dict)
# TBD, should support vector store, now we just persist the embedding memory
# maybe chainable abstractions for *_stores could be designed
embedding_dict: Dict[str, List[float]] = field(default_factory=dict)
@property
def node_ids(self) -> Set[str]:
"""Get all node ids."""
return set.union(*self.table.values())
def add_to_embedding_dict(self, triplet_str: str, embedding: List[float]) -> None:
"""Add embedding to dict."""
self.embedding_dict[triplet_str] = embedding
def add_node(self, keywords: List[str], node: BaseNode) -> None:
"""Add text to table."""
node_id = node.node_id
for keyword in keywords:
if keyword not in self.table:
self.table[keyword] = set()
self.table[keyword].add(node_id)
def search_node_by_keyword(self, keyword: str) -> List[str]:
"""Search for nodes by keyword."""
if keyword not in self.table:
return []
return list(self.table[keyword])
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.KG
@dataclass
class EmptyIndexStruct(IndexStruct):
"""Empty index."""
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.EMPTY

View File

@ -0,0 +1,73 @@
"""Data struct for document summary index."""
from dataclasses import dataclass, field
from typing import Dict, List
from llama_index.data_structs.data_structs import IndexStruct
from llama_index.data_structs.struct_type import IndexStructType
from llama_index.schema import BaseNode
@dataclass
class IndexDocumentSummary(IndexStruct):
"""A simple struct containing a mapping from summary node_id to doc node_ids.
Also mapping vice versa.
"""
summary_id_to_node_ids: Dict[str, List[str]] = field(default_factory=dict)
node_id_to_summary_id: Dict[str, str] = field(default_factory=dict)
# track mapping from doc id to node summary id
doc_id_to_summary_id: Dict[str, str] = field(default_factory=dict)
def add_summary_and_nodes(
self,
summary_node: BaseNode,
nodes: List[BaseNode],
) -> str:
"""Add node and summary."""
summary_id = summary_node.node_id
ref_doc_id = summary_node.ref_doc_id
if ref_doc_id is None:
raise ValueError(
"ref_doc_id of node cannot be None when building a document "
"summary index"
)
self.doc_id_to_summary_id[ref_doc_id] = summary_id
for node in nodes:
node_id = node.node_id
if summary_id not in self.summary_id_to_node_ids:
self.summary_id_to_node_ids[summary_id] = []
self.summary_id_to_node_ids[summary_id].append(node_id)
self.node_id_to_summary_id[node_id] = summary_id
return summary_id
@property
def summary_ids(self) -> List[str]:
"""Get summary ids."""
return list(self.summary_id_to_node_ids.keys())
def delete(self, doc_id: str) -> None:
"""Delete a document and its nodes."""
summary_id = self.doc_id_to_summary_id[doc_id]
del self.doc_id_to_summary_id[doc_id]
node_ids = self.summary_id_to_node_ids[summary_id]
for node_id in node_ids:
del self.node_id_to_summary_id[node_id]
del self.summary_id_to_node_ids[summary_id]
def delete_nodes(self, node_ids: List[str]) -> None:
for node_id in node_ids:
summary_id = self.node_id_to_summary_id[node_id]
self.summary_id_to_node_ids[summary_id].remove(node_id)
del self.node_id_to_summary_id[node_id]
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.DOCUMENT_SUMMARY

View File

@ -0,0 +1,30 @@
"""Index registry."""
from typing import Dict, Type
from llama_index.data_structs.data_structs import (
KG,
EmptyIndexStruct,
IndexDict,
IndexGraph,
IndexList,
IndexStruct,
KeywordTable,
MultiModelIndexDict,
)
from llama_index.data_structs.document_summary import IndexDocumentSummary
from llama_index.data_structs.struct_type import IndexStructType
from llama_index.data_structs.table import PandasStructTable, SQLStructTable
INDEX_STRUCT_TYPE_TO_INDEX_STRUCT_CLASS: Dict[IndexStructType, Type[IndexStruct]] = {
IndexStructType.TREE: IndexGraph,
IndexStructType.LIST: IndexList,
IndexStructType.KEYWORD_TABLE: KeywordTable,
IndexStructType.VECTOR_STORE: IndexDict,
IndexStructType.SQL: SQLStructTable,
IndexStructType.PANDAS: PandasStructTable,
IndexStructType.KG: KG,
IndexStructType.EMPTY: EmptyIndexStruct,
IndexStructType.DOCUMENT_SUMMARY: IndexDocumentSummary,
IndexStructType.MULTIMODAL_VECTOR_STORE: MultiModelIndexDict,
}

View File

@ -0,0 +1,110 @@
"""IndexStructType class."""
from enum import Enum
class IndexStructType(str, Enum):
"""Index struct type. Identifier for a "type" of index.
Attributes:
TREE ("tree"): Tree index. See :ref:`Ref-Indices-Tree` for tree indices.
LIST ("list"): Summary index. See :ref:`Ref-Indices-List` for summary indices.
KEYWORD_TABLE ("keyword_table"): Keyword table index. See
:ref:`Ref-Indices-Table`
for keyword table indices.
DICT ("dict"): Faiss Vector Store Index. See
:ref:`Ref-Indices-VectorStore`
for more information on the faiss vector store index.
SIMPLE_DICT ("simple_dict"): Simple Vector Store Index. See
:ref:`Ref-Indices-VectorStore`
for more information on the simple vector store index.
WEAVIATE ("weaviate"): Weaviate Vector Store Index.
See :ref:`Ref-Indices-VectorStore`
for more information on the Weaviate vector store index.
PINECONE ("pinecone"): Pinecone Vector Store Index.
See :ref:`Ref-Indices-VectorStore`
for more information on the Pinecone vector store index.
DEEPLAKE ("deeplake"): DeepLake Vector Store Index.
See :ref:`Ref-Indices-VectorStore`
for more information on the Pinecone vector store index.
QDRANT ("qdrant"): Qdrant Vector Store Index.
See :ref:`Ref-Indices-VectorStore`
for more information on the Qdrant vector store index.
LANCEDB ("lancedb"): LanceDB Vector Store Index
See :ref:`Ref-Indices-VectorStore`
for more information on the LanceDB vector store index.
MILVUS ("milvus"): Milvus Vector Store Index.
See :ref:`Ref-Indices-VectorStore`
for more information on the Milvus vector store index.
CHROMA ("chroma"): Chroma Vector Store Index.
See :ref:`Ref-Indices-VectorStore`
for more information on the Chroma vector store index.
OPENSEARCH ("opensearch"): Opensearch Vector Store Index.
See :ref:`Ref-Indices-VectorStore`
for more information on the Opensearch vector store index.
MYSCALE ("myscale"): MyScale Vector Store Index.
See :ref:`Ref-Indices-VectorStore`
for more information on the MyScale vector store index.
EPSILLA ("epsilla"): Epsilla Vector Store Index.
See :ref:`Ref-Indices-VectorStore`
for more information on the Epsilla vector store index.
CHATGPT_RETRIEVAL_PLUGIN ("chatgpt_retrieval_plugin"): ChatGPT
retrieval plugin index.
SQL ("SQL"): SQL Structured Store Index.
See :ref:`Ref-Indices-StructStore`
for more information on the SQL vector store index.
DASHVECTOR ("dashvector"): DashVector Vector Store Index.
See :ref:`Ref-Indices-VectorStore`
for more information on the Dashvecotor vector store index.
KG ("kg"): Knowledge Graph index.
See :ref:`Ref-Indices-Knowledge-Graph` for KG indices.
DOCUMENT_SUMMARY ("document_summary"): Document Summary Index.
See :ref:`Ref-Indices-Document-Summary` for Summary Indices.
"""
# TODO: refactor so these are properties on the base class
NODE = "node"
TREE = "tree"
LIST = "list"
KEYWORD_TABLE = "keyword_table"
# faiss
DICT = "dict"
# simple
SIMPLE_DICT = "simple_dict"
WEAVIATE = "weaviate"
PINECONE = "pinecone"
QDRANT = "qdrant"
LANCEDB = "lancedb"
MILVUS = "milvus"
CHROMA = "chroma"
MYSCALE = "myscale"
VECTOR_STORE = "vector_store"
OPENSEARCH = "opensearch"
DASHVECTOR = "dashvector"
CHATGPT_RETRIEVAL_PLUGIN = "chatgpt_retrieval_plugin"
DEEPLAKE = "deeplake"
EPSILLA = "epsilla"
# multimodal
MULTIMODAL_VECTOR_STORE = "multimodal"
# for SQL index
SQL = "sql"
# for KG index
KG = "kg"
SIMPLE_KG = "simple_kg"
NEBULAGRAPH = "nebulagraph"
FALKORDB = "falkordb"
# EMPTY
EMPTY = "empty"
COMPOSITE = "composite"
PANDAS = "pandas"
DOCUMENT_SUMMARY = "document_summary"
# Managed
VECTARA = "vectara"
ZILLIZ_CLOUD_PIPELINE = "zilliz_cloud_pipeline"

View File

@ -0,0 +1,45 @@
"""Struct store schema."""
from dataclasses import dataclass, field
from typing import Any, Dict
from dataclasses_json import DataClassJsonMixin
from llama_index.data_structs.data_structs import IndexStruct
from llama_index.data_structs.struct_type import IndexStructType
@dataclass
class StructDatapoint(DataClassJsonMixin):
"""Struct outputs."""
# map from field name to StructValue
fields: Dict[str, Any]
@dataclass
class BaseStructTable(IndexStruct):
"""Struct outputs."""
@dataclass
class SQLStructTable(BaseStructTable):
"""SQL struct outputs."""
context_dict: Dict[str, str] = field(default_factory=dict)
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
# TODO: consolidate with IndexStructType
return IndexStructType.SQL
@dataclass
class PandasStructTable(BaseStructTable):
"""Pandas struct outputs."""
@classmethod
def get_type(cls) -> IndexStructType:
"""Get type."""
return IndexStructType.PANDAS

View File

View File

@ -0,0 +1,262 @@
"""Download."""
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import requests
import tqdm
from llama_index.download.module import LLAMA_HUB_URL
from llama_index.download.utils import (
get_file_content,
get_file_content_bytes,
initialize_directory,
)
LLAMA_DATASETS_LFS_URL = (
f"https://media.githubusercontent.com/media/run-llama/llama-datasets/main"
)
LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL = (
"https://github.com/run-llama/llama-datasets/tree/main"
)
LLAMA_SOURCE_FILES_PATH = "source_files"
DATASET_CLASS_FILENAME_REGISTRY = {
"LabelledRagDataset": "rag_dataset.json",
"LabeledRagDataset": "rag_dataset.json",
"LabelledPairwiseEvaluatorDataset": "pairwise_evaluator_dataset.json",
"LabeledPairwiseEvaluatorDataset": "pairwise_evaluator_dataset.json",
"LabelledEvaluatorDataset": "evaluator_dataset.json",
"LabeledEvaluatorDataset": "evaluator_dataset.json",
}
PATH_TYPE = Union[str, Path]
def _resolve_dataset_file_name(class_name: str) -> str:
"""Resolve filename based on dataset class."""
try:
return DATASET_CLASS_FILENAME_REGISTRY[class_name]
except KeyError as err:
raise ValueError("Invalid dataset filename.") from err
def _get_source_files_list(source_tree_url: str, path: str) -> List[str]:
"""Get the list of source files to download."""
resp = requests.get(source_tree_url + path + "?recursive=1")
payload = resp.json()["payload"]
return [item["name"] for item in payload["tree"]["items"]]
def get_dataset_info(
local_dir_path: PATH_TYPE,
remote_dir_path: PATH_TYPE,
remote_source_dir_path: PATH_TYPE,
dataset_class: str,
refresh_cache: bool = False,
library_path: str = "library.json",
source_files_path: str = "source_files",
disable_library_cache: bool = False,
) -> Dict:
"""Get dataset info."""
if isinstance(local_dir_path, str):
local_dir_path = Path(local_dir_path)
local_library_path = f"{local_dir_path}/{library_path}"
dataset_id = None
source_files = []
# Check cache first
if not refresh_cache and os.path.exists(local_library_path):
with open(local_library_path) as f:
library = json.load(f)
if dataset_class in library:
dataset_id = library[dataset_class]["id"]
source_files = library[dataset_class].get("source_files", [])
# Fetch up-to-date library from remote repo if dataset_id not found
if dataset_id is None:
library_raw_content, _ = get_file_content(
str(remote_dir_path), f"/{library_path}"
)
library = json.loads(library_raw_content)
if dataset_class not in library:
raise ValueError("Loader class name not found in library")
dataset_id = library[dataset_class]["id"]
# get data card
raw_card_content, _ = get_file_content(
str(remote_dir_path), f"/{dataset_id}/card.json"
)
card = json.loads(raw_card_content)
dataset_class_name = card["className"]
source_files = []
if dataset_class_name == "LabelledRagDataset":
source_files = _get_source_files_list(
str(remote_source_dir_path), f"/{dataset_id}/{source_files_path}"
)
# create cache dir if needed
local_library_dir = os.path.dirname(local_library_path)
if not disable_library_cache:
if not os.path.exists(local_library_dir):
os.makedirs(local_library_dir)
# Update cache
with open(local_library_path, "w") as f:
f.write(library_raw_content)
if dataset_id is None:
raise ValueError("Dataset class name not found in library")
return {
"dataset_id": dataset_id,
"dataset_class_name": dataset_class_name,
"source_files": source_files,
}
def download_dataset_and_source_files(
local_dir_path: PATH_TYPE,
remote_lfs_dir_path: PATH_TYPE,
source_files_dir_path: PATH_TYPE,
dataset_id: str,
dataset_class_name: str,
source_files: List[str],
refresh_cache: bool = False,
base_file_name: str = "rag_dataset.json",
override_path: bool = False,
show_progress: bool = False,
) -> None:
"""Download dataset and source files."""
if isinstance(local_dir_path, str):
local_dir_path = Path(local_dir_path)
if override_path:
module_path = str(local_dir_path)
else:
module_path = f"{local_dir_path}/{dataset_id}"
if refresh_cache or not os.path.exists(module_path):
os.makedirs(module_path, exist_ok=True)
base_file_name = _resolve_dataset_file_name(dataset_class_name)
dataset_raw_content, _ = get_file_content(
str(remote_lfs_dir_path), f"/{dataset_id}/{base_file_name}"
)
with open(f"{module_path}/{base_file_name}", "w") as f:
f.write(dataset_raw_content)
# Get content of source files
if dataset_class_name == "LabelledRagDataset":
os.makedirs(f"{module_path}/{source_files_dir_path}", exist_ok=True)
if show_progress:
source_files_iterator = tqdm.tqdm(source_files)
else:
source_files_iterator = source_files
for source_file in source_files_iterator:
if ".pdf" in source_file:
source_file_raw_content_bytes, _ = get_file_content_bytes(
str(remote_lfs_dir_path),
f"/{dataset_id}/{source_files_dir_path}/{source_file}",
)
with open(
f"{module_path}/{source_files_dir_path}/{source_file}", "wb"
) as f:
f.write(source_file_raw_content_bytes)
else:
source_file_raw_content, _ = get_file_content(
str(remote_lfs_dir_path),
f"/{dataset_id}/{source_files_dir_path}/{source_file}",
)
with open(
f"{module_path}/{source_files_dir_path}/{source_file}", "w"
) as f:
f.write(source_file_raw_content)
def download_llama_dataset(
dataset_class: str,
llama_hub_url: str = LLAMA_HUB_URL,
llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL,
llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL,
refresh_cache: bool = False,
custom_dir: Optional[str] = None,
custom_path: Optional[str] = None,
source_files_dirpath: str = LLAMA_SOURCE_FILES_PATH,
library_path: str = "llama_datasets/library.json",
disable_library_cache: bool = False,
override_path: bool = False,
show_progress: bool = False,
) -> Any:
"""
Download a module from LlamaHub.
Can be a loader, tool, pack, or more.
Args:
loader_class: The name of the llama module class you want to download,
such as `GmailOpenAIAgentPack`.
refresh_cache: If true, the local cache will be skipped and the
loader will be fetched directly from the remote repo.
custom_dir: Custom dir name to download loader into (under parent folder).
custom_path: Custom dirpath to download loader into.
library_path: File name of the library file.
use_gpt_index_import: If true, the loader files will use
llama_index as the base dependency. By default (False),
the loader files use llama_index as the base dependency.
NOTE: this is a temporary workaround while we fully migrate all usages
to llama_index.
is_dataset: whether or not downloading a LlamaDataset
Returns:
A Loader, A Pack, An Agent, or A Dataset
"""
# create directory / get path
dirpath = initialize_directory(custom_path=custom_path, custom_dir=custom_dir)
# fetch info from library.json file
dataset_info = get_dataset_info(
local_dir_path=dirpath,
remote_dir_path=llama_hub_url,
remote_source_dir_path=llama_datasets_source_files_tree_url,
dataset_class=dataset_class,
refresh_cache=refresh_cache,
library_path=library_path,
disable_library_cache=disable_library_cache,
)
dataset_id = dataset_info["dataset_id"]
source_files = dataset_info["source_files"]
dataset_class_name = dataset_info["dataset_class_name"]
dataset_filename = _resolve_dataset_file_name(dataset_class_name)
download_dataset_and_source_files(
local_dir_path=dirpath,
remote_lfs_dir_path=llama_datasets_lfs_url,
source_files_dir_path=source_files_dirpath,
dataset_id=dataset_id,
dataset_class_name=dataset_class_name,
source_files=source_files,
refresh_cache=refresh_cache,
override_path=override_path,
show_progress=show_progress,
)
if override_path:
module_path = str(dirpath)
else:
module_path = f"{dirpath}/{dataset_id}"
return (
f"{module_path}/{dataset_filename}",
f"{module_path}/{LLAMA_SOURCE_FILES_PATH}",
)

View File

@ -0,0 +1,273 @@
"""Download."""
import json
import logging
import os
import subprocess
import sys
from enum import Enum
from importlib import util
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import pkg_resources
import requests
from pkg_resources import DistributionNotFound
from llama_index.download.utils import (
get_exports,
get_file_content,
initialize_directory,
rewrite_exports,
)
LLAMA_HUB_CONTENTS_URL = f"https://raw.githubusercontent.com/run-llama/llama-hub/main"
LLAMA_HUB_PATH = "/llama_hub"
LLAMA_HUB_URL = LLAMA_HUB_CONTENTS_URL + LLAMA_HUB_PATH
PATH_TYPE = Union[str, Path]
logger = logging.getLogger(__name__)
LLAMAHUB_ANALYTICS_PROXY_SERVER = "https://llamahub.ai/api/analytics/downloads"
class MODULE_TYPE(str, Enum):
LOADER = "loader"
TOOL = "tool"
LLAMAPACK = "llamapack"
DATASETS = "datasets"
def get_module_info(
local_dir_path: PATH_TYPE,
remote_dir_path: PATH_TYPE,
module_class: str,
refresh_cache: bool = False,
library_path: str = "library.json",
disable_library_cache: bool = False,
) -> Dict:
"""Get module info."""
if isinstance(local_dir_path, str):
local_dir_path = Path(local_dir_path)
local_library_path = f"{local_dir_path}/{library_path}"
module_id = None # e.g. `web/simple_web`
extra_files = [] # e.g. `web/simple_web/utils.py`
# Check cache first
if not refresh_cache and os.path.exists(local_library_path):
with open(local_library_path) as f:
library = json.load(f)
if module_class in library:
module_id = library[module_class]["id"]
extra_files = library[module_class].get("extra_files", [])
# Fetch up-to-date library from remote repo if module_id not found
if module_id is None:
library_raw_content, _ = get_file_content(
str(remote_dir_path), f"/{library_path}"
)
library = json.loads(library_raw_content)
if module_class not in library:
raise ValueError("Loader class name not found in library")
module_id = library[module_class]["id"]
extra_files = library[module_class].get("extra_files", [])
# create cache dir if needed
local_library_dir = os.path.dirname(local_library_path)
if not disable_library_cache:
if not os.path.exists(local_library_dir):
os.makedirs(local_library_dir)
# Update cache
with open(local_library_path, "w") as f:
f.write(library_raw_content)
if module_id is None:
raise ValueError("Loader class name not found in library")
return {
"module_id": module_id,
"extra_files": extra_files,
}
def download_module_and_reqs(
local_dir_path: PATH_TYPE,
remote_dir_path: PATH_TYPE,
module_id: str,
extra_files: List[str],
refresh_cache: bool = False,
use_gpt_index_import: bool = False,
base_file_name: str = "base.py",
override_path: bool = False,
) -> None:
"""Load module."""
if isinstance(local_dir_path, str):
local_dir_path = Path(local_dir_path)
if override_path:
module_path = str(local_dir_path)
else:
module_path = f"{local_dir_path}/{module_id}"
if refresh_cache or not os.path.exists(module_path):
os.makedirs(module_path, exist_ok=True)
basepy_raw_content, _ = get_file_content(
str(remote_dir_path), f"/{module_id}/{base_file_name}"
)
if use_gpt_index_import:
basepy_raw_content = basepy_raw_content.replace(
"import llama_index", "import llama_index"
)
basepy_raw_content = basepy_raw_content.replace(
"from llama_index", "from llama_index"
)
with open(f"{module_path}/{base_file_name}", "w") as f:
f.write(basepy_raw_content)
# Get content of extra files if there are any
# and write them under the loader directory
for extra_file in extra_files:
extra_file_raw_content, _ = get_file_content(
str(remote_dir_path), f"/{module_id}/{extra_file}"
)
# If the extra file is an __init__.py file, we need to
# add the exports to the __init__.py file in the modules directory
if extra_file == "__init__.py":
loader_exports = get_exports(extra_file_raw_content)
existing_exports = []
init_file_path = local_dir_path / "__init__.py"
# if the __init__.py file do not exists, we need to create it
mode = "a+" if not os.path.exists(init_file_path) else "r+"
with open(init_file_path, mode) as f:
f.write(f"from .{module_id} import {', '.join(loader_exports)}")
existing_exports = get_exports(f.read())
rewrite_exports(existing_exports + loader_exports, str(local_dir_path))
with open(f"{module_path}/{extra_file}", "w") as f:
f.write(extra_file_raw_content)
# install requirements
requirements_path = f"{local_dir_path}/requirements.txt"
if not os.path.exists(requirements_path):
# NOTE: need to check the status code
response_txt, status_code = get_file_content(
str(remote_dir_path), f"/{module_id}/requirements.txt"
)
if status_code == 200:
with open(requirements_path, "w") as f:
f.write(response_txt)
# Install dependencies if there are any and not already installed
if os.path.exists(requirements_path):
try:
requirements = pkg_resources.parse_requirements(
Path(requirements_path).open()
)
pkg_resources.require([str(r) for r in requirements])
except DistributionNotFound:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-r", requirements_path]
)
def download_llama_module(
module_class: str,
llama_hub_url: str = LLAMA_HUB_URL,
refresh_cache: bool = False,
custom_dir: Optional[str] = None,
custom_path: Optional[str] = None,
library_path: str = "library.json",
base_file_name: str = "base.py",
use_gpt_index_import: bool = False,
disable_library_cache: bool = False,
override_path: bool = False,
skip_load: bool = False,
) -> Any:
"""Download a module from LlamaHub.
Can be a loader, tool, pack, or more.
Args:
loader_class: The name of the llama module class you want to download,
such as `GmailOpenAIAgentPack`.
refresh_cache: If true, the local cache will be skipped and the
loader will be fetched directly from the remote repo.
custom_dir: Custom dir name to download loader into (under parent folder).
custom_path: Custom dirpath to download loader into.
library_path: File name of the library file.
use_gpt_index_import: If true, the loader files will use
llama_index as the base dependency. By default (False),
the loader files use llama_index as the base dependency.
NOTE: this is a temporary workaround while we fully migrate all usages
to llama_index.
is_dataset: whether or not downloading a LlamaDataset
Returns:
A Loader, A Pack, An Agent, or A Dataset
"""
# create directory / get path
dirpath = initialize_directory(custom_path=custom_path, custom_dir=custom_dir)
# fetch info from library.json file
module_info = get_module_info(
local_dir_path=dirpath,
remote_dir_path=llama_hub_url,
module_class=module_class,
refresh_cache=refresh_cache,
library_path=library_path,
disable_library_cache=disable_library_cache,
)
module_id = module_info["module_id"]
extra_files = module_info["extra_files"]
# download the module, install requirements
download_module_and_reqs(
local_dir_path=dirpath,
remote_dir_path=llama_hub_url,
module_id=module_id,
extra_files=extra_files,
refresh_cache=refresh_cache,
use_gpt_index_import=use_gpt_index_import,
base_file_name=base_file_name,
override_path=override_path,
)
if skip_load:
return None
# loads the module into memory
if override_path:
path = f"{dirpath}/{base_file_name}"
spec = util.spec_from_file_location("custom_module", location=path)
if spec is None:
raise ValueError(f"Could not find file: {path}.")
else:
path = f"{dirpath}/{module_id}/{base_file_name}"
spec = util.spec_from_file_location("custom_module", location=path)
if spec is None:
raise ValueError(f"Could not find file: {path}.")
module = util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return getattr(module, module_class)
def track_download(module_class: str, module_type: str) -> None:
"""Tracks number of downloads via Llamahub proxy.
Args:
module_class: The name of the llama module being downloaded, e.g.,`GmailOpenAIAgentPack`.
module_type: Can be "loader", "tool", "llamapack", or "datasets"
"""
try:
requests.post(
LLAMAHUB_ANALYTICS_PROXY_SERVER,
json={"type": module_type, "plugin": module_class},
)
except Exception as e:
logger.info(f"Error tracking downloads for {module_class} : {e}")

View File

@ -0,0 +1,88 @@
import os
from pathlib import Path
from typing import List, Optional, Tuple
import requests
def get_file_content(url: str, path: str) -> Tuple[str, int]:
"""Get the content of a file from the GitHub REST API."""
resp = requests.get(url + path)
return resp.text, resp.status_code
def get_file_content_bytes(url: str, path: str) -> Tuple[bytes, int]:
"""Get the content of a file from the GitHub REST API."""
resp = requests.get(url + path)
return resp.content, resp.status_code
def get_exports(raw_content: str) -> List:
"""Read content of a Python file and returns a list of exported class names.
For example:
```python
from .a import A
from .b import B
__all__ = ["A", "B"]
```
will return `["A", "B"]`.
Args:
- raw_content: The content of a Python file as a string.
Returns:
A list of exported class names.
"""
exports = []
for line in raw_content.splitlines():
line = line.strip()
if line.startswith("__all__"):
exports = line.split("=")[1].strip().strip("[").strip("]").split(",")
exports = [export.strip().strip("'").strip('"') for export in exports]
return exports
def rewrite_exports(exports: List[str], dirpath: str) -> None:
"""Write the `__all__` variable to the `__init__.py` file in the modules dir.
Removes the line that contains `__all__` and appends a new line with the updated
`__all__` variable.
Args:
- exports: A list of exported class names.
"""
init_path = f"{dirpath}/__init__.py"
with open(init_path) as f:
lines = f.readlines()
with open(init_path, "w") as f:
for line in lines:
line = line.strip()
if line.startswith("__all__"):
continue
f.write(line + os.linesep)
f.write(f"__all__ = {list(set(exports))}" + os.linesep)
def initialize_directory(
custom_path: Optional[str] = None, custom_dir: Optional[str] = None
) -> Path:
"""Initialize directory."""
if custom_path is not None and custom_dir is not None:
raise ValueError(
"You cannot specify both `custom_path` and `custom_dir` at the same time."
)
custom_dir = custom_dir or "llamadatasets"
if custom_path is not None:
dirpath = Path(custom_path)
else:
dirpath = Path(__file__).parent / custom_dir
if not os.path.exists(dirpath):
# Create a new directory because it does not exist
os.makedirs(dirpath)
return dirpath

View File

@ -0,0 +1,96 @@
"""Init file."""
from llama_index.embeddings.adapter import (
AdapterEmbeddingModel,
LinearAdapterEmbeddingModel,
)
from llama_index.embeddings.anyscale import AnyscaleEmbedding
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.embeddings.base import BaseEmbedding, SimilarityMode
from llama_index.embeddings.bedrock import BedrockEmbedding
from llama_index.embeddings.clarifai import ClarifaiEmbedding
from llama_index.embeddings.clip import ClipEmbedding
from llama_index.embeddings.cohereai import CohereEmbedding
from llama_index.embeddings.dashscope import (
DashScopeBatchTextEmbeddingModels,
DashScopeEmbedding,
DashScopeMultiModalEmbeddingModels,
DashScopeTextEmbeddingModels,
DashScopeTextEmbeddingType,
)
from llama_index.embeddings.elasticsearch import (
ElasticsearchEmbedding,
ElasticsearchEmbeddings,
)
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.embeddings.google import GoogleUnivSentEncoderEmbedding
from llama_index.embeddings.google_palm import GooglePaLMEmbedding
from llama_index.embeddings.gradient import GradientEmbedding
from llama_index.embeddings.huggingface import (
HuggingFaceEmbedding,
HuggingFaceInferenceAPIEmbedding,
HuggingFaceInferenceAPIEmbeddings,
)
from llama_index.embeddings.huggingface_optimum import OptimumEmbedding
from llama_index.embeddings.huggingface_utils import DEFAULT_HUGGINGFACE_EMBEDDING_MODEL
from llama_index.embeddings.instructor import InstructorEmbedding
from llama_index.embeddings.langchain import LangchainEmbedding
from llama_index.embeddings.llm_rails import LLMRailsEmbedding, LLMRailsEmbeddings
from llama_index.embeddings.mistralai import MistralAIEmbedding
from llama_index.embeddings.nomic import NomicEmbedding
from llama_index.embeddings.ollama_embedding import OllamaEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.pooling import Pooling
from llama_index.embeddings.sagemaker_embedding_endpoint import (
SageMakerEmbedding,
)
from llama_index.embeddings.text_embeddings_inference import TextEmbeddingsInference
from llama_index.embeddings.together import TogetherEmbedding
from llama_index.embeddings.utils import resolve_embed_model
from llama_index.embeddings.voyageai import VoyageEmbedding
__all__ = [
"AdapterEmbeddingModel",
"BedrockEmbedding",
"ClarifaiEmbedding",
"ClipEmbedding",
"CohereEmbedding",
"BaseEmbedding",
"DEFAULT_HUGGINGFACE_EMBEDDING_MODEL",
"ElasticsearchEmbedding",
"FastEmbedEmbedding",
"GoogleUnivSentEncoderEmbedding",
"GradientEmbedding",
"HuggingFaceInferenceAPIEmbedding",
"HuggingFaceEmbedding",
"InstructorEmbedding",
"LangchainEmbedding",
"LinearAdapterEmbeddingModel",
"LLMRailsEmbedding",
"MistralAIEmbedding",
"OpenAIEmbedding",
"AzureOpenAIEmbedding",
"AnyscaleEmbedding",
"OptimumEmbedding",
"Pooling",
"SageMakerEmbedding",
"GooglePaLMEmbedding",
"SimilarityMode",
"TextEmbeddingsInference",
"TogetherEmbedding",
"resolve_embed_model",
"NomicEmbedding",
# Deprecated, kept for backwards compatibility
"LLMRailsEmbeddings",
"ElasticsearchEmbeddings",
"HuggingFaceInferenceAPIEmbeddings",
"VoyageEmbedding",
"OllamaEmbedding",
"GeminiEmbedding",
"DashScopeEmbedding",
"DashScopeTextEmbeddingModels",
"DashScopeTextEmbeddingType",
"DashScopeBatchTextEmbeddingModels",
"DashScopeMultiModalEmbeddingModels",
]

View File

@ -0,0 +1,116 @@
"""Embedding adapter model."""
import logging
from typing import Any, List, Optional, Type, cast
from llama_index.bridge.pydantic import PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.constants import DEFAULT_EMBED_BATCH_SIZE
from llama_index.core.embeddings.base import BaseEmbedding
from llama_index.utils import infer_torch_device
logger = logging.getLogger(__name__)
class AdapterEmbeddingModel(BaseEmbedding):
"""Adapter for any embedding model.
This is a wrapper around any embedding model that adds an adapter layer \
on top of it.
This is useful for finetuning an embedding model on a downstream task.
The embedding model can be any model - it does not need to expose gradients.
Args:
base_embed_model (BaseEmbedding): Base embedding model.
adapter_path (str): Path to adapter.
adapter_cls (Optional[Type[Any]]): Adapter class. Defaults to None, in which \
case a linear adapter is used.
transform_query (bool): Whether to transform query embeddings. Defaults to True.
device (Optional[str]): Device to use. Defaults to None.
embed_batch_size (int): Batch size for embedding. Defaults to 10.
callback_manager (Optional[CallbackManager]): Callback manager. \
Defaults to None.
"""
_base_embed_model: BaseEmbedding = PrivateAttr()
_adapter: Any = PrivateAttr()
_transform_query: bool = PrivateAttr()
_device: Optional[str] = PrivateAttr()
_target_device: Any = PrivateAttr()
def __init__(
self,
base_embed_model: BaseEmbedding,
adapter_path: str,
adapter_cls: Optional[Type[Any]] = None,
transform_query: bool = True,
device: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
) -> None:
"""Init params."""
import torch
from llama_index.embeddings.adapter_utils import BaseAdapter, LinearLayer
if device is None:
device = infer_torch_device()
logger.info(f"Use pytorch device: {device}")
self._target_device = torch.device(device)
self._base_embed_model = base_embed_model
if adapter_cls is None:
adapter_cls = LinearLayer
else:
adapter_cls = cast(Type[BaseAdapter], adapter_cls)
adapter = adapter_cls.load(adapter_path)
self._adapter = cast(BaseAdapter, adapter)
self._adapter.to(self._target_device)
self._transform_query = transform_query
super().__init__(
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
model_name=f"Adapter for {base_embed_model.model_name}",
)
@classmethod
def class_name(cls) -> str:
return "AdapterEmbeddingModel"
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
import torch
query_embedding = self._base_embed_model._get_query_embedding(query)
if self._transform_query:
query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
query_embedding_t = self._adapter.forward(query_embedding_t)
query_embedding = query_embedding_t.tolist()
return query_embedding
async def _aget_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
import torch
query_embedding = await self._base_embed_model._aget_query_embedding(query)
if self._transform_query:
query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
query_embedding_t = self._adapter.forward(query_embedding_t)
query_embedding = query_embedding_t.tolist()
return query_embedding
def _get_text_embedding(self, text: str) -> List[float]:
return self._base_embed_model._get_text_embedding(text)
async def _aget_text_embedding(self, text: str) -> List[float]:
return await self._base_embed_model._aget_text_embedding(text)
# Maintain for backwards compatibility
LinearAdapterEmbeddingModel = AdapterEmbeddingModel

View File

@ -0,0 +1,179 @@
"""Adapter utils."""
import json
import logging
import os
from abc import abstractmethod
from typing import Callable, Dict
import torch
import torch.nn.functional as F
from torch import Tensor, nn
logger = logging.getLogger(__name__)
class BaseAdapter(nn.Module):
"""Base adapter.
Can be subclassed to implement custom adapters.
To implement a custom adapter, subclass this class and implement the
following methods:
- get_config_dict
- forward
"""
@abstractmethod
def get_config_dict(self) -> Dict:
"""Get config dict."""
@abstractmethod
def forward(self, embed: Tensor) -> Tensor:
"""Forward pass."""
def save(self, output_path: str) -> None:
"""Save model."""
os.makedirs(output_path, exist_ok=True)
with open(os.path.join(output_path, "config.json"), "w") as fOut:
json.dump(self.get_config_dict(), fOut)
torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
@classmethod
def load(cls, input_path: str) -> "BaseAdapter":
"""Load model."""
with open(os.path.join(input_path, "config.json")) as fIn:
config = json.load(fIn)
model = cls(**config)
model.load_state_dict(
torch.load(
os.path.join(input_path, "pytorch_model.bin"),
map_location=torch.device("cpu"),
)
)
return model
class LinearLayer(BaseAdapter):
"""Linear transformation.
Args:
in_features (int): Input dimension.
out_features (int): Output dimension.
bias (bool): Whether to use bias. Defaults to False.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.bias = bias
self.linear = nn.Linear(in_features, out_features, bias=bias)
# seed with identity matrix and 0 bias
# only works for square matrices
self.linear.weight.data.copy_(torch.eye(in_features, out_features))
if bias:
self.linear.bias.data.copy_(torch.zeros(out_features))
def forward(self, embed: Tensor) -> Tensor:
"""Forward pass (Wv)."""
return self.linear(embed)
def get_config_dict(self) -> Dict:
return {
"in_features": self.in_features,
"out_features": self.out_features,
"bias": self.bias,
}
def get_activation_function(name: str) -> Callable:
"""Get activation function.
Args:
name (str): Name of activation function.
"""
activations: Dict[str, Callable] = {
"relu": F.relu,
"sigmoid": torch.sigmoid,
"tanh": torch.tanh,
"leaky_relu": F.leaky_relu,
# add more activations here as needed
}
if name not in activations:
raise ValueError(f"Unknown activation function: {name}")
return activations[name]
class TwoLayerNN(BaseAdapter):
"""Two-layer transformation.
Args:
in_features (int): Input dimension.
hidden_features (int): Hidden dimension.
out_features (int): Output dimension.
bias (bool): Whether to use bias. Defaults to False.
activation_fn_str (str): Name of activation function. Defaults to "relu".
"""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int,
bias: bool = False,
activation_fn_str: str = "relu",
add_residual: bool = False,
) -> None:
super().__init__()
self.in_features = in_features
self.hidden_features = hidden_features
self.out_features = out_features
self.bias = bias
self.activation_fn_str = activation_fn_str
self.linear1 = nn.Linear(in_features, hidden_features, bias=True)
self.linear2 = nn.Linear(hidden_features, out_features, bias=True)
# self.linear1.weight.data.copy_(torch.zeros(hidden_features, in_features))
# self.linear2.weight.data.copy_(torch.zeros(out_features, hidden_features))
# if bias:
# self.linear1.bias.data.copy_(torch.zeros(hidden_features))
# self.linear2.bias.data.copy_(torch.zeros(out_features))
self._activation_function = get_activation_function(activation_fn_str)
self._add_residual = add_residual
# if add_residual, then add residual_weight (init to 0)
self.residual_weight = nn.Parameter(torch.zeros(1))
def forward(self, embed: Tensor) -> Tensor:
"""Forward pass (Wv).
Args:
embed (Tensor): Input tensor.
"""
output1 = self.linear1(embed)
output1 = self._activation_function(output1)
output2 = self.linear2(output1)
if self._add_residual:
# print(output2)
# print(self.residual_weight)
# print(self.linear2.weight.data)
output2 = self.residual_weight * output2 + embed
return output2
def get_config_dict(self) -> Dict:
"""Get config dict."""
return {
"in_features": self.in_features,
"hidden_features": self.hidden_features,
"out_features": self.out_features,
"bias": self.bias,
"activation_fn_str": self.activation_fn_str,
"add_residual": self._add_residual,
}

Some files were not shown because too many files have changed in this diff Show More