This commit is contained in:
parent
5da1f8579d
commit
ad59446d14
|
|
@ -0,0 +1 @@
|
|||
0.9.48
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
"""Init file of LlamaIndex."""
|
||||
from pathlib import Path
|
||||
|
||||
with open(Path(__file__).absolute().parents[0] / "VERSION") as _f:
|
||||
__version__ = _f.read().strip()
|
||||
|
||||
|
||||
import logging
|
||||
from logging import NullHandler
|
||||
from typing import Callable, Optional
|
||||
|
||||
# import global eval handler
|
||||
from llama_index.callbacks.global_handlers import set_global_handler
|
||||
|
||||
# response
|
||||
from llama_index.core.response.schema import Response
|
||||
from llama_index.data_structs.struct_type import IndexStructType
|
||||
|
||||
# embeddings
|
||||
from llama_index.embeddings import OpenAIEmbedding
|
||||
|
||||
# indices
|
||||
# loading
|
||||
from llama_index.indices import (
|
||||
ComposableGraph,
|
||||
DocumentSummaryIndex,
|
||||
GPTDocumentSummaryIndex,
|
||||
GPTKeywordTableIndex,
|
||||
GPTKnowledgeGraphIndex,
|
||||
GPTListIndex,
|
||||
GPTRAKEKeywordTableIndex,
|
||||
GPTSimpleKeywordTableIndex,
|
||||
GPTTreeIndex,
|
||||
GPTVectorStoreIndex,
|
||||
KeywordTableIndex,
|
||||
KnowledgeGraphIndex,
|
||||
ListIndex,
|
||||
RAKEKeywordTableIndex,
|
||||
SimpleKeywordTableIndex,
|
||||
SummaryIndex,
|
||||
TreeIndex,
|
||||
VectorStoreIndex,
|
||||
load_graph_from_storage,
|
||||
load_index_from_storage,
|
||||
load_indices_from_storage,
|
||||
)
|
||||
|
||||
# structured
|
||||
from llama_index.indices.common.struct_store.base import SQLDocumentContextBuilder
|
||||
|
||||
# prompt helper
|
||||
from llama_index.indices.prompt_helper import PromptHelper
|
||||
from llama_index.llm_predictor import LLMPredictor
|
||||
|
||||
# token predictor
|
||||
from llama_index.llm_predictor.mock import MockLLMPredictor
|
||||
|
||||
# prompts
|
||||
from llama_index.prompts import (
|
||||
BasePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
# backwards compatibility
|
||||
Prompt,
|
||||
PromptTemplate,
|
||||
SelectorPromptTemplate,
|
||||
)
|
||||
from llama_index.readers import (
|
||||
SimpleDirectoryReader,
|
||||
download_loader,
|
||||
)
|
||||
|
||||
# Response Synthesizer
|
||||
from llama_index.response_synthesizers.factory import get_response_synthesizer
|
||||
from llama_index.schema import Document, QueryBundle
|
||||
from llama_index.service_context import (
|
||||
ServiceContext,
|
||||
set_global_service_context,
|
||||
)
|
||||
|
||||
# storage
|
||||
from llama_index.storage.storage_context import StorageContext
|
||||
from llama_index.token_counter.mock_embed_model import MockEmbedding
|
||||
|
||||
# sql wrapper
|
||||
from llama_index.utilities.sql_wrapper import SQLDatabase
|
||||
|
||||
# global tokenizer
|
||||
from llama_index.utils import get_tokenizer, set_global_tokenizer
|
||||
|
||||
# best practices for library logging:
|
||||
# https://docs.python.org/3/howto/logging.html#configuring-logging-for-a-library
|
||||
logging.getLogger(__name__).addHandler(NullHandler())
|
||||
|
||||
__all__ = [
|
||||
"StorageContext",
|
||||
"ServiceContext",
|
||||
"ComposableGraph",
|
||||
# indices
|
||||
"SummaryIndex",
|
||||
"VectorStoreIndex",
|
||||
"SimpleKeywordTableIndex",
|
||||
"KeywordTableIndex",
|
||||
"RAKEKeywordTableIndex",
|
||||
"TreeIndex",
|
||||
"DocumentSummaryIndex",
|
||||
"KnowledgeGraphIndex",
|
||||
# indices - legacy names
|
||||
"GPTKeywordTableIndex",
|
||||
"GPTKnowledgeGraphIndex",
|
||||
"GPTSimpleKeywordTableIndex",
|
||||
"GPTRAKEKeywordTableIndex",
|
||||
"GPTListIndex",
|
||||
"ListIndex",
|
||||
"GPTTreeIndex",
|
||||
"GPTVectorStoreIndex",
|
||||
"GPTDocumentSummaryIndex",
|
||||
"Prompt",
|
||||
"PromptTemplate",
|
||||
"BasePromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
"SelectorPromptTemplate",
|
||||
"OpenAIEmbedding",
|
||||
"SummaryPrompt",
|
||||
"TreeInsertPrompt",
|
||||
"TreeSelectPrompt",
|
||||
"TreeSelectMultiplePrompt",
|
||||
"RefinePrompt",
|
||||
"QuestionAnswerPrompt",
|
||||
"KeywordExtractPrompt",
|
||||
"QueryKeywordExtractPrompt",
|
||||
"Response",
|
||||
"Document",
|
||||
"SimpleDirectoryReader",
|
||||
"LLMPredictor",
|
||||
"MockLLMPredictor",
|
||||
"VellumPredictor",
|
||||
"VellumPromptRegistry",
|
||||
"MockEmbedding",
|
||||
"SQLDatabase",
|
||||
"SQLDocumentContextBuilder",
|
||||
"SQLContextBuilder",
|
||||
"PromptHelper",
|
||||
"IndexStructType",
|
||||
"download_loader",
|
||||
"load_graph_from_storage",
|
||||
"load_index_from_storage",
|
||||
"load_indices_from_storage",
|
||||
"QueryBundle",
|
||||
"get_response_synthesizer",
|
||||
"set_global_service_context",
|
||||
"set_global_handler",
|
||||
"set_global_tokenizer",
|
||||
"get_tokenizer",
|
||||
]
|
||||
|
||||
# eval global toggle
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
|
||||
global_handler: Optional[BaseCallbackHandler] = None
|
||||
|
||||
# NOTE: keep for backwards compatibility
|
||||
SQLContextBuilder = SQLDocumentContextBuilder
|
||||
|
||||
# global service context for ServiceContext.from_defaults()
|
||||
global_service_context: Optional[ServiceContext] = None
|
||||
|
||||
# global tokenizer
|
||||
global_tokenizer: Optional[Callable[[str], list]] = None
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
# Include this file
|
||||
!.gitignore
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
# Include this file
|
||||
!.gitignore
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# agent runner + agent worker
|
||||
from llama_index.agent.custom.pipeline_worker import QueryPipelineAgentWorker
|
||||
from llama_index.agent.custom.simple import CustomSimpleAgentWorker
|
||||
from llama_index.agent.legacy.context_retriever_agent import ContextRetrieverOpenAIAgent
|
||||
from llama_index.agent.legacy.openai_agent import OpenAIAgent as OldOpenAIAgent
|
||||
from llama_index.agent.legacy.react.base import ReActAgent as OldReActAgent
|
||||
from llama_index.agent.legacy.retriever_openai_agent import FnRetrieverOpenAIAgent
|
||||
from llama_index.agent.openai.base import OpenAIAgent
|
||||
from llama_index.agent.openai.step import OpenAIAgentWorker
|
||||
from llama_index.agent.openai_assistant_agent import OpenAIAssistantAgent
|
||||
from llama_index.agent.react.base import ReActAgent
|
||||
from llama_index.agent.react.formatter import ReActChatFormatter
|
||||
from llama_index.agent.react.step import ReActAgentWorker
|
||||
from llama_index.agent.react_multimodal.step import MultimodalReActAgentWorker
|
||||
from llama_index.agent.runner.base import AgentRunner
|
||||
from llama_index.agent.runner.parallel import ParallelAgentRunner
|
||||
from llama_index.agent.types import Task
|
||||
from llama_index.chat_engine.types import AgentChatResponse
|
||||
|
||||
# for backwards compatibility
|
||||
RetrieverOpenAIAgent = FnRetrieverOpenAIAgent
|
||||
|
||||
__all__ = [
|
||||
"AgentRunner",
|
||||
"ParallelAgentRunner",
|
||||
"OpenAIAgentWorker",
|
||||
"ReActAgentWorker",
|
||||
"OpenAIAgent",
|
||||
"ReActAgent",
|
||||
"OpenAIAssistantAgent",
|
||||
"FnRetrieverOpenAIAgent",
|
||||
"RetrieverOpenAIAgent", # for backwards compatibility
|
||||
"ContextRetrieverOpenAIAgent",
|
||||
"CustomSimpleAgentWorker",
|
||||
"QueryPipelineAgentWorker",
|
||||
"ReActChatFormatter",
|
||||
# beta
|
||||
"MultimodalReActAgentWorker",
|
||||
# schema-related
|
||||
"AgentChatResponse",
|
||||
"Task",
|
||||
# legacy
|
||||
"OldOpenAIAgent",
|
||||
"OldReActAgent",
|
||||
]
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Init params."""
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
"""Agent worker that takes in a query pipeline."""
|
||||
|
||||
import uuid
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
|
||||
from llama_index.agent.types import (
|
||||
BaseAgentWorker,
|
||||
Task,
|
||||
TaskStep,
|
||||
TaskStepOutput,
|
||||
)
|
||||
from llama_index.bridge.pydantic import BaseModel, Field
|
||||
from llama_index.callbacks import (
|
||||
CallbackManager,
|
||||
trace_method,
|
||||
)
|
||||
from llama_index.chat_engine.types import (
|
||||
AGENT_CHAT_RESPONSE_TYPE,
|
||||
)
|
||||
from llama_index.core.query_pipeline.query_component import QueryComponent
|
||||
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
|
||||
from llama_index.query_pipeline.components.agent import (
|
||||
AgentFnComponent,
|
||||
AgentInputComponent,
|
||||
BaseAgentComponent,
|
||||
)
|
||||
from llama_index.query_pipeline.query import QueryPipeline
|
||||
from llama_index.tools import ToolOutput
|
||||
|
||||
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
|
||||
|
||||
|
||||
def _get_agent_components(query_component: QueryComponent) -> List[BaseAgentComponent]:
|
||||
"""Get agent components."""
|
||||
agent_components: List[BaseAgentComponent] = []
|
||||
for c in query_component.sub_query_components:
|
||||
if isinstance(c, BaseAgentComponent):
|
||||
agent_components.append(cast(BaseAgentComponent, c))
|
||||
|
||||
if len(c.sub_query_components) > 0:
|
||||
agent_components.extend(_get_agent_components(c))
|
||||
|
||||
return agent_components
|
||||
|
||||
|
||||
class QueryPipelineAgentWorker(BaseModel, BaseAgentWorker):
|
||||
"""Query Pipeline agent worker.
|
||||
|
||||
Barebones agent worker that takes in a query pipeline.
|
||||
|
||||
Assumes that the first component in the query pipeline is an
|
||||
`AgentInputComponent` and last is `AgentFnComponent`.
|
||||
|
||||
Args:
|
||||
pipeline (QueryPipeline): Query pipeline
|
||||
|
||||
"""
|
||||
|
||||
pipeline: QueryPipeline = Field(..., description="Query pipeline")
|
||||
callback_manager: CallbackManager = Field(..., exclude=True)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline: QueryPipeline,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
) -> None:
|
||||
"""Initialize."""
|
||||
if callback_manager is not None:
|
||||
# set query pipeline callback
|
||||
pipeline.set_callback_manager(callback_manager)
|
||||
else:
|
||||
callback_manager = pipeline.callback_manager
|
||||
super().__init__(
|
||||
pipeline=pipeline,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
# validate query pipeline
|
||||
self.agent_input_component
|
||||
self.agent_components
|
||||
|
||||
@property
|
||||
def agent_input_component(self) -> AgentInputComponent:
|
||||
"""Get agent input component."""
|
||||
root_key = self.pipeline.get_root_keys()[0]
|
||||
if not isinstance(self.pipeline.module_dict[root_key], AgentInputComponent):
|
||||
raise ValueError(
|
||||
"Query pipeline first component must be AgentInputComponent, got "
|
||||
f"{self.pipeline.module_dict[root_key]}"
|
||||
)
|
||||
|
||||
return cast(AgentInputComponent, self.pipeline.module_dict[root_key])
|
||||
|
||||
@property
|
||||
def agent_components(self) -> List[AgentFnComponent]:
|
||||
"""Get agent output component."""
|
||||
return _get_agent_components(self.pipeline)
|
||||
|
||||
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
|
||||
"""Initialize step from task."""
|
||||
sources: List[ToolOutput] = []
|
||||
# temporary memory for new messages
|
||||
new_memory = ChatMemoryBuffer.from_defaults()
|
||||
|
||||
# initialize initial state
|
||||
initial_state = {
|
||||
"sources": sources,
|
||||
"memory": new_memory,
|
||||
}
|
||||
|
||||
return TaskStep(
|
||||
task_id=task.task_id,
|
||||
step_id=str(uuid.uuid4()),
|
||||
input=task.input,
|
||||
step_state=initial_state,
|
||||
)
|
||||
|
||||
def _get_task_step_response(
|
||||
self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
|
||||
) -> TaskStepOutput:
|
||||
"""Get task step response."""
|
||||
if is_done:
|
||||
new_steps = []
|
||||
else:
|
||||
new_steps = [
|
||||
step.get_next_step(
|
||||
step_id=str(uuid.uuid4()),
|
||||
# NOTE: input is unused
|
||||
input=None,
|
||||
)
|
||||
]
|
||||
|
||||
return TaskStepOutput(
|
||||
output=agent_response,
|
||||
task_step=step,
|
||||
is_last=is_done,
|
||||
next_steps=new_steps,
|
||||
)
|
||||
|
||||
@trace_method("run_step")
|
||||
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
# partial agent output component with task and step
|
||||
for agent_fn_component in self.agent_components:
|
||||
agent_fn_component.partial(task=task, state=step.step_state)
|
||||
|
||||
agent_response, is_done = self.pipeline.run(state=step.step_state, task=task)
|
||||
response = self._get_task_step_response(agent_response, step, is_done)
|
||||
# sync step state with task state
|
||||
task.extra_state.update(step.step_state)
|
||||
return response
|
||||
|
||||
@trace_method("run_step")
|
||||
async def arun_step(
|
||||
self, step: TaskStep, task: Task, **kwargs: Any
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async)."""
|
||||
# partial agent output component with task and step
|
||||
for agent_fn_component in self.agent_components:
|
||||
agent_fn_component.partial(task=task, state=step.step_state)
|
||||
|
||||
agent_response, is_done = await self.pipeline.arun(
|
||||
state=step.step_state, task=task
|
||||
)
|
||||
response = self._get_task_step_response(agent_response, step, is_done)
|
||||
task.extra_state.update(step.step_state)
|
||||
return response
|
||||
|
||||
@trace_method("run_step")
|
||||
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
||||
"""Run step (stream)."""
|
||||
raise NotImplementedError("This agent does not support streaming.")
|
||||
|
||||
@trace_method("run_step")
|
||||
async def astream_step(
|
||||
self, step: TaskStep, task: Task, **kwargs: Any
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async stream)."""
|
||||
raise NotImplementedError("This agent does not support streaming.")
|
||||
|
||||
def finalize_task(self, task: Task, **kwargs: Any) -> None:
|
||||
"""Finalize task, after all the steps are completed."""
|
||||
# add new messages to memory
|
||||
task.memory.set(task.memory.get() + task.extra_state["memory"].get_all())
|
||||
# reset new memory
|
||||
task.extra_state["memory"].reset()
|
||||
|
||||
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
||||
"""Set callback manager."""
|
||||
# TODO: make this abstractmethod (right now will break some agent impls)
|
||||
self.callback_manager = callback_manager
|
||||
self.pipeline.set_callback_manager(callback_manager)
|
||||
|
|
@ -0,0 +1,261 @@
|
|||
"""Custom agent worker."""
|
||||
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
from llama_index.agent.types import (
|
||||
BaseAgentWorker,
|
||||
Task,
|
||||
TaskStep,
|
||||
TaskStepOutput,
|
||||
)
|
||||
from llama_index.bridge.pydantic import BaseModel, Field, PrivateAttr
|
||||
from llama_index.callbacks import (
|
||||
CallbackManager,
|
||||
trace_method,
|
||||
)
|
||||
from llama_index.chat_engine.types import (
|
||||
AGENT_CHAT_RESPONSE_TYPE,
|
||||
AgentChatResponse,
|
||||
)
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
|
||||
from llama_index.objects.base import ObjectRetriever
|
||||
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
|
||||
from llama_index.tools.types import AsyncBaseTool
|
||||
|
||||
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
|
||||
|
||||
|
||||
class CustomSimpleAgentWorker(BaseModel, BaseAgentWorker):
|
||||
"""Custom simple agent worker.
|
||||
|
||||
This is "simple" in the sense that some of the scaffolding is setup already.
|
||||
Assumptions:
|
||||
- assumes that the agent has tools, llm, callback manager, and tool retriever
|
||||
- has a `from_tools` convenience function
|
||||
- assumes that the agent is sequential, and doesn't take in any additional
|
||||
intermediate inputs.
|
||||
|
||||
Args:
|
||||
tools (Sequence[BaseTool]): Tools to use for reasoning
|
||||
llm (LLM): LLM to use
|
||||
callback_manager (CallbackManager): Callback manager
|
||||
tool_retriever (Optional[ObjectRetriever[BaseTool]]): Tool retriever
|
||||
verbose (bool): Whether to print out reasoning steps
|
||||
|
||||
"""
|
||||
|
||||
tools: Sequence[BaseTool] = Field(..., description="Tools to use for reasoning")
|
||||
llm: LLM = Field(..., description="LLM to use")
|
||||
callback_manager: CallbackManager = Field(
|
||||
default_factory=lambda: CallbackManager([]), exclude=True
|
||||
)
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = Field(
|
||||
default=None, description="Tool retriever"
|
||||
)
|
||||
verbose: bool = Field(False, description="Whether to print out reasoning steps")
|
||||
|
||||
_get_tools: Callable[[str], Sequence[BaseTool]] = PrivateAttr()
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: Sequence[BaseTool],
|
||||
llm: LLM,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
) -> None:
|
||||
if len(tools) > 0 and tool_retriever is not None:
|
||||
raise ValueError("Cannot specify both tools and tool_retriever")
|
||||
elif len(tools) > 0:
|
||||
self._get_tools = lambda _: tools
|
||||
elif tool_retriever is not None:
|
||||
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
|
||||
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
|
||||
else:
|
||||
self._get_tools = lambda _: []
|
||||
|
||||
super().__init__(
|
||||
tools=tools,
|
||||
llm=llm,
|
||||
callback_manager=callback_manager,
|
||||
tool_retriever=tool_retriever,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tools(
|
||||
cls,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
llm: Optional[LLM] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "CustomSimpleAgentWorker":
|
||||
"""Convenience constructor method from set of of BaseTools (Optional)."""
|
||||
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
|
||||
if callback_manager is not None:
|
||||
llm.callback_manager = callback_manager
|
||||
return cls(
|
||||
tools=tools or [],
|
||||
tool_retriever=tool_retriever,
|
||||
llm=llm,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_state(self, task: Task, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Initialize state."""
|
||||
|
||||
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
|
||||
"""Initialize step from task."""
|
||||
sources: List[ToolOutput] = []
|
||||
# temporary memory for new messages
|
||||
new_memory = ChatMemoryBuffer.from_defaults()
|
||||
|
||||
# initialize initial state
|
||||
initial_state = {
|
||||
"sources": sources,
|
||||
"memory": new_memory,
|
||||
}
|
||||
|
||||
step_state = self._initialize_state(task, **kwargs)
|
||||
# if intersecting keys, error
|
||||
if set(step_state.keys()).intersection(set(initial_state.keys())):
|
||||
raise ValueError(
|
||||
f"Step state keys {step_state.keys()} and initial state keys {initial_state.keys()} intersect."
|
||||
f"*NOTE*: initial state keys {initial_state.keys()} are reserved."
|
||||
)
|
||||
step_state.update(initial_state)
|
||||
|
||||
return TaskStep(
|
||||
task_id=task.task_id,
|
||||
step_id=str(uuid.uuid4()),
|
||||
input=task.input,
|
||||
step_state=step_state,
|
||||
)
|
||||
|
||||
def get_tools(self, input: str) -> List[AsyncBaseTool]:
|
||||
"""Get tools."""
|
||||
return [adapt_to_async_tool(t) for t in self._get_tools(input)]
|
||||
|
||||
def _get_task_step_response(
|
||||
self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
|
||||
) -> TaskStepOutput:
|
||||
"""Get task step response."""
|
||||
if is_done:
|
||||
new_steps = []
|
||||
else:
|
||||
new_steps = [
|
||||
step.get_next_step(
|
||||
step_id=str(uuid.uuid4()),
|
||||
# NOTE: input is unused
|
||||
input=None,
|
||||
)
|
||||
]
|
||||
|
||||
return TaskStepOutput(
|
||||
output=agent_response,
|
||||
task_step=step,
|
||||
is_last=is_done,
|
||||
next_steps=new_steps,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _run_step(
|
||||
self, state: Dict[str, Any], task: Task, input: Optional[str] = None
|
||||
) -> Tuple[AgentChatResponse, bool]:
|
||||
"""Run step.
|
||||
|
||||
Returns:
|
||||
Tuple of (agent_response, is_done)
|
||||
|
||||
"""
|
||||
|
||||
async def _arun_step(
|
||||
self, state: Dict[str, Any], task: Task, input: Optional[str] = None
|
||||
) -> Tuple[AgentChatResponse, bool]:
|
||||
"""Run step (async).
|
||||
|
||||
Can override this method if you want to run the step asynchronously.
|
||||
|
||||
Returns:
|
||||
Tuple of (agent_response, is_done)
|
||||
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"This agent does not support async." "Please implement _arun_step."
|
||||
)
|
||||
|
||||
@trace_method("run_step")
|
||||
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
agent_response, is_done = self._run_step(
|
||||
step.step_state, task, input=step.input
|
||||
)
|
||||
response = self._get_task_step_response(agent_response, step, is_done)
|
||||
# sync step state with task state
|
||||
task.extra_state.update(step.step_state)
|
||||
return response
|
||||
|
||||
@trace_method("run_step")
|
||||
async def arun_step(
|
||||
self, step: TaskStep, task: Task, **kwargs: Any
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async)."""
|
||||
agent_response, is_done = await self._arun_step(
|
||||
step.step_state, task, input=step.input
|
||||
)
|
||||
response = self._get_task_step_response(agent_response, step, is_done)
|
||||
task.extra_state.update(step.step_state)
|
||||
return response
|
||||
|
||||
@trace_method("run_step")
|
||||
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
||||
"""Run step (stream)."""
|
||||
raise NotImplementedError("This agent does not support streaming.")
|
||||
|
||||
@trace_method("run_step")
|
||||
async def astream_step(
|
||||
self, step: TaskStep, task: Task, **kwargs: Any
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async stream)."""
|
||||
raise NotImplementedError("This agent does not support streaming.")
|
||||
|
||||
@abstractmethod
|
||||
def _finalize_task(self, state: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Finalize task, after all the steps are completed.
|
||||
|
||||
State is all the step states.
|
||||
|
||||
"""
|
||||
|
||||
def finalize_task(self, task: Task, **kwargs: Any) -> None:
|
||||
"""Finalize task, after all the steps are completed."""
|
||||
# add new messages to memory
|
||||
task.memory.set(task.memory.get() + task.extra_state["memory"].get_all())
|
||||
# reset new memory
|
||||
task.extra_state["memory"].reset()
|
||||
self._finalize_task(task.extra_state, **kwargs)
|
||||
|
||||
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
||||
"""Set callback manager."""
|
||||
# TODO: make this abstractmethod (right now will break some agent impls)
|
||||
self.callback_manager = callback_manager
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Init params."""
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
"""Context retriever agent."""
|
||||
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
from llama_index.agent.legacy.openai_agent import (
|
||||
DEFAULT_MAX_FUNCTION_CALLS,
|
||||
DEFAULT_MODEL_NAME,
|
||||
BaseOpenAIAgent,
|
||||
)
|
||||
from llama_index.callbacks import CallbackManager
|
||||
from llama_index.chat_engine.types import (
|
||||
AgentChatResponse,
|
||||
)
|
||||
from llama_index.core.base_retriever import BaseRetriever
|
||||
from llama_index.core.llms.types import ChatMessage
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.llms.openai_utils import is_function_calling_model
|
||||
from llama_index.memory import BaseMemory, ChatMemoryBuffer
|
||||
from llama_index.prompts import PromptTemplate
|
||||
from llama_index.schema import NodeWithScore
|
||||
from llama_index.tools import BaseTool
|
||||
from llama_index.utils import print_text
|
||||
|
||||
# inspired by DEFAULT_QA_PROMPT_TMPL from llama_index/prompts/default_prompts.py
|
||||
DEFAULT_QA_PROMPT_TMPL = (
|
||||
"Context information is below.\n"
|
||||
"---------------------\n"
|
||||
"{context_str}\n"
|
||||
"---------------------\n"
|
||||
"Given the context information and not prior knowledge, "
|
||||
"either pick the corresponding tool or answer the function: {query_str}\n"
|
||||
)
|
||||
DEFAULT_QA_PROMPT = PromptTemplate(DEFAULT_QA_PROMPT_TMPL)
|
||||
|
||||
|
||||
class ContextRetrieverOpenAIAgent(BaseOpenAIAgent):
|
||||
"""ContextRetriever OpenAI Agent.
|
||||
|
||||
This agent performs retrieval from BaseRetriever before
|
||||
calling the LLM. Allows it to augment user message with context.
|
||||
|
||||
NOTE: this is a beta feature, function interfaces might change.
|
||||
|
||||
Args:
|
||||
tools (List[BaseTool]): A list of tools.
|
||||
retriever (BaseRetriever): A retriever.
|
||||
qa_prompt (Optional[PromptTemplate]): A QA prompt.
|
||||
context_separator (str): A context separator.
|
||||
llm (Optional[OpenAI]): An OpenAI LLM.
|
||||
chat_history (Optional[List[ChatMessage]]): A chat history.
|
||||
prefix_messages: List[ChatMessage]: A list of prefix messages.
|
||||
verbose (bool): Whether to print debug statements.
|
||||
max_function_calls (int): Maximum number of function calls.
|
||||
callback_manager (Optional[CallbackManager]): A callback manager.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[BaseTool],
|
||||
retriever: BaseRetriever,
|
||||
qa_prompt: PromptTemplate,
|
||||
context_separator: str,
|
||||
llm: OpenAI,
|
||||
memory: BaseMemory,
|
||||
prefix_messages: List[ChatMessage],
|
||||
verbose: bool = False,
|
||||
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
prefix_messages=prefix_messages,
|
||||
verbose=verbose,
|
||||
max_function_calls=max_function_calls,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
self._tools = tools
|
||||
self._qa_prompt = qa_prompt
|
||||
self._retriever = retriever
|
||||
self._context_separator = context_separator
|
||||
|
||||
@classmethod
|
||||
def from_tools_and_retriever(
|
||||
cls,
|
||||
tools: List[BaseTool],
|
||||
retriever: BaseRetriever,
|
||||
qa_prompt: Optional[PromptTemplate] = None,
|
||||
context_separator: str = "\n",
|
||||
llm: Optional[LLM] = None,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
|
||||
verbose: bool = False,
|
||||
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
prefix_messages: Optional[List[ChatMessage]] = None,
|
||||
) -> "ContextRetrieverOpenAIAgent":
|
||||
"""Create a ContextRetrieverOpenAIAgent from a retriever.
|
||||
|
||||
Args:
|
||||
retriever (BaseRetriever): A retriever.
|
||||
qa_prompt (Optional[PromptTemplate]): A QA prompt.
|
||||
context_separator (str): A context separator.
|
||||
llm (Optional[OpenAI]): An OpenAI LLM.
|
||||
chat_history (Optional[ChatMessageHistory]): A chat history.
|
||||
verbose (bool): Whether to print debug statements.
|
||||
max_function_calls (int): Maximum number of function calls.
|
||||
callback_manager (Optional[CallbackManager]): A callback manager.
|
||||
|
||||
"""
|
||||
qa_prompt = qa_prompt or DEFAULT_QA_PROMPT
|
||||
chat_history = chat_history or []
|
||||
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
|
||||
if not isinstance(llm, OpenAI):
|
||||
raise ValueError("llm must be a OpenAI instance")
|
||||
if callback_manager is not None:
|
||||
llm.callback_manager = callback_manager
|
||||
|
||||
memory = memory or memory_cls.from_defaults(chat_history=chat_history, llm=llm)
|
||||
|
||||
if not is_function_calling_model(llm.model):
|
||||
raise ValueError(
|
||||
f"Model name {llm.model} does not support function calling API."
|
||||
)
|
||||
if system_prompt is not None:
|
||||
if prefix_messages is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both system_prompt and prefix_messages"
|
||||
)
|
||||
prefix_messages = [ChatMessage(content=system_prompt, role="system")]
|
||||
|
||||
prefix_messages = prefix_messages or []
|
||||
|
||||
return cls(
|
||||
tools=tools,
|
||||
retriever=retriever,
|
||||
qa_prompt=qa_prompt,
|
||||
context_separator=context_separator,
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
prefix_messages=prefix_messages,
|
||||
verbose=verbose,
|
||||
max_function_calls=max_function_calls,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
|
||||
def _get_tools(self, message: str) -> List[BaseTool]:
|
||||
"""Get tools."""
|
||||
return self._tools
|
||||
|
||||
def _build_formatted_message(self, message: str) -> str:
|
||||
# augment user message
|
||||
retrieved_nodes_w_scores: List[NodeWithScore] = self._retriever.retrieve(
|
||||
message
|
||||
)
|
||||
retrieved_nodes = [node.node for node in retrieved_nodes_w_scores]
|
||||
retrieved_texts = [node.get_content() for node in retrieved_nodes]
|
||||
|
||||
# format message
|
||||
context_str = self._context_separator.join(retrieved_texts)
|
||||
return self._qa_prompt.format(context_str=context_str, query_str=message)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
) -> AgentChatResponse:
|
||||
"""Chat."""
|
||||
formatted_message = self._build_formatted_message(message)
|
||||
if self._verbose:
|
||||
print_text(formatted_message + "\n", color="yellow")
|
||||
|
||||
return super().chat(
|
||||
formatted_message, chat_history=chat_history, tool_choice=tool_choice
|
||||
)
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
) -> AgentChatResponse:
|
||||
"""Chat."""
|
||||
formatted_message = self._build_formatted_message(message)
|
||||
if self._verbose:
|
||||
print_text(formatted_message + "\n", color="yellow")
|
||||
|
||||
return await super().achat(
|
||||
formatted_message, chat_history=chat_history, tool_choice=tool_choice
|
||||
)
|
||||
|
||||
def get_tools(self, message: str) -> List[BaseTool]:
|
||||
"""Get tools."""
|
||||
return self._get_tools(message)
|
||||
|
|
@ -0,0 +1,610 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast, get_args
|
||||
|
||||
from llama_index.agent.openai.utils import get_function_by_name
|
||||
from llama_index.agent.types import BaseAgent
|
||||
from llama_index.callbacks import (
|
||||
CallbackManager,
|
||||
CBEventType,
|
||||
EventPayload,
|
||||
trace_method,
|
||||
)
|
||||
from llama_index.chat_engine.types import (
|
||||
AGENT_CHAT_RESPONSE_TYPE,
|
||||
AgentChatResponse,
|
||||
ChatResponseMode,
|
||||
StreamingAgentChatResponse,
|
||||
)
|
||||
from llama_index.core.llms.types import ChatMessage, ChatResponse, MessageRole
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.llms.openai_utils import OpenAIToolCall
|
||||
from llama_index.memory import BaseMemory, ChatMemoryBuffer
|
||||
from llama_index.objects.base import ObjectRetriever
|
||||
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
DEFAULT_MAX_FUNCTION_CALLS = 5
|
||||
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
|
||||
|
||||
|
||||
def call_tool_with_error_handling(
|
||||
tool: BaseTool,
|
||||
input_dict: Dict,
|
||||
error_message: Optional[str] = None,
|
||||
raise_error: bool = False,
|
||||
) -> ToolOutput:
|
||||
"""Call tool with error handling.
|
||||
|
||||
Input is a dictionary with args and kwargs
|
||||
|
||||
"""
|
||||
try:
|
||||
return tool(**input_dict)
|
||||
except Exception as e:
|
||||
if raise_error:
|
||||
raise
|
||||
error_message = error_message or f"Error: {e!s}"
|
||||
return ToolOutput(
|
||||
content=error_message,
|
||||
tool_name=tool.metadata.name,
|
||||
raw_input={"kwargs": input_dict},
|
||||
raw_output=e,
|
||||
)
|
||||
|
||||
|
||||
def call_function(
|
||||
tools: List[BaseTool],
|
||||
tool_call: OpenAIToolCall,
|
||||
verbose: bool = False,
|
||||
) -> Tuple[ChatMessage, ToolOutput]:
|
||||
"""Call a function and return the output as a string."""
|
||||
# validations to get passed mypy
|
||||
assert tool_call.id is not None
|
||||
assert tool_call.function is not None
|
||||
assert tool_call.function.name is not None
|
||||
assert tool_call.function.arguments is not None
|
||||
|
||||
id_ = tool_call.id
|
||||
function_call = tool_call.function
|
||||
name = tool_call.function.name
|
||||
arguments_str = tool_call.function.arguments
|
||||
if verbose:
|
||||
print("=== Calling Function ===")
|
||||
print(f"Calling function: {name} with args: {arguments_str}")
|
||||
tool = get_function_by_name(tools, name)
|
||||
argument_dict = json.loads(arguments_str)
|
||||
|
||||
# Call tool
|
||||
# Use default error message
|
||||
output = call_tool_with_error_handling(tool, argument_dict, error_message=None)
|
||||
if verbose:
|
||||
print(f"Got output: {output!s}")
|
||||
print("========================\n")
|
||||
return (
|
||||
ChatMessage(
|
||||
content=str(output),
|
||||
role=MessageRole.TOOL,
|
||||
additional_kwargs={
|
||||
"name": name,
|
||||
"tool_call_id": id_,
|
||||
},
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
|
||||
async def acall_function(
|
||||
tools: List[BaseTool], tool_call: OpenAIToolCall, verbose: bool = False
|
||||
) -> Tuple[ChatMessage, ToolOutput]:
|
||||
"""Call a function and return the output as a string."""
|
||||
# validations to get passed mypy
|
||||
assert tool_call.id is not None
|
||||
assert tool_call.function is not None
|
||||
assert tool_call.function.name is not None
|
||||
assert tool_call.function.arguments is not None
|
||||
|
||||
id_ = tool_call.id
|
||||
function_call = tool_call.function
|
||||
name = tool_call.function.name
|
||||
arguments_str = tool_call.function.arguments
|
||||
if verbose:
|
||||
print("=== Calling Function ===")
|
||||
print(f"Calling function: {name} with args: {arguments_str}")
|
||||
tool = get_function_by_name(tools, name)
|
||||
async_tool = adapt_to_async_tool(tool)
|
||||
argument_dict = json.loads(arguments_str)
|
||||
output = await async_tool.acall(**argument_dict)
|
||||
if verbose:
|
||||
print(f"Got output: {output!s}")
|
||||
print("========================\n")
|
||||
return (
|
||||
ChatMessage(
|
||||
content=str(output),
|
||||
role=MessageRole.TOOL,
|
||||
additional_kwargs={
|
||||
"name": name,
|
||||
"tool_call_id": id_,
|
||||
},
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
|
||||
def resolve_tool_choice(tool_choice: Union[str, dict] = "auto") -> Union[str, dict]:
|
||||
"""Resolve tool choice.
|
||||
|
||||
If tool_choice is a function name string, return the appropriate dict.
|
||||
"""
|
||||
if isinstance(tool_choice, str) and tool_choice not in ["none", "auto"]:
|
||||
return {"type": "function", "function": {"name": tool_choice}}
|
||||
|
||||
return tool_choice
|
||||
|
||||
|
||||
class BaseOpenAIAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
llm: OpenAI,
|
||||
memory: BaseMemory,
|
||||
prefix_messages: List[ChatMessage],
|
||||
verbose: bool,
|
||||
max_function_calls: int,
|
||||
callback_manager: Optional[CallbackManager],
|
||||
):
|
||||
self._llm = llm
|
||||
self._verbose = verbose
|
||||
self._max_function_calls = max_function_calls
|
||||
self.prefix_messages = prefix_messages
|
||||
self.memory = memory
|
||||
self.callback_manager = callback_manager or self._llm.callback_manager
|
||||
self.sources: List[ToolOutput] = []
|
||||
|
||||
@property
|
||||
def chat_history(self) -> List[ChatMessage]:
|
||||
return self.memory.get_all()
|
||||
|
||||
@property
|
||||
def all_messages(self) -> List[ChatMessage]:
|
||||
return self.prefix_messages + self.memory.get()
|
||||
|
||||
@property
|
||||
def latest_function_call(self) -> Optional[dict]:
|
||||
return self.memory.get_all()[-1].additional_kwargs.get("function_call", None)
|
||||
|
||||
@property
|
||||
def latest_tool_calls(self) -> Optional[List[OpenAIToolCall]]:
|
||||
return self.memory.get_all()[-1].additional_kwargs.get("tool_calls", None)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.memory.reset()
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self, message: str) -> List[BaseTool]:
|
||||
"""Get tools."""
|
||||
|
||||
def _should_continue(
|
||||
self, tool_calls: Optional[List[OpenAIToolCall]], n_function_calls: int
|
||||
) -> bool:
|
||||
if n_function_calls > self._max_function_calls:
|
||||
return False
|
||||
if not tool_calls:
|
||||
return False
|
||||
return True
|
||||
|
||||
def init_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> Tuple[List[BaseTool], List[dict]]:
|
||||
if chat_history is not None:
|
||||
self.memory.set(chat_history)
|
||||
self.sources = []
|
||||
self.memory.put(ChatMessage(content=message, role=MessageRole.USER))
|
||||
tools = self.get_tools(message)
|
||||
openai_tools = [tool.metadata.to_openai_tool() for tool in tools]
|
||||
return tools, openai_tools
|
||||
|
||||
def _process_message(self, chat_response: ChatResponse) -> AgentChatResponse:
|
||||
ai_message = chat_response.message
|
||||
self.memory.put(ai_message)
|
||||
return AgentChatResponse(response=str(ai_message.content), sources=self.sources)
|
||||
|
||||
def _get_stream_ai_response(
|
||||
self, **llm_chat_kwargs: Any
|
||||
) -> StreamingAgentChatResponse:
|
||||
chat_stream_response = StreamingAgentChatResponse(
|
||||
chat_stream=self._llm.stream_chat(**llm_chat_kwargs),
|
||||
sources=self.sources,
|
||||
)
|
||||
# Get the response in a separate thread so we can yield the response
|
||||
thread = Thread(
|
||||
target=chat_stream_response.write_response_to_history,
|
||||
args=(self.memory,),
|
||||
)
|
||||
thread.start()
|
||||
# Wait for the event to be set
|
||||
chat_stream_response._is_function_not_none_thread_event.wait()
|
||||
# If it is executing an openAI function, wait for the thread to finish
|
||||
if chat_stream_response._is_function:
|
||||
thread.join()
|
||||
|
||||
# if it's false, return the answer (to stream)
|
||||
return chat_stream_response
|
||||
|
||||
async def _get_async_stream_ai_response(
|
||||
self, **llm_chat_kwargs: Any
|
||||
) -> StreamingAgentChatResponse:
|
||||
chat_stream_response = StreamingAgentChatResponse(
|
||||
achat_stream=await self._llm.astream_chat(**llm_chat_kwargs),
|
||||
sources=self.sources,
|
||||
)
|
||||
# create task to write chat response to history
|
||||
asyncio.create_task(
|
||||
chat_stream_response.awrite_response_to_history(self.memory)
|
||||
)
|
||||
# wait until openAI functions stop executing
|
||||
await chat_stream_response._is_function_false_event.wait()
|
||||
# return response stream
|
||||
return chat_stream_response
|
||||
|
||||
def _call_function(self, tools: List[BaseTool], tool_call: OpenAIToolCall) -> None:
|
||||
function_call = tool_call.function
|
||||
# validations to get passed mypy
|
||||
assert function_call is not None
|
||||
assert function_call.name is not None
|
||||
assert function_call.arguments is not None
|
||||
|
||||
with self.callback_manager.event(
|
||||
CBEventType.FUNCTION_CALL,
|
||||
payload={
|
||||
EventPayload.FUNCTION_CALL: function_call.arguments,
|
||||
EventPayload.TOOL: get_function_by_name(
|
||||
tools, function_call.name
|
||||
).metadata,
|
||||
},
|
||||
) as event:
|
||||
function_message, tool_output = call_function(
|
||||
tools, tool_call, verbose=self._verbose
|
||||
)
|
||||
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
||||
self.sources.append(tool_output)
|
||||
self.memory.put(function_message)
|
||||
|
||||
async def _acall_function(
|
||||
self, tools: List[BaseTool], tool_call: OpenAIToolCall
|
||||
) -> None:
|
||||
function_call = tool_call.function
|
||||
# validations to get passed mypy
|
||||
assert function_call is not None
|
||||
assert function_call.name is not None
|
||||
assert function_call.arguments is not None
|
||||
|
||||
with self.callback_manager.event(
|
||||
CBEventType.FUNCTION_CALL,
|
||||
payload={
|
||||
EventPayload.FUNCTION_CALL: function_call.arguments,
|
||||
EventPayload.TOOL: get_function_by_name(
|
||||
tools, function_call.name
|
||||
).metadata,
|
||||
},
|
||||
) as event:
|
||||
function_message, tool_output = await acall_function(
|
||||
tools, tool_call, verbose=self._verbose
|
||||
)
|
||||
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
||||
self.sources.append(tool_output)
|
||||
self.memory.put(function_message)
|
||||
|
||||
def _get_llm_chat_kwargs(
|
||||
self, openai_tools: List[dict], tool_choice: Union[str, dict] = "auto"
|
||||
) -> Dict[str, Any]:
|
||||
llm_chat_kwargs: dict = {"messages": self.all_messages}
|
||||
if openai_tools:
|
||||
llm_chat_kwargs.update(
|
||||
tools=openai_tools, tool_choice=resolve_tool_choice(tool_choice)
|
||||
)
|
||||
return llm_chat_kwargs
|
||||
|
||||
def _get_agent_response(
|
||||
self, mode: ChatResponseMode, **llm_chat_kwargs: Any
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
if mode == ChatResponseMode.WAIT:
|
||||
chat_response: ChatResponse = self._llm.chat(**llm_chat_kwargs)
|
||||
return self._process_message(chat_response)
|
||||
elif mode == ChatResponseMode.STREAM:
|
||||
return self._get_stream_ai_response(**llm_chat_kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _get_async_agent_response(
|
||||
self, mode: ChatResponseMode, **llm_chat_kwargs: Any
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
if mode == ChatResponseMode.WAIT:
|
||||
chat_response: ChatResponse = await self._llm.achat(**llm_chat_kwargs)
|
||||
return self._process_message(chat_response)
|
||||
elif mode == ChatResponseMode.STREAM:
|
||||
return await self._get_async_stream_ai_response(**llm_chat_kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
tools, openai_tools = self.init_chat(message, chat_history)
|
||||
n_function_calls = 0
|
||||
|
||||
# Loop until no more function calls or max_function_calls is reached
|
||||
current_tool_choice = tool_choice
|
||||
ix = 0
|
||||
while True:
|
||||
ix += 1
|
||||
if self._verbose:
|
||||
print(f"STARTING TURN {ix}\n---------------\n")
|
||||
llm_chat_kwargs = self._get_llm_chat_kwargs(
|
||||
openai_tools, current_tool_choice
|
||||
)
|
||||
agent_chat_response = self._get_agent_response(mode=mode, **llm_chat_kwargs)
|
||||
if not self._should_continue(self.latest_tool_calls, n_function_calls):
|
||||
logger.debug("Break: should continue False")
|
||||
break
|
||||
# iterate through all the tool calls
|
||||
logger.debug(f"Continue to tool calls: {self.latest_tool_calls}")
|
||||
if self.latest_tool_calls is not None:
|
||||
for tool_call in self.latest_tool_calls:
|
||||
# Some validation
|
||||
if not isinstance(tool_call, get_args(OpenAIToolCall)):
|
||||
raise ValueError("Invalid tool_call object")
|
||||
|
||||
if tool_call.type != "function":
|
||||
raise ValueError("Invalid tool type. Unsupported by OpenAI")
|
||||
# TODO: maybe execute this with multi-threading
|
||||
self._call_function(tools, tool_call)
|
||||
# change function call to the default value, if a custom function was given
|
||||
# as an argument (none and auto are predefined by OpenAI)
|
||||
if current_tool_choice not in ("auto", "none"):
|
||||
current_tool_choice = "auto"
|
||||
n_function_calls += 1
|
||||
|
||||
return agent_chat_response
|
||||
|
||||
async def _achat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
tools, functions = self.init_chat(message, chat_history)
|
||||
n_function_calls = 0
|
||||
|
||||
# Loop until no more function calls or max_function_calls is reached
|
||||
current_tool_choice = tool_choice
|
||||
ix = 0
|
||||
while True:
|
||||
ix += 1
|
||||
if self._verbose:
|
||||
print(f"STARTING TURN {ix}\n---------------\n")
|
||||
llm_chat_kwargs = self._get_llm_chat_kwargs(functions, current_tool_choice)
|
||||
agent_chat_response = await self._get_async_agent_response(
|
||||
mode=mode, **llm_chat_kwargs
|
||||
)
|
||||
if not self._should_continue(self.latest_tool_calls, n_function_calls):
|
||||
break
|
||||
# iterate through all the tool calls
|
||||
if self.latest_tool_calls is not None:
|
||||
for tool_call in self.latest_tool_calls:
|
||||
# Some validation
|
||||
if not isinstance(tool_call, get_args(OpenAIToolCall)):
|
||||
raise ValueError("Invalid tool_call object")
|
||||
|
||||
if tool_call.type != "function":
|
||||
raise ValueError("Invalid tool type. Unsupported by OpenAI")
|
||||
|
||||
# TODO: maybe execute this with multi-threading
|
||||
await self._acall_function(tools, tool_call)
|
||||
# change function call to the default value, if a custom function was given
|
||||
# as an argument (none and auto are predefined by OpenAI)
|
||||
if current_tool_choice not in ("auto", "none"):
|
||||
current_tool_choice = "auto"
|
||||
n_function_calls += 1
|
||||
|
||||
return agent_chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
def chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
) -> AgentChatResponse:
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = self._chat(
|
||||
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
|
||||
)
|
||||
assert isinstance(chat_response, AgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
async def achat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
) -> AgentChatResponse:
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = await self._achat(
|
||||
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
|
||||
)
|
||||
assert isinstance(chat_response, AgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
def stream_chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
) -> StreamingAgentChatResponse:
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = self._chat(
|
||||
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
|
||||
)
|
||||
assert isinstance(chat_response, StreamingAgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
async def astream_chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
) -> StreamingAgentChatResponse:
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = await self._achat(
|
||||
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
|
||||
)
|
||||
assert isinstance(chat_response, StreamingAgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
|
||||
class OpenAIAgent(BaseOpenAIAgent):
|
||||
"""OpenAI (function calling) Agent.
|
||||
|
||||
Uses the OpenAI function API to reason about whether to
|
||||
use a tool, and returning the response to the user.
|
||||
|
||||
Supports both a flat list of tools as well as retrieval over the tools.
|
||||
|
||||
Args:
|
||||
tools (List[BaseTool]): List of tools to use.
|
||||
llm (OpenAI): OpenAI instance.
|
||||
memory (BaseMemory): Memory to use.
|
||||
prefix_messages (List[ChatMessage]): Prefix messages to use.
|
||||
verbose (Optional[bool]): Whether to print verbose output. Defaults to False.
|
||||
max_function_calls (Optional[int]): Maximum number of function calls.
|
||||
Defaults to DEFAULT_MAX_FUNCTION_CALLS.
|
||||
callback_manager (Optional[CallbackManager]): Callback manager to use.
|
||||
Defaults to None.
|
||||
tool_retriever (ObjectRetriever[BaseTool]): Object retriever to retrieve tools.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[BaseTool],
|
||||
llm: OpenAI,
|
||||
memory: BaseMemory,
|
||||
prefix_messages: List[ChatMessage],
|
||||
verbose: bool = False,
|
||||
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
prefix_messages=prefix_messages,
|
||||
verbose=verbose,
|
||||
max_function_calls=max_function_calls,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
if len(tools) > 0 and tool_retriever is not None:
|
||||
raise ValueError("Cannot specify both tools and tool_retriever")
|
||||
elif len(tools) > 0:
|
||||
self._get_tools = lambda _: tools
|
||||
elif tool_retriever is not None:
|
||||
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
|
||||
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
|
||||
else:
|
||||
# no tools
|
||||
self._get_tools = lambda _: []
|
||||
|
||||
@classmethod
|
||||
def from_tools(
|
||||
cls,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
llm: Optional[LLM] = None,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
|
||||
verbose: bool = False,
|
||||
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
prefix_messages: Optional[List[ChatMessage]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "OpenAIAgent":
|
||||
"""Create an OpenAIAgent from a list of tools.
|
||||
|
||||
Similar to `from_defaults` in other classes, this method will
|
||||
infer defaults for a variety of parameters, including the LLM,
|
||||
if they are not specified.
|
||||
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
chat_history = chat_history or []
|
||||
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
|
||||
if not isinstance(llm, OpenAI):
|
||||
raise ValueError("llm must be a OpenAI instance")
|
||||
|
||||
if callback_manager is not None:
|
||||
llm.callback_manager = callback_manager
|
||||
|
||||
memory = memory or memory_cls.from_defaults(chat_history, llm=llm)
|
||||
|
||||
if not llm.metadata.is_function_calling_model:
|
||||
raise ValueError(
|
||||
f"Model name {llm.model} does not support function calling API. "
|
||||
)
|
||||
|
||||
if system_prompt is not None:
|
||||
if prefix_messages is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both system_prompt and prefix_messages"
|
||||
)
|
||||
prefix_messages = [ChatMessage(content=system_prompt, role="system")]
|
||||
|
||||
prefix_messages = prefix_messages or []
|
||||
|
||||
return cls(
|
||||
tools=tools,
|
||||
tool_retriever=tool_retriever,
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
prefix_messages=prefix_messages,
|
||||
verbose=verbose,
|
||||
max_function_calls=max_function_calls,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
|
||||
def get_tools(self, message: str) -> List[BaseTool]:
|
||||
"""Get tools."""
|
||||
return self._get_tools(message)
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Init params."""
|
||||
|
|
@ -0,0 +1,526 @@
|
|||
import asyncio
|
||||
from itertools import chain
|
||||
from threading import Thread
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
from llama_index.agent.react.formatter import ReActChatFormatter
|
||||
from llama_index.agent.react.output_parser import ReActOutputParser
|
||||
from llama_index.agent.react.types import (
|
||||
ActionReasoningStep,
|
||||
BaseReasoningStep,
|
||||
ObservationReasoningStep,
|
||||
ResponseReasoningStep,
|
||||
)
|
||||
from llama_index.agent.types import BaseAgent
|
||||
from llama_index.callbacks import (
|
||||
CallbackManager,
|
||||
CBEventType,
|
||||
EventPayload,
|
||||
trace_method,
|
||||
)
|
||||
from llama_index.chat_engine.types import AgentChatResponse, StreamingAgentChatResponse
|
||||
from llama_index.core.llms.types import MessageRole
|
||||
from llama_index.llms.base import ChatMessage, ChatResponse
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
|
||||
from llama_index.memory.types import BaseMemory
|
||||
from llama_index.objects.base import ObjectRetriever
|
||||
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
|
||||
from llama_index.tools.types import AsyncBaseTool
|
||||
from llama_index.utils import print_text, unit_generator
|
||||
|
||||
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
|
||||
|
||||
|
||||
class ReActAgent(BaseAgent):
|
||||
"""ReAct agent.
|
||||
|
||||
Uses a ReAct prompt that can be used in both chat and text
|
||||
completion endpoints.
|
||||
|
||||
Can take in a set of tools that require structured inputs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: Sequence[BaseTool],
|
||||
llm: LLM,
|
||||
memory: BaseMemory,
|
||||
max_iterations: int = 10,
|
||||
react_chat_formatter: Optional[ReActChatFormatter] = None,
|
||||
output_parser: Optional[ReActOutputParser] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
) -> None:
|
||||
super().__init__(callback_manager=callback_manager or llm.callback_manager)
|
||||
self._llm = llm
|
||||
self._memory = memory
|
||||
self._max_iterations = max_iterations
|
||||
self._react_chat_formatter = react_chat_formatter or ReActChatFormatter()
|
||||
self._output_parser = output_parser or ReActOutputParser()
|
||||
self._verbose = verbose
|
||||
self.sources: List[ToolOutput] = []
|
||||
|
||||
if len(tools) > 0 and tool_retriever is not None:
|
||||
raise ValueError("Cannot specify both tools and tool_retriever")
|
||||
elif len(tools) > 0:
|
||||
self._get_tools = lambda _: tools
|
||||
elif tool_retriever is not None:
|
||||
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
|
||||
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
|
||||
else:
|
||||
self._get_tools = lambda _: []
|
||||
|
||||
@classmethod
|
||||
def from_tools(
|
||||
cls,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
llm: Optional[LLM] = None,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
|
||||
max_iterations: int = 10,
|
||||
react_chat_formatter: Optional[ReActChatFormatter] = None,
|
||||
output_parser: Optional[ReActOutputParser] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "ReActAgent":
|
||||
"""Convenience constructor method from set of of BaseTools (Optional).
|
||||
|
||||
NOTE: kwargs should have been exhausted by this point. In other words
|
||||
the various upstream components such as BaseSynthesizer (response synthesizer)
|
||||
or BaseRetriever should have picked up off their respective kwargs in their
|
||||
constructions.
|
||||
|
||||
Returns:
|
||||
ReActAgent
|
||||
"""
|
||||
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
|
||||
if callback_manager is not None:
|
||||
llm.callback_manager = callback_manager
|
||||
memory = memory or memory_cls.from_defaults(
|
||||
chat_history=chat_history or [], llm=llm
|
||||
)
|
||||
return cls(
|
||||
tools=tools or [],
|
||||
tool_retriever=tool_retriever,
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
max_iterations=max_iterations,
|
||||
react_chat_formatter=react_chat_formatter,
|
||||
output_parser=output_parser,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@property
|
||||
def chat_history(self) -> List[ChatMessage]:
|
||||
"""Chat history."""
|
||||
return self._memory.get_all()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._memory.reset()
|
||||
|
||||
def _extract_reasoning_step(
|
||||
self, output: ChatResponse, is_streaming: bool = False
|
||||
) -> Tuple[str, List[BaseReasoningStep], bool]:
|
||||
"""
|
||||
Extracts the reasoning step from the given output.
|
||||
|
||||
This method parses the message content from the output,
|
||||
extracts the reasoning step, and determines whether the processing is
|
||||
complete. It also performs validation checks on the output and
|
||||
handles possible errors.
|
||||
"""
|
||||
if output.message.content is None:
|
||||
raise ValueError("Got empty message.")
|
||||
message_content = output.message.content
|
||||
current_reasoning = []
|
||||
try:
|
||||
reasoning_step = self._output_parser.parse(message_content, is_streaming)
|
||||
except BaseException as exc:
|
||||
raise ValueError(f"Could not parse output: {message_content}") from exc
|
||||
if self._verbose:
|
||||
print_text(f"{reasoning_step.get_content()}\n", color="pink")
|
||||
current_reasoning.append(reasoning_step)
|
||||
|
||||
if reasoning_step.is_done:
|
||||
return message_content, current_reasoning, True
|
||||
|
||||
reasoning_step = cast(ActionReasoningStep, reasoning_step)
|
||||
if not isinstance(reasoning_step, ActionReasoningStep):
|
||||
raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}")
|
||||
|
||||
return message_content, current_reasoning, False
|
||||
|
||||
def _process_actions(
|
||||
self,
|
||||
tools: Sequence[AsyncBaseTool],
|
||||
output: ChatResponse,
|
||||
is_streaming: bool = False,
|
||||
) -> Tuple[List[BaseReasoningStep], bool]:
|
||||
tools_dict: Dict[str, AsyncBaseTool] = {
|
||||
tool.metadata.get_name(): tool for tool in tools
|
||||
}
|
||||
_, current_reasoning, is_done = self._extract_reasoning_step(
|
||||
output, is_streaming
|
||||
)
|
||||
|
||||
if is_done:
|
||||
return current_reasoning, True
|
||||
|
||||
# call tool with input
|
||||
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
|
||||
tool = tools_dict[reasoning_step.action]
|
||||
with self.callback_manager.event(
|
||||
CBEventType.FUNCTION_CALL,
|
||||
payload={
|
||||
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
|
||||
EventPayload.TOOL: tool.metadata,
|
||||
},
|
||||
) as event:
|
||||
tool_output = tool.call(**reasoning_step.action_input)
|
||||
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
||||
|
||||
self.sources.append(tool_output)
|
||||
|
||||
observation_step = ObservationReasoningStep(observation=str(tool_output))
|
||||
current_reasoning.append(observation_step)
|
||||
if self._verbose:
|
||||
print_text(f"{observation_step.get_content()}\n", color="blue")
|
||||
return current_reasoning, False
|
||||
|
||||
async def _aprocess_actions(
|
||||
self,
|
||||
tools: Sequence[AsyncBaseTool],
|
||||
output: ChatResponse,
|
||||
is_streaming: bool = False,
|
||||
) -> Tuple[List[BaseReasoningStep], bool]:
|
||||
tools_dict = {tool.metadata.name: tool for tool in tools}
|
||||
_, current_reasoning, is_done = self._extract_reasoning_step(
|
||||
output, is_streaming
|
||||
)
|
||||
|
||||
if is_done:
|
||||
return current_reasoning, True
|
||||
|
||||
# call tool with input
|
||||
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
|
||||
tool = tools_dict[reasoning_step.action]
|
||||
with self.callback_manager.event(
|
||||
CBEventType.FUNCTION_CALL,
|
||||
payload={
|
||||
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
|
||||
EventPayload.TOOL: tool.metadata,
|
||||
},
|
||||
) as event:
|
||||
tool_output = await tool.acall(**reasoning_step.action_input)
|
||||
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
||||
|
||||
self.sources.append(tool_output)
|
||||
|
||||
observation_step = ObservationReasoningStep(observation=str(tool_output))
|
||||
current_reasoning.append(observation_step)
|
||||
if self._verbose:
|
||||
print_text(f"{observation_step.get_content()}\n", color="blue")
|
||||
return current_reasoning, False
|
||||
|
||||
def _get_response(
|
||||
self,
|
||||
current_reasoning: List[BaseReasoningStep],
|
||||
) -> AgentChatResponse:
|
||||
"""Get response from reasoning steps."""
|
||||
if len(current_reasoning) == 0:
|
||||
raise ValueError("No reasoning steps were taken.")
|
||||
elif len(current_reasoning) == self._max_iterations:
|
||||
raise ValueError("Reached max iterations.")
|
||||
|
||||
response_step = cast(ResponseReasoningStep, current_reasoning[-1])
|
||||
|
||||
# TODO: add sources from reasoning steps
|
||||
return AgentChatResponse(response=response_step.response, sources=self.sources)
|
||||
|
||||
def _infer_stream_chunk_is_final(self, chunk: ChatResponse) -> bool:
|
||||
"""Infers if a chunk from a live stream is the start of the final
|
||||
reasoning step. (i.e., and should eventually become
|
||||
ResponseReasoningStep — not part of this function's logic tho.).
|
||||
|
||||
Args:
|
||||
chunk (ChatResponse): the current chunk stream to check
|
||||
|
||||
Returns:
|
||||
bool: Boolean on whether the chunk is the start of the final response
|
||||
"""
|
||||
latest_content = chunk.message.content
|
||||
if latest_content:
|
||||
if not latest_content.startswith(
|
||||
"Thought"
|
||||
): # doesn't follow thought-action format
|
||||
return True
|
||||
else:
|
||||
if "Answer: " in latest_content:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _add_back_chunk_to_stream(
|
||||
self, chunk: ChatResponse, chat_stream: Generator[ChatResponse, None, None]
|
||||
) -> Generator[ChatResponse, None, None]:
|
||||
"""Helper method for adding back initial chunk stream of final response
|
||||
back to the rest of the chat_stream.
|
||||
|
||||
Args:
|
||||
chunk (ChatResponse): the chunk to add back to the beginning of the
|
||||
chat_stream.
|
||||
|
||||
Return:
|
||||
Generator[ChatResponse, None, None]: the updated chat_stream
|
||||
"""
|
||||
updated_stream = chain.from_iterable( # need to add back partial response chunk
|
||||
[
|
||||
unit_generator(chunk),
|
||||
chat_stream,
|
||||
]
|
||||
)
|
||||
# use cast to avoid mypy issue with chain and Generator
|
||||
updated_stream_c: Generator[ChatResponse, None, None] = cast(
|
||||
Generator[ChatResponse, None, None], updated_stream
|
||||
)
|
||||
return updated_stream_c
|
||||
|
||||
async def _async_add_back_chunk_to_stream(
|
||||
self, chunk: ChatResponse, chat_stream: AsyncGenerator[ChatResponse, None]
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Helper method for adding back initial chunk stream of final response
|
||||
back to the rest of the chat_stream.
|
||||
|
||||
NOTE: this itself is not an async function.
|
||||
|
||||
Args:
|
||||
chunk (ChatResponse): the chunk to add back to the beginning of the
|
||||
chat_stream.
|
||||
|
||||
Return:
|
||||
AsyncGenerator[ChatResponse, None]: the updated async chat_stream
|
||||
"""
|
||||
yield chunk
|
||||
async for item in chat_stream:
|
||||
yield item
|
||||
|
||||
@trace_method("chat")
|
||||
def chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AgentChatResponse:
|
||||
"""Chat."""
|
||||
# get tools
|
||||
# TODO: do get tools dynamically at every iteration of the agent loop
|
||||
self.sources = []
|
||||
tools = self.get_tools(message)
|
||||
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
|
||||
self._memory.put(ChatMessage(content=message, role="user"))
|
||||
|
||||
current_reasoning: List[BaseReasoningStep] = []
|
||||
# start loop
|
||||
for _ in range(self._max_iterations):
|
||||
# prepare inputs
|
||||
input_chat = self._react_chat_formatter.format(
|
||||
tools,
|
||||
chat_history=self._memory.get(),
|
||||
current_reasoning=current_reasoning,
|
||||
)
|
||||
# send prompt
|
||||
chat_response = self._llm.chat(input_chat)
|
||||
# given react prompt outputs, call tools or return response
|
||||
reasoning_steps, is_done = self._process_actions(
|
||||
tools, output=chat_response
|
||||
)
|
||||
current_reasoning.extend(reasoning_steps)
|
||||
if is_done:
|
||||
break
|
||||
|
||||
response = self._get_response(current_reasoning)
|
||||
self._memory.put(
|
||||
ChatMessage(content=response.response, role=MessageRole.ASSISTANT)
|
||||
)
|
||||
return response
|
||||
|
||||
@trace_method("chat")
|
||||
async def achat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AgentChatResponse:
|
||||
# get tools
|
||||
# TODO: do get tools dynamically at every iteration of the agent loop
|
||||
self.sources = []
|
||||
tools = self.get_tools(message)
|
||||
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
|
||||
self._memory.put(ChatMessage(content=message, role="user"))
|
||||
|
||||
current_reasoning: List[BaseReasoningStep] = []
|
||||
# start loop
|
||||
for _ in range(self._max_iterations):
|
||||
# prepare inputs
|
||||
input_chat = self._react_chat_formatter.format(
|
||||
tools,
|
||||
chat_history=self._memory.get(),
|
||||
current_reasoning=current_reasoning,
|
||||
)
|
||||
# send prompt
|
||||
chat_response = await self._llm.achat(input_chat)
|
||||
# given react prompt outputs, call tools or return response
|
||||
reasoning_steps, is_done = await self._aprocess_actions(
|
||||
tools, output=chat_response
|
||||
)
|
||||
current_reasoning.extend(reasoning_steps)
|
||||
if is_done:
|
||||
break
|
||||
|
||||
response = self._get_response(current_reasoning)
|
||||
self._memory.put(
|
||||
ChatMessage(content=response.response, role=MessageRole.ASSISTANT)
|
||||
)
|
||||
return response
|
||||
|
||||
@trace_method("chat")
|
||||
def stream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
# get tools
|
||||
# TODO: do get tools dynamically at every iteration of the agent loop
|
||||
self.sources = []
|
||||
tools = self.get_tools(message)
|
||||
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
self._memory.put(ChatMessage(content=message, role="user"))
|
||||
|
||||
current_reasoning: List[BaseReasoningStep] = []
|
||||
# start loop
|
||||
is_done, ix = False, 0
|
||||
while (not is_done) and (ix < self._max_iterations):
|
||||
ix += 1
|
||||
|
||||
# prepare inputs
|
||||
input_chat = self._react_chat_formatter.format(
|
||||
tools,
|
||||
chat_history=self._memory.get(),
|
||||
current_reasoning=current_reasoning,
|
||||
)
|
||||
# send prompt
|
||||
chat_stream = self._llm.stream_chat(input_chat)
|
||||
|
||||
# iterate over stream, break out if is final answer after the "Answer: "
|
||||
full_response = ChatResponse(
|
||||
message=ChatMessage(content=None, role="assistant")
|
||||
)
|
||||
for latest_chunk in chat_stream:
|
||||
full_response = latest_chunk
|
||||
is_done = self._infer_stream_chunk_is_final(latest_chunk)
|
||||
if is_done:
|
||||
break
|
||||
|
||||
# given react prompt outputs, call tools or return response
|
||||
reasoning_steps, _ = self._process_actions(
|
||||
tools=tools, output=full_response, is_streaming=True
|
||||
)
|
||||
current_reasoning.extend(reasoning_steps)
|
||||
|
||||
# Get the response in a separate thread so we can yield the response
|
||||
response_stream = self._add_back_chunk_to_stream(
|
||||
chunk=latest_chunk, chat_stream=chat_stream
|
||||
)
|
||||
|
||||
chat_stream_response = StreamingAgentChatResponse(
|
||||
chat_stream=response_stream,
|
||||
sources=self.sources,
|
||||
)
|
||||
thread = Thread(
|
||||
target=chat_stream_response.write_response_to_history,
|
||||
args=(self._memory,),
|
||||
)
|
||||
thread.start()
|
||||
return chat_stream_response
|
||||
|
||||
@trace_method("chat")
|
||||
async def astream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
# get tools
|
||||
# TODO: do get tools dynamically at every iteration of the agent loop
|
||||
self.sources = []
|
||||
tools = self.get_tools(message)
|
||||
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
|
||||
self._memory.put(ChatMessage(content=message, role="user"))
|
||||
|
||||
current_reasoning: List[BaseReasoningStep] = []
|
||||
# start loop
|
||||
is_done, ix = False, 0
|
||||
while (not is_done) and (ix < self._max_iterations):
|
||||
ix += 1
|
||||
|
||||
# prepare inputs
|
||||
input_chat = self._react_chat_formatter.format(
|
||||
tools,
|
||||
chat_history=self._memory.get(),
|
||||
current_reasoning=current_reasoning,
|
||||
)
|
||||
# send prompt
|
||||
chat_stream = await self._llm.astream_chat(input_chat)
|
||||
|
||||
# iterate over stream, break out if is final answer
|
||||
is_done = False
|
||||
full_response = ChatResponse(
|
||||
message=ChatMessage(content=None, role="assistant")
|
||||
)
|
||||
async for latest_chunk in chat_stream:
|
||||
full_response = latest_chunk
|
||||
is_done = self._infer_stream_chunk_is_final(latest_chunk)
|
||||
if is_done:
|
||||
break
|
||||
|
||||
# given react prompt outputs, call tools or return response
|
||||
reasoning_steps, _ = self._process_actions(
|
||||
tools=tools, output=full_response, is_streaming=True
|
||||
)
|
||||
current_reasoning.extend(reasoning_steps)
|
||||
|
||||
# Get the response in a separate thread so we can yield the response
|
||||
response_stream = self._async_add_back_chunk_to_stream(
|
||||
chunk=latest_chunk, chat_stream=chat_stream
|
||||
)
|
||||
|
||||
chat_stream_response = StreamingAgentChatResponse(
|
||||
achat_stream=response_stream, sources=self.sources
|
||||
)
|
||||
# create task to write chat response to history
|
||||
asyncio.create_task(
|
||||
chat_stream_response.awrite_response_to_history(self._memory)
|
||||
)
|
||||
# thread.start()
|
||||
return chat_stream_response
|
||||
|
||||
def get_tools(self, message: str) -> List[AsyncBaseTool]:
|
||||
"""Get tools."""
|
||||
return [adapt_to_async_tool(t) for t in self._get_tools(message)]
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
"""Retriever OpenAI agent."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from llama_index.agent.legacy.openai_agent import (
|
||||
OpenAIAgent,
|
||||
)
|
||||
from llama_index.objects.base import ObjectRetriever
|
||||
from llama_index.tools.types import BaseTool
|
||||
|
||||
|
||||
class FnRetrieverOpenAIAgent(OpenAIAgent):
|
||||
"""Function Retriever OpenAI Agent.
|
||||
|
||||
Uses our object retriever module to retrieve openai agent.
|
||||
|
||||
NOTE: This is deprecated, you can just use the base `OpenAIAgent` class by
|
||||
specifying the following:
|
||||
```
|
||||
agent = OpenAIAgent.from_tools(tool_retriever=retriever, ...)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_retriever(
|
||||
cls, retriever: ObjectRetriever[BaseTool], **kwargs: Any
|
||||
) -> "FnRetrieverOpenAIAgent":
|
||||
return cast(
|
||||
FnRetrieverOpenAIAgent, cls.from_tools(tool_retriever=retriever, **kwargs)
|
||||
)
|
||||
|
|
@ -0,0 +1,140 @@
|
|||
"""OpenAI Agent.
|
||||
|
||||
Simple wrapper around AgentRunner + OpenAIAgentWorker.
|
||||
|
||||
For the legacy implementation see:
|
||||
```python
|
||||
from llama_index.agent.legacy.openai.base import OpenAIAgent
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
)
|
||||
|
||||
from llama_index.agent.openai.step import OpenAIAgentWorker
|
||||
from llama_index.agent.runner.base import AgentRunner
|
||||
from llama_index.callbacks import (
|
||||
CallbackManager,
|
||||
)
|
||||
from llama_index.llms.base import ChatMessage
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
|
||||
from llama_index.memory.types import BaseMemory
|
||||
from llama_index.objects.base import ObjectRetriever
|
||||
from llama_index.tools import BaseTool
|
||||
|
||||
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
|
||||
|
||||
DEFAULT_MAX_FUNCTION_CALLS = 5
|
||||
|
||||
|
||||
class OpenAIAgent(AgentRunner):
|
||||
"""OpenAI agent.
|
||||
|
||||
Subclasses AgentRunner with a OpenAIAgentWorker.
|
||||
|
||||
For the legacy implementation see:
|
||||
```python
|
||||
from llama_index.agent.legacy.openai.base import OpenAIAgent
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[BaseTool],
|
||||
llm: OpenAI,
|
||||
memory: BaseMemory,
|
||||
prefix_messages: List[ChatMessage],
|
||||
verbose: bool = False,
|
||||
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
|
||||
default_tool_choice: str = "auto",
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
callback_manager = callback_manager or llm.callback_manager
|
||||
step_engine = OpenAIAgentWorker.from_tools(
|
||||
tools=tools,
|
||||
tool_retriever=tool_retriever,
|
||||
llm=llm,
|
||||
verbose=verbose,
|
||||
max_function_calls=max_function_calls,
|
||||
callback_manager=callback_manager,
|
||||
prefix_messages=prefix_messages,
|
||||
)
|
||||
super().__init__(
|
||||
step_engine,
|
||||
memory=memory,
|
||||
llm=llm,
|
||||
callback_manager=callback_manager,
|
||||
default_tool_choice=default_tool_choice,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tools(
|
||||
cls,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
llm: Optional[LLM] = None,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
|
||||
verbose: bool = False,
|
||||
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
|
||||
default_tool_choice: str = "auto",
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
prefix_messages: Optional[List[ChatMessage]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "OpenAIAgent":
|
||||
"""Create an OpenAIAgent from a list of tools.
|
||||
|
||||
Similar to `from_defaults` in other classes, this method will
|
||||
infer defaults for a variety of parameters, including the LLM,
|
||||
if they are not specified.
|
||||
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
chat_history = chat_history or []
|
||||
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
|
||||
if not isinstance(llm, OpenAI):
|
||||
raise ValueError("llm must be a OpenAI instance")
|
||||
|
||||
if callback_manager is not None:
|
||||
llm.callback_manager = callback_manager
|
||||
|
||||
memory = memory or memory_cls.from_defaults(chat_history, llm=llm)
|
||||
|
||||
if not llm.metadata.is_function_calling_model:
|
||||
raise ValueError(
|
||||
f"Model name {llm.model} does not support function calling API. "
|
||||
)
|
||||
|
||||
if system_prompt is not None:
|
||||
if prefix_messages is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both system_prompt and prefix_messages"
|
||||
)
|
||||
prefix_messages = [ChatMessage(content=system_prompt, role="system")]
|
||||
|
||||
prefix_messages = prefix_messages or []
|
||||
|
||||
return cls(
|
||||
tools=tools,
|
||||
tool_retriever=tool_retriever,
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
prefix_messages=prefix_messages,
|
||||
verbose=verbose,
|
||||
max_function_calls=max_function_calls,
|
||||
callback_manager=callback_manager,
|
||||
default_tool_choice=default_tool_choice,
|
||||
)
|
||||
|
|
@ -0,0 +1,644 @@
|
|||
"""OpenAI agent worker."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args
|
||||
|
||||
from llama_index.agent.openai.utils import resolve_tool_choice
|
||||
from llama_index.agent.types import (
|
||||
BaseAgentWorker,
|
||||
Task,
|
||||
TaskStep,
|
||||
TaskStepOutput,
|
||||
)
|
||||
from llama_index.agent.utils import add_user_step_to_memory
|
||||
from llama_index.callbacks import (
|
||||
CallbackManager,
|
||||
CBEventType,
|
||||
EventPayload,
|
||||
trace_method,
|
||||
)
|
||||
from llama_index.chat_engine.types import (
|
||||
AGENT_CHAT_RESPONSE_TYPE,
|
||||
AgentChatResponse,
|
||||
ChatResponseMode,
|
||||
StreamingAgentChatResponse,
|
||||
)
|
||||
from llama_index.core.llms.types import MessageRole
|
||||
from llama_index.llms.base import ChatMessage, ChatResponse
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.llms.openai_utils import OpenAIToolCall
|
||||
from llama_index.memory import BaseMemory, ChatMemoryBuffer
|
||||
from llama_index.memory.types import BaseMemory
|
||||
from llama_index.objects.base import ObjectRetriever
|
||||
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
DEFAULT_MAX_FUNCTION_CALLS = 5
|
||||
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
|
||||
|
||||
|
||||
def get_function_by_name(tools: List[BaseTool], name: str) -> BaseTool:
|
||||
"""Get function by name."""
|
||||
name_to_tool = {tool.metadata.name: tool for tool in tools}
|
||||
if name not in name_to_tool:
|
||||
raise ValueError(f"Tool with name {name} not found")
|
||||
return name_to_tool[name]
|
||||
|
||||
|
||||
def call_tool_with_error_handling(
|
||||
tool: BaseTool,
|
||||
input_dict: Dict,
|
||||
error_message: Optional[str] = None,
|
||||
raise_error: bool = False,
|
||||
) -> ToolOutput:
|
||||
"""Call tool with error handling.
|
||||
|
||||
Input is a dictionary with args and kwargs
|
||||
|
||||
"""
|
||||
try:
|
||||
return tool(**input_dict)
|
||||
except Exception as e:
|
||||
if raise_error:
|
||||
raise
|
||||
error_message = error_message or f"Error: {e!s}"
|
||||
return ToolOutput(
|
||||
content=error_message,
|
||||
tool_name=tool.metadata.name,
|
||||
raw_input={"kwargs": input_dict},
|
||||
raw_output=e,
|
||||
)
|
||||
|
||||
|
||||
def call_function(
|
||||
tools: List[BaseTool],
|
||||
tool_call: OpenAIToolCall,
|
||||
verbose: bool = False,
|
||||
) -> Tuple[ChatMessage, ToolOutput]:
|
||||
"""Call a function and return the output as a string."""
|
||||
# validations to get passed mypy
|
||||
assert tool_call.id is not None
|
||||
assert tool_call.function is not None
|
||||
assert tool_call.function.name is not None
|
||||
assert tool_call.function.arguments is not None
|
||||
|
||||
id_ = tool_call.id
|
||||
function_call = tool_call.function
|
||||
name = tool_call.function.name
|
||||
arguments_str = tool_call.function.arguments
|
||||
if verbose:
|
||||
print("=== Calling Function ===")
|
||||
print(f"Calling function: {name} with args: {arguments_str}")
|
||||
tool = get_function_by_name(tools, name)
|
||||
argument_dict = json.loads(arguments_str)
|
||||
|
||||
# Call tool
|
||||
# Use default error message
|
||||
output = call_tool_with_error_handling(tool, argument_dict, error_message=None)
|
||||
if verbose:
|
||||
print(f"Got output: {output!s}")
|
||||
print("========================\n")
|
||||
return (
|
||||
ChatMessage(
|
||||
content=str(output),
|
||||
role=MessageRole.TOOL,
|
||||
additional_kwargs={
|
||||
"name": name,
|
||||
"tool_call_id": id_,
|
||||
},
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
|
||||
async def acall_function(
|
||||
tools: List[BaseTool], tool_call: OpenAIToolCall, verbose: bool = False
|
||||
) -> Tuple[ChatMessage, ToolOutput]:
|
||||
"""Call a function and return the output as a string."""
|
||||
# validations to get passed mypy
|
||||
assert tool_call.id is not None
|
||||
assert tool_call.function is not None
|
||||
assert tool_call.function.name is not None
|
||||
assert tool_call.function.arguments is not None
|
||||
|
||||
id_ = tool_call.id
|
||||
function_call = tool_call.function
|
||||
name = tool_call.function.name
|
||||
arguments_str = tool_call.function.arguments
|
||||
if verbose:
|
||||
print("=== Calling Function ===")
|
||||
print(f"Calling function: {name} with args: {arguments_str}")
|
||||
tool = get_function_by_name(tools, name)
|
||||
async_tool = adapt_to_async_tool(tool)
|
||||
argument_dict = json.loads(arguments_str)
|
||||
output = await async_tool.acall(**argument_dict)
|
||||
if verbose:
|
||||
print(f"Got output: {output!s}")
|
||||
print("========================\n")
|
||||
return (
|
||||
ChatMessage(
|
||||
content=str(output),
|
||||
role=MessageRole.TOOL,
|
||||
additional_kwargs={
|
||||
"name": name,
|
||||
"tool_call_id": id_,
|
||||
},
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIAgentWorker(BaseAgentWorker):
|
||||
"""OpenAI Agent agent worker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[BaseTool],
|
||||
llm: OpenAI,
|
||||
prefix_messages: List[ChatMessage],
|
||||
verbose: bool = False,
|
||||
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
):
|
||||
self._llm = llm
|
||||
self._verbose = verbose
|
||||
self._max_function_calls = max_function_calls
|
||||
self.prefix_messages = prefix_messages
|
||||
self.callback_manager = callback_manager or self._llm.callback_manager
|
||||
|
||||
if len(tools) > 0 and tool_retriever is not None:
|
||||
raise ValueError("Cannot specify both tools and tool_retriever")
|
||||
elif len(tools) > 0:
|
||||
self._get_tools = lambda _: tools
|
||||
elif tool_retriever is not None:
|
||||
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
|
||||
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
|
||||
else:
|
||||
# no tools
|
||||
self._get_tools = lambda _: []
|
||||
|
||||
@classmethod
|
||||
def from_tools(
|
||||
cls,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
llm: Optional[LLM] = None,
|
||||
verbose: bool = False,
|
||||
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
prefix_messages: Optional[List[ChatMessage]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "OpenAIAgentWorker":
|
||||
"""Create an OpenAIAgent from a list of tools.
|
||||
|
||||
Similar to `from_defaults` in other classes, this method will
|
||||
infer defaults for a variety of parameters, including the LLM,
|
||||
if they are not specified.
|
||||
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
|
||||
if not isinstance(llm, OpenAI):
|
||||
raise ValueError("llm must be a OpenAI instance")
|
||||
|
||||
if callback_manager is not None:
|
||||
llm.callback_manager = callback_manager
|
||||
|
||||
if not llm.metadata.is_function_calling_model:
|
||||
raise ValueError(
|
||||
f"Model name {llm.model} does not support function calling API. "
|
||||
)
|
||||
|
||||
if system_prompt is not None:
|
||||
if prefix_messages is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both system_prompt and prefix_messages"
|
||||
)
|
||||
prefix_messages = [ChatMessage(content=system_prompt, role="system")]
|
||||
|
||||
prefix_messages = prefix_messages or []
|
||||
|
||||
return cls(
|
||||
tools=tools,
|
||||
tool_retriever=tool_retriever,
|
||||
llm=llm,
|
||||
prefix_messages=prefix_messages,
|
||||
verbose=verbose,
|
||||
max_function_calls=max_function_calls,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
|
||||
def get_all_messages(self, task: Task) -> List[ChatMessage]:
|
||||
return (
|
||||
self.prefix_messages
|
||||
+ task.memory.get()
|
||||
+ task.extra_state["new_memory"].get_all()
|
||||
)
|
||||
|
||||
def get_latest_tool_calls(self, task: Task) -> Optional[List[OpenAIToolCall]]:
|
||||
chat_history: List[ChatMessage] = task.extra_state["new_memory"].get_all()
|
||||
return (
|
||||
chat_history[-1].additional_kwargs.get("tool_calls", None)
|
||||
if chat_history
|
||||
else None
|
||||
)
|
||||
|
||||
def _get_llm_chat_kwargs(
|
||||
self,
|
||||
task: Task,
|
||||
openai_tools: List[dict],
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
) -> Dict[str, Any]:
|
||||
llm_chat_kwargs: dict = {"messages": self.get_all_messages(task)}
|
||||
if openai_tools:
|
||||
llm_chat_kwargs.update(
|
||||
tools=openai_tools, tool_choice=resolve_tool_choice(tool_choice)
|
||||
)
|
||||
return llm_chat_kwargs
|
||||
|
||||
def _process_message(
|
||||
self, task: Task, chat_response: ChatResponse
|
||||
) -> AgentChatResponse:
|
||||
ai_message = chat_response.message
|
||||
task.extra_state["new_memory"].put(ai_message)
|
||||
return AgentChatResponse(
|
||||
response=str(ai_message.content), sources=task.extra_state["sources"]
|
||||
)
|
||||
|
||||
def _get_stream_ai_response(
|
||||
self, task: Task, **llm_chat_kwargs: Any
|
||||
) -> StreamingAgentChatResponse:
|
||||
chat_stream_response = StreamingAgentChatResponse(
|
||||
chat_stream=self._llm.stream_chat(**llm_chat_kwargs),
|
||||
sources=task.extra_state["sources"],
|
||||
)
|
||||
# Get the response in a separate thread so we can yield the response
|
||||
thread = Thread(
|
||||
target=chat_stream_response.write_response_to_history,
|
||||
args=(task.extra_state["new_memory"],),
|
||||
)
|
||||
thread.start()
|
||||
# Wait for the event to be set
|
||||
chat_stream_response._is_function_not_none_thread_event.wait()
|
||||
# If it is executing an openAI function, wait for the thread to finish
|
||||
if chat_stream_response._is_function:
|
||||
thread.join()
|
||||
|
||||
# if it's false, return the answer (to stream)
|
||||
return chat_stream_response
|
||||
|
||||
async def _get_async_stream_ai_response(
|
||||
self, task: Task, **llm_chat_kwargs: Any
|
||||
) -> StreamingAgentChatResponse:
|
||||
chat_stream_response = StreamingAgentChatResponse(
|
||||
achat_stream=await self._llm.astream_chat(**llm_chat_kwargs),
|
||||
sources=task.extra_state["sources"],
|
||||
)
|
||||
# create task to write chat response to history
|
||||
asyncio.create_task(
|
||||
chat_stream_response.awrite_response_to_history(
|
||||
task.extra_state["new_memory"]
|
||||
)
|
||||
)
|
||||
# wait until openAI functions stop executing
|
||||
await chat_stream_response._is_function_false_event.wait()
|
||||
# return response stream
|
||||
return chat_stream_response
|
||||
|
||||
def _get_agent_response(
|
||||
self, task: Task, mode: ChatResponseMode, **llm_chat_kwargs: Any
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
if mode == ChatResponseMode.WAIT:
|
||||
chat_response: ChatResponse = self._llm.chat(**llm_chat_kwargs)
|
||||
return self._process_message(task, chat_response)
|
||||
elif mode == ChatResponseMode.STREAM:
|
||||
return self._get_stream_ai_response(task, **llm_chat_kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _get_async_agent_response(
|
||||
self, task: Task, mode: ChatResponseMode, **llm_chat_kwargs: Any
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
if mode == ChatResponseMode.WAIT:
|
||||
chat_response: ChatResponse = await self._llm.achat(**llm_chat_kwargs)
|
||||
return self._process_message(task, chat_response)
|
||||
elif mode == ChatResponseMode.STREAM:
|
||||
return await self._get_async_stream_ai_response(task, **llm_chat_kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _call_function(
|
||||
self,
|
||||
tools: List[BaseTool],
|
||||
tool_call: OpenAIToolCall,
|
||||
memory: BaseMemory,
|
||||
sources: List[ToolOutput],
|
||||
) -> None:
|
||||
function_call = tool_call.function
|
||||
# validations to get passed mypy
|
||||
assert function_call is not None
|
||||
assert function_call.name is not None
|
||||
assert function_call.arguments is not None
|
||||
|
||||
with self.callback_manager.event(
|
||||
CBEventType.FUNCTION_CALL,
|
||||
payload={
|
||||
EventPayload.FUNCTION_CALL: function_call.arguments,
|
||||
EventPayload.TOOL: get_function_by_name(
|
||||
tools, function_call.name
|
||||
).metadata,
|
||||
},
|
||||
) as event:
|
||||
function_message, tool_output = call_function(
|
||||
tools, tool_call, verbose=self._verbose
|
||||
)
|
||||
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
||||
sources.append(tool_output)
|
||||
memory.put(function_message)
|
||||
|
||||
async def _acall_function(
|
||||
self,
|
||||
tools: List[BaseTool],
|
||||
tool_call: OpenAIToolCall,
|
||||
memory: BaseMemory,
|
||||
sources: List[ToolOutput],
|
||||
) -> None:
|
||||
function_call = tool_call.function
|
||||
# validations to get passed mypy
|
||||
assert function_call is not None
|
||||
assert function_call.name is not None
|
||||
assert function_call.arguments is not None
|
||||
|
||||
with self.callback_manager.event(
|
||||
CBEventType.FUNCTION_CALL,
|
||||
payload={
|
||||
EventPayload.FUNCTION_CALL: function_call.arguments,
|
||||
EventPayload.TOOL: get_function_by_name(
|
||||
tools, function_call.name
|
||||
).metadata,
|
||||
},
|
||||
) as event:
|
||||
function_message, tool_output = await acall_function(
|
||||
tools, tool_call, verbose=self._verbose
|
||||
)
|
||||
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
||||
sources.append(tool_output)
|
||||
memory.put(function_message)
|
||||
|
||||
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
|
||||
"""Initialize step from task."""
|
||||
sources: List[ToolOutput] = []
|
||||
# temporary memory for new messages
|
||||
new_memory = ChatMemoryBuffer.from_defaults()
|
||||
# initialize task state
|
||||
task_state = {
|
||||
"sources": sources,
|
||||
"n_function_calls": 0,
|
||||
"new_memory": new_memory,
|
||||
}
|
||||
task.extra_state.update(task_state)
|
||||
|
||||
return TaskStep(
|
||||
task_id=task.task_id,
|
||||
step_id=str(uuid.uuid4()),
|
||||
input=task.input,
|
||||
)
|
||||
|
||||
def _should_continue(
|
||||
self, tool_calls: Optional[List[OpenAIToolCall]], n_function_calls: int
|
||||
) -> bool:
|
||||
if n_function_calls > self._max_function_calls:
|
||||
return False
|
||||
if not tool_calls:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_tools(self, input: str) -> List[BaseTool]:
|
||||
"""Get tools."""
|
||||
return self._get_tools(input)
|
||||
|
||||
def _run_step(
|
||||
self,
|
||||
step: TaskStep,
|
||||
task: Task,
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
if step.input is not None:
|
||||
add_user_step_to_memory(
|
||||
step, task.extra_state["new_memory"], verbose=self._verbose
|
||||
)
|
||||
# TODO: see if we want to do step-based inputs
|
||||
tools = self.get_tools(task.input)
|
||||
openai_tools = [tool.metadata.to_openai_tool() for tool in tools]
|
||||
|
||||
llm_chat_kwargs = self._get_llm_chat_kwargs(task, openai_tools, tool_choice)
|
||||
|
||||
agent_chat_response = self._get_agent_response(
|
||||
task, mode=mode, **llm_chat_kwargs
|
||||
)
|
||||
|
||||
# TODO: implement _should_continue
|
||||
latest_tool_calls = self.get_latest_tool_calls(task) or []
|
||||
if not self._should_continue(
|
||||
latest_tool_calls, task.extra_state["n_function_calls"]
|
||||
):
|
||||
is_done = True
|
||||
new_steps = []
|
||||
# TODO: return response
|
||||
else:
|
||||
is_done = False
|
||||
for tool_call in latest_tool_calls:
|
||||
# Some validation
|
||||
if not isinstance(tool_call, get_args(OpenAIToolCall)):
|
||||
raise ValueError("Invalid tool_call object")
|
||||
|
||||
if tool_call.type != "function":
|
||||
raise ValueError("Invalid tool type. Unsupported by OpenAI")
|
||||
# TODO: maybe execute this with multi-threading
|
||||
self._call_function(
|
||||
tools,
|
||||
tool_call,
|
||||
task.extra_state["new_memory"],
|
||||
task.extra_state["sources"],
|
||||
)
|
||||
# change function call to the default value, if a custom function was given
|
||||
# as an argument (none and auto are predefined by OpenAI)
|
||||
if tool_choice not in ("auto", "none"):
|
||||
tool_choice = "auto"
|
||||
task.extra_state["n_function_calls"] += 1
|
||||
new_steps = [
|
||||
step.get_next_step(
|
||||
step_id=str(uuid.uuid4()),
|
||||
# NOTE: input is unused
|
||||
input=None,
|
||||
)
|
||||
]
|
||||
|
||||
# attach next step to task
|
||||
|
||||
return TaskStepOutput(
|
||||
output=agent_chat_response,
|
||||
task_step=step,
|
||||
is_last=is_done,
|
||||
next_steps=new_steps,
|
||||
)
|
||||
|
||||
async def _arun_step(
|
||||
self,
|
||||
step: TaskStep,
|
||||
task: Task,
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
if step.input is not None:
|
||||
add_user_step_to_memory(
|
||||
step, task.extra_state["new_memory"], verbose=self._verbose
|
||||
)
|
||||
|
||||
# TODO: see if we want to do step-based inputs
|
||||
tools = self.get_tools(task.input)
|
||||
openai_tools = [tool.metadata.to_openai_tool() for tool in tools]
|
||||
|
||||
llm_chat_kwargs = self._get_llm_chat_kwargs(task, openai_tools, tool_choice)
|
||||
agent_chat_response = await self._get_async_agent_response(
|
||||
task, mode=mode, **llm_chat_kwargs
|
||||
)
|
||||
|
||||
# TODO: implement _should_continue
|
||||
latest_tool_calls = self.get_latest_tool_calls(task) or []
|
||||
if not self._should_continue(
|
||||
latest_tool_calls, task.extra_state["n_function_calls"]
|
||||
):
|
||||
is_done = True
|
||||
|
||||
else:
|
||||
is_done = False
|
||||
for tool_call in latest_tool_calls:
|
||||
# Some validation
|
||||
if not isinstance(tool_call, get_args(OpenAIToolCall)):
|
||||
raise ValueError("Invalid tool_call object")
|
||||
|
||||
if tool_call.type != "function":
|
||||
raise ValueError("Invalid tool type. Unsupported by OpenAI")
|
||||
# TODO: maybe execute this with multi-threading
|
||||
await self._acall_function(
|
||||
tools,
|
||||
tool_call,
|
||||
task.extra_state["new_memory"],
|
||||
task.extra_state["sources"],
|
||||
)
|
||||
# change function call to the default value, if a custom function was given
|
||||
# as an argument (none and auto are predefined by OpenAI)
|
||||
if tool_choice not in ("auto", "none"):
|
||||
tool_choice = "auto"
|
||||
task.extra_state["n_function_calls"] += 1
|
||||
|
||||
# generate next step, append to task queue
|
||||
new_steps = (
|
||||
[
|
||||
step.get_next_step(
|
||||
step_id=str(uuid.uuid4()),
|
||||
# NOTE: input is unused
|
||||
input=None,
|
||||
)
|
||||
]
|
||||
if not is_done
|
||||
else []
|
||||
)
|
||||
|
||||
return TaskStepOutput(
|
||||
output=agent_chat_response,
|
||||
task_step=step,
|
||||
is_last=is_done,
|
||||
next_steps=new_steps,
|
||||
)
|
||||
|
||||
@trace_method("run_step")
|
||||
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
tool_choice = kwargs.get("tool_choice", "auto")
|
||||
return self._run_step(
|
||||
step, task, mode=ChatResponseMode.WAIT, tool_choice=tool_choice
|
||||
)
|
||||
|
||||
@trace_method("run_step")
|
||||
async def arun_step(
|
||||
self, step: TaskStep, task: Task, **kwargs: Any
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async)."""
|
||||
tool_choice = kwargs.get("tool_choice", "auto")
|
||||
return await self._arun_step(
|
||||
step, task, mode=ChatResponseMode.WAIT, tool_choice=tool_choice
|
||||
)
|
||||
|
||||
@trace_method("run_step")
|
||||
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
||||
"""Run step (stream)."""
|
||||
# TODO: figure out if we need a different type for TaskStepOutput
|
||||
tool_choice = kwargs.get("tool_choice", "auto")
|
||||
return self._run_step(
|
||||
step, task, mode=ChatResponseMode.STREAM, tool_choice=tool_choice
|
||||
)
|
||||
|
||||
@trace_method("run_step")
|
||||
async def astream_step(
|
||||
self, step: TaskStep, task: Task, **kwargs: Any
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async stream)."""
|
||||
tool_choice = kwargs.get("tool_choice", "auto")
|
||||
return await self._arun_step(
|
||||
step, task, mode=ChatResponseMode.STREAM, tool_choice=tool_choice
|
||||
)
|
||||
|
||||
def finalize_task(self, task: Task, **kwargs: Any) -> None:
|
||||
"""Finalize task, after all the steps are completed."""
|
||||
# add new messages to memory
|
||||
task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
|
||||
# reset new memory
|
||||
task.extra_state["new_memory"].reset()
|
||||
|
||||
def undo_step(self, task: Task, **kwargs: Any) -> Optional[TaskStep]:
|
||||
"""Undo step from task.
|
||||
|
||||
If this cannot be implemented, return None.
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Undo is not yet implemented")
|
||||
# if len(task.completed_steps) == 0:
|
||||
# return None
|
||||
|
||||
# # pop last step output
|
||||
# last_step_output = task.completed_steps.pop()
|
||||
# # add step to the front of the queue
|
||||
# task.step_queue.appendleft(last_step_output.task_step)
|
||||
|
||||
# # undo any `step_state` variables that have changed
|
||||
# last_step_output.step_state["n_function_calls"] -= 1
|
||||
|
||||
# # TODO: we don't have memory pop capabilities yet
|
||||
# # # now pop the memory until we get to the state
|
||||
# # last_step_response = cast(AgentChatResponse, last_step_output.output)
|
||||
# # while last_step_response != task.memory.:
|
||||
# # last_message = last_step_output.task_step.memory.pop()
|
||||
# # if last_message == cast(AgentChatResponse, last_step_output.output).response:
|
||||
# # break
|
||||
|
||||
# # while cast(AgentChatResponse, last_step_output.output).response !=
|
||||
|
||||
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
||||
"""Set callback manager."""
|
||||
# TODO: make this abstractmethod (right now will break some agent impls)
|
||||
self.callback_manager = callback_manager
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
"""Utils for OpenAI agent."""
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from llama_index.tools import BaseTool
|
||||
|
||||
|
||||
def get_function_by_name(tools: List[BaseTool], name: str) -> BaseTool:
|
||||
"""Get function by name."""
|
||||
name_to_tool = {tool.metadata.name: tool for tool in tools}
|
||||
if name not in name_to_tool:
|
||||
raise ValueError(f"Tool with name {name} not found")
|
||||
return name_to_tool[name]
|
||||
|
||||
|
||||
def resolve_tool_choice(tool_choice: Union[str, dict] = "auto") -> Union[str, dict]:
|
||||
"""Resolve tool choice.
|
||||
|
||||
If tool_choice is a function name string, return the appropriate dict.
|
||||
"""
|
||||
if isinstance(tool_choice, str) and tool_choice not in ["none", "auto"]:
|
||||
return {"type": "function", "function": {"name": tool_choice}}
|
||||
|
||||
return tool_choice
|
||||
|
|
@ -0,0 +1,554 @@
|
|||
"""OpenAI Assistant Agent."""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from llama_index.agent.openai.utils import get_function_by_name
|
||||
from llama_index.agent.types import BaseAgent
|
||||
from llama_index.callbacks import (
|
||||
CallbackManager,
|
||||
CBEventType,
|
||||
EventPayload,
|
||||
trace_method,
|
||||
)
|
||||
from llama_index.chat_engine.types import (
|
||||
AGENT_CHAT_RESPONSE_TYPE,
|
||||
AgentChatResponse,
|
||||
ChatResponseMode,
|
||||
StreamingAgentChatResponse,
|
||||
)
|
||||
from llama_index.core.llms.types import ChatMessage, MessageRole
|
||||
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def from_openai_thread_message(thread_message: Any) -> ChatMessage:
|
||||
"""From OpenAI thread message."""
|
||||
from openai.types.beta.threads import MessageContentText, ThreadMessage
|
||||
|
||||
thread_message = cast(ThreadMessage, thread_message)
|
||||
|
||||
# we don't have a way of showing images, just do text for now
|
||||
text_contents = [
|
||||
t for t in thread_message.content if isinstance(t, MessageContentText)
|
||||
]
|
||||
text_content_str = " ".join([t.text.value for t in text_contents])
|
||||
|
||||
return ChatMessage(
|
||||
role=thread_message.role,
|
||||
content=text_content_str,
|
||||
additional_kwargs={
|
||||
"thread_message": thread_message,
|
||||
"thread_id": thread_message.thread_id,
|
||||
"assistant_id": thread_message.assistant_id,
|
||||
"id": thread_message.id,
|
||||
"metadata": thread_message.metadata,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def from_openai_thread_messages(thread_messages: List[Any]) -> List[ChatMessage]:
|
||||
"""From OpenAI thread messages."""
|
||||
return [
|
||||
from_openai_thread_message(thread_message) for thread_message in thread_messages
|
||||
]
|
||||
|
||||
|
||||
def call_function(
|
||||
tools: List[BaseTool], fn_obj: Any, verbose: bool = False
|
||||
) -> Tuple[ChatMessage, ToolOutput]:
|
||||
"""Call a function and return the output as a string."""
|
||||
from openai.types.beta.threads.required_action_function_tool_call import Function
|
||||
|
||||
fn_obj = cast(Function, fn_obj)
|
||||
# TMP: consolidate with other abstractions
|
||||
name = fn_obj.name
|
||||
arguments_str = fn_obj.arguments
|
||||
if verbose:
|
||||
print("=== Calling Function ===")
|
||||
print(f"Calling function: {name} with args: {arguments_str}")
|
||||
tool = get_function_by_name(tools, name)
|
||||
argument_dict = json.loads(arguments_str)
|
||||
output = tool(**argument_dict)
|
||||
if verbose:
|
||||
print(f"Got output: {output!s}")
|
||||
print("========================")
|
||||
return (
|
||||
ChatMessage(
|
||||
content=str(output),
|
||||
role=MessageRole.FUNCTION,
|
||||
additional_kwargs={
|
||||
"name": fn_obj.name,
|
||||
},
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
|
||||
async def acall_function(
|
||||
tools: List[BaseTool], fn_obj: Any, verbose: bool = False
|
||||
) -> Tuple[ChatMessage, ToolOutput]:
|
||||
"""Call an async function and return the output as a string."""
|
||||
from openai.types.beta.threads.required_action_function_tool_call import Function
|
||||
|
||||
fn_obj = cast(Function, fn_obj)
|
||||
# TMP: consolidate with other abstractions
|
||||
name = fn_obj.name
|
||||
arguments_str = fn_obj.arguments
|
||||
if verbose:
|
||||
print("=== Calling Function ===")
|
||||
print(f"Calling function: {name} with args: {arguments_str}")
|
||||
tool = get_function_by_name(tools, name)
|
||||
argument_dict = json.loads(arguments_str)
|
||||
async_tool = adapt_to_async_tool(tool)
|
||||
output = await async_tool.acall(**argument_dict)
|
||||
if verbose:
|
||||
print(f"Got output: {output!s}")
|
||||
print("========================")
|
||||
return (
|
||||
ChatMessage(
|
||||
content=str(output),
|
||||
role=MessageRole.FUNCTION,
|
||||
additional_kwargs={
|
||||
"name": fn_obj.name,
|
||||
},
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
|
||||
def _process_files(client: Any, files: List[str]) -> Dict[str, str]:
|
||||
"""Process files."""
|
||||
from openai import OpenAI
|
||||
|
||||
client = cast(OpenAI, client)
|
||||
|
||||
file_dict = {}
|
||||
for file in files:
|
||||
file_obj = client.files.create(file=open(file, "rb"), purpose="assistants")
|
||||
file_dict[file_obj.id] = file
|
||||
return file_dict
|
||||
|
||||
|
||||
class OpenAIAssistantAgent(BaseAgent):
|
||||
"""OpenAIAssistant agent.
|
||||
|
||||
Wrapper around OpenAI assistant API: https://platform.openai.com/docs/assistants/overview
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Any,
|
||||
assistant: Any,
|
||||
tools: Optional[List[BaseTool]],
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
instructions_prefix: Optional[str] = None,
|
||||
run_retrieve_sleep_time: float = 0.1,
|
||||
file_dict: Dict[str, str] = {},
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
from openai import OpenAI
|
||||
from openai.types.beta.assistant import Assistant
|
||||
|
||||
self._client = cast(OpenAI, client)
|
||||
self._assistant = cast(Assistant, assistant)
|
||||
self._tools = tools or []
|
||||
if thread_id is None:
|
||||
thread = self._client.beta.threads.create()
|
||||
thread_id = thread.id
|
||||
self._thread_id = thread_id
|
||||
self._instructions_prefix = instructions_prefix
|
||||
self._run_retrieve_sleep_time = run_retrieve_sleep_time
|
||||
self._verbose = verbose
|
||||
self.file_dict = file_dict
|
||||
|
||||
self.callback_manager = callback_manager or CallbackManager([])
|
||||
|
||||
@classmethod
|
||||
def from_new(
|
||||
cls,
|
||||
name: str,
|
||||
instructions: str,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
openai_tools: Optional[List[Dict]] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
model: str = "gpt-4-1106-preview",
|
||||
instructions_prefix: Optional[str] = None,
|
||||
run_retrieve_sleep_time: float = 0.1,
|
||||
files: Optional[List[str]] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
file_ids: Optional[List[str]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> "OpenAIAssistantAgent":
|
||||
"""From new assistant.
|
||||
|
||||
Args:
|
||||
name: name of assistant
|
||||
instructions: instructions for assistant
|
||||
tools: list of tools
|
||||
openai_tools: list of openai tools
|
||||
thread_id: thread id
|
||||
model: model
|
||||
run_retrieve_sleep_time: run retrieve sleep time
|
||||
files: files
|
||||
instructions_prefix: instructions prefix
|
||||
callback_manager: callback manager
|
||||
verbose: verbose
|
||||
file_ids: list of file ids
|
||||
api_key: OpenAI API key
|
||||
|
||||
"""
|
||||
from openai import OpenAI
|
||||
|
||||
# this is the set of openai tools
|
||||
# not to be confused with the tools we pass in for function calling
|
||||
openai_tools = openai_tools or []
|
||||
tools = tools or []
|
||||
tool_fns = [t.metadata.to_openai_tool() for t in tools]
|
||||
all_openai_tools = openai_tools + tool_fns
|
||||
|
||||
# initialize client
|
||||
client = OpenAI(api_key=api_key)
|
||||
|
||||
# process files
|
||||
files = files or []
|
||||
file_ids = file_ids or []
|
||||
|
||||
file_dict = _process_files(client, files)
|
||||
all_file_ids = list(file_dict.keys()) + file_ids
|
||||
|
||||
# TODO: openai's typing is a bit sus
|
||||
all_openai_tools = cast(List[Any], all_openai_tools)
|
||||
assistant = client.beta.assistants.create(
|
||||
name=name,
|
||||
instructions=instructions,
|
||||
tools=cast(List[Any], all_openai_tools),
|
||||
model=model,
|
||||
file_ids=all_file_ids,
|
||||
)
|
||||
return cls(
|
||||
client,
|
||||
assistant,
|
||||
tools,
|
||||
callback_manager=callback_manager,
|
||||
thread_id=thread_id,
|
||||
instructions_prefix=instructions_prefix,
|
||||
file_dict=file_dict,
|
||||
run_retrieve_sleep_time=run_retrieve_sleep_time,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_existing(
|
||||
cls,
|
||||
assistant_id: str,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
instructions_prefix: Optional[str] = None,
|
||||
run_retrieve_sleep_time: float = 0.1,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
api_key: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
) -> "OpenAIAssistantAgent":
|
||||
"""From existing assistant id.
|
||||
|
||||
Args:
|
||||
assistant_id: id of assistant
|
||||
tools: list of BaseTools Assistant can use
|
||||
thread_id: thread id
|
||||
run_retrieve_sleep_time: run retrieve sleep time
|
||||
instructions_prefix: instructions prefix
|
||||
callback_manager: callback manager
|
||||
api_key: OpenAI API key
|
||||
verbose: verbose
|
||||
|
||||
"""
|
||||
from openai import OpenAI
|
||||
|
||||
# initialize client
|
||||
client = OpenAI(api_key=api_key)
|
||||
|
||||
# get assistant
|
||||
assistant = client.beta.assistants.retrieve(assistant_id)
|
||||
# assistant.tools is incompatible with BaseTools so have to pass from params
|
||||
|
||||
return cls(
|
||||
client,
|
||||
assistant,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
thread_id=thread_id,
|
||||
instructions_prefix=instructions_prefix,
|
||||
run_retrieve_sleep_time=run_retrieve_sleep_time,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@property
|
||||
def assistant(self) -> Any:
|
||||
"""Get assistant."""
|
||||
return self._assistant
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
"""Get client."""
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def thread_id(self) -> str:
|
||||
"""Get thread id."""
|
||||
return self._thread_id
|
||||
|
||||
@property
|
||||
def files_dict(self) -> Dict[str, str]:
|
||||
"""Get files dict."""
|
||||
return self.file_dict
|
||||
|
||||
@property
|
||||
def chat_history(self) -> List[ChatMessage]:
|
||||
raw_messages = self._client.beta.threads.messages.list(
|
||||
thread_id=self._thread_id, order="asc"
|
||||
)
|
||||
return from_openai_thread_messages(list(raw_messages))
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Delete and create a new thread."""
|
||||
self._client.beta.threads.delete(self._thread_id)
|
||||
thread = self._client.beta.threads.create()
|
||||
thread_id = thread.id
|
||||
self._thread_id = thread_id
|
||||
|
||||
def get_tools(self, message: str) -> List[BaseTool]:
|
||||
"""Get tools."""
|
||||
return self._tools
|
||||
|
||||
def upload_files(self, files: List[str]) -> Dict[str, Any]:
|
||||
"""Upload files."""
|
||||
return _process_files(self._client, files)
|
||||
|
||||
def add_message(self, message: str, file_ids: Optional[List[str]] = None) -> Any:
|
||||
"""Add message to assistant."""
|
||||
file_ids = file_ids or []
|
||||
return self._client.beta.threads.messages.create(
|
||||
thread_id=self._thread_id,
|
||||
role="user",
|
||||
content=message,
|
||||
file_ids=file_ids,
|
||||
)
|
||||
|
||||
def _run_function_calling(self, run: Any) -> List[ToolOutput]:
|
||||
"""Run function calling."""
|
||||
tool_calls = run.required_action.submit_tool_outputs.tool_calls
|
||||
tool_output_dicts = []
|
||||
tool_output_objs: List[ToolOutput] = []
|
||||
for tool_call in tool_calls:
|
||||
fn_obj = tool_call.function
|
||||
_, tool_output = call_function(self._tools, fn_obj, verbose=self._verbose)
|
||||
tool_output_dicts.append(
|
||||
{"tool_call_id": tool_call.id, "output": str(tool_output)}
|
||||
)
|
||||
tool_output_objs.append(tool_output)
|
||||
|
||||
# submit tool outputs
|
||||
# TODO: openai's typing is a bit sus
|
||||
self._client.beta.threads.runs.submit_tool_outputs(
|
||||
thread_id=self._thread_id,
|
||||
run_id=run.id,
|
||||
tool_outputs=cast(List[Any], tool_output_dicts),
|
||||
)
|
||||
return tool_output_objs
|
||||
|
||||
async def _arun_function_calling(self, run: Any) -> List[ToolOutput]:
|
||||
"""Run function calling."""
|
||||
tool_calls = run.required_action.submit_tool_outputs.tool_calls
|
||||
tool_output_dicts = []
|
||||
tool_output_objs: List[ToolOutput] = []
|
||||
for tool_call in tool_calls:
|
||||
fn_obj = tool_call.function
|
||||
_, tool_output = await acall_function(
|
||||
self._tools, fn_obj, verbose=self._verbose
|
||||
)
|
||||
tool_output_dicts.append(
|
||||
{"tool_call_id": tool_call.id, "output": str(tool_output)}
|
||||
)
|
||||
tool_output_objs.append(tool_output)
|
||||
|
||||
# submit tool outputs
|
||||
self._client.beta.threads.runs.submit_tool_outputs(
|
||||
thread_id=self._thread_id,
|
||||
run_id=run.id,
|
||||
tool_outputs=cast(List[Any], tool_output_dicts),
|
||||
)
|
||||
return tool_output_objs
|
||||
|
||||
def run_assistant(
|
||||
self, instructions_prefix: Optional[str] = None
|
||||
) -> Tuple[Any, Dict]:
|
||||
"""Run assistant."""
|
||||
instructions_prefix = instructions_prefix or self._instructions_prefix
|
||||
run = self._client.beta.threads.runs.create(
|
||||
thread_id=self._thread_id,
|
||||
assistant_id=self._assistant.id,
|
||||
instructions=instructions_prefix,
|
||||
)
|
||||
from openai.types.beta.threads import Run
|
||||
|
||||
run = cast(Run, run)
|
||||
|
||||
sources = []
|
||||
|
||||
while run.status in ["queued", "in_progress", "requires_action"]:
|
||||
run = self._client.beta.threads.runs.retrieve(
|
||||
thread_id=self._thread_id, run_id=run.id
|
||||
)
|
||||
if run.status == "requires_action":
|
||||
cur_tool_outputs = self._run_function_calling(run)
|
||||
sources.extend(cur_tool_outputs)
|
||||
|
||||
time.sleep(self._run_retrieve_sleep_time)
|
||||
if run.status == "failed":
|
||||
raise ValueError(
|
||||
f"Run failed with status {run.status}.\n" f"Error: {run.last_error}"
|
||||
)
|
||||
return run, {"sources": sources}
|
||||
|
||||
async def arun_assistant(
|
||||
self, instructions_prefix: Optional[str] = None
|
||||
) -> Tuple[Any, Dict]:
|
||||
"""Run assistant."""
|
||||
instructions_prefix = instructions_prefix or self._instructions_prefix
|
||||
run = self._client.beta.threads.runs.create(
|
||||
thread_id=self._thread_id,
|
||||
assistant_id=self._assistant.id,
|
||||
instructions=instructions_prefix,
|
||||
)
|
||||
from openai.types.beta.threads import Run
|
||||
|
||||
run = cast(Run, run)
|
||||
|
||||
sources = []
|
||||
|
||||
while run.status in ["queued", "in_progress", "requires_action"]:
|
||||
run = self._client.beta.threads.runs.retrieve(
|
||||
thread_id=self._thread_id, run_id=run.id
|
||||
)
|
||||
if run.status == "requires_action":
|
||||
cur_tool_outputs = await self._arun_function_calling(run)
|
||||
sources.extend(cur_tool_outputs)
|
||||
|
||||
await asyncio.sleep(self._run_retrieve_sleep_time)
|
||||
if run.status == "failed":
|
||||
raise ValueError(
|
||||
f"Run failed with status {run.status}.\n" f"Error: {run.last_error}"
|
||||
)
|
||||
return run, {"sources": sources}
|
||||
|
||||
@property
|
||||
def latest_message(self) -> ChatMessage:
|
||||
"""Get latest message."""
|
||||
raw_messages = self._client.beta.threads.messages.list(
|
||||
thread_id=self._thread_id, order="desc"
|
||||
)
|
||||
messages = from_openai_thread_messages(list(raw_messages))
|
||||
return messages[0]
|
||||
|
||||
def _chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
function_call: Union[str, dict] = "auto",
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Main chat interface."""
|
||||
# TODO: since chat interface doesn't expose additional kwargs
|
||||
# we can't pass in file_ids per message
|
||||
added_message_obj = self.add_message(message)
|
||||
run, metadata = self.run_assistant(
|
||||
instructions_prefix=self._instructions_prefix,
|
||||
)
|
||||
latest_message = self.latest_message
|
||||
# get most recent message content
|
||||
return AgentChatResponse(
|
||||
response=str(latest_message.content),
|
||||
sources=metadata["sources"],
|
||||
)
|
||||
|
||||
async def _achat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
function_call: Union[str, dict] = "auto",
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Asynchronous main chat interface."""
|
||||
self.add_message(message)
|
||||
run, metadata = await self.arun_assistant(
|
||||
instructions_prefix=self._instructions_prefix,
|
||||
)
|
||||
latest_message = self.latest_message
|
||||
# get most recent message content
|
||||
return AgentChatResponse(
|
||||
response=str(latest_message.content),
|
||||
sources=metadata["sources"],
|
||||
)
|
||||
|
||||
@trace_method("chat")
|
||||
def chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
function_call: Union[str, dict] = "auto",
|
||||
) -> AgentChatResponse:
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = self._chat(
|
||||
message, chat_history, function_call, mode=ChatResponseMode.WAIT
|
||||
)
|
||||
assert isinstance(chat_response, AgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
async def achat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
function_call: Union[str, dict] = "auto",
|
||||
) -> AgentChatResponse:
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = await self._achat(
|
||||
message, chat_history, function_call, mode=ChatResponseMode.WAIT
|
||||
)
|
||||
assert isinstance(chat_response, AgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
def stream_chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
function_call: Union[str, dict] = "auto",
|
||||
) -> StreamingAgentChatResponse:
|
||||
raise NotImplementedError("stream_chat not implemented")
|
||||
|
||||
@trace_method("chat")
|
||||
async def astream_chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
function_call: Union[str, dict] = "auto",
|
||||
) -> StreamingAgentChatResponse:
|
||||
raise NotImplementedError("astream_chat not implemented")
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
from llama_index.agent.react.base import ReActAgent
|
||||
from llama_index.agent.react.formatter import ReActChatFormatter
|
||||
from llama_index.agent.react.step import ReActAgentWorker
|
||||
|
||||
__all__ = ["ReActChatFormatter", "ReActAgentWorker", "ReActAgent"]
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
"""ReAct agent.
|
||||
|
||||
Simple wrapper around AgentRunner + ReActAgentWorker.
|
||||
|
||||
For the legacy implementation see:
|
||||
```python
|
||||
from llama_index.agent.legacy.react.base import ReActAgent
|
||||
```
|
||||
|
||||
"""
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
"""ReAct agent.
|
||||
|
||||
Simple wrapper around AgentRunner + ReActAgentWorker.
|
||||
|
||||
For the legacy implementation see:
|
||||
```python
|
||||
from llama_index.agent.legacy.react.base import ReActAgent
|
||||
```
|
||||
|
||||
"""
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
)
|
||||
|
||||
from llama_index.agent.react.formatter import ReActChatFormatter
|
||||
from llama_index.agent.react.output_parser import ReActOutputParser
|
||||
from llama_index.agent.react.step import ReActAgentWorker
|
||||
from llama_index.agent.runner.base import AgentRunner
|
||||
from llama_index.callbacks import (
|
||||
CallbackManager,
|
||||
)
|
||||
from llama_index.core.llms.types import ChatMessage
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
|
||||
from llama_index.memory.types import BaseMemory
|
||||
from llama_index.objects.base import ObjectRetriever
|
||||
from llama_index.prompts.mixin import PromptMixinType
|
||||
from llama_index.tools import BaseTool
|
||||
|
||||
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
|
||||
|
||||
|
||||
class ReActAgent(AgentRunner):
|
||||
"""ReAct agent.
|
||||
|
||||
Subclasses AgentRunner with a ReActAgentWorker.
|
||||
|
||||
For the legacy implementation see:
|
||||
```python
|
||||
from llama_index.agent.legacy.react.base import ReActAgent
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: Sequence[BaseTool],
|
||||
llm: LLM,
|
||||
memory: BaseMemory,
|
||||
max_iterations: int = 10,
|
||||
react_chat_formatter: Optional[ReActChatFormatter] = None,
|
||||
output_parser: Optional[ReActOutputParser] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
context: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
callback_manager = callback_manager or llm.callback_manager
|
||||
if context and react_chat_formatter:
|
||||
raise ValueError("Cannot provide both context and react_chat_formatter")
|
||||
if context:
|
||||
react_chat_formatter = ReActChatFormatter.from_context(context)
|
||||
|
||||
step_engine = ReActAgentWorker.from_tools(
|
||||
tools=tools,
|
||||
tool_retriever=tool_retriever,
|
||||
llm=llm,
|
||||
max_iterations=max_iterations,
|
||||
react_chat_formatter=react_chat_formatter,
|
||||
output_parser=output_parser,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
)
|
||||
super().__init__(
|
||||
step_engine,
|
||||
memory=memory,
|
||||
llm=llm,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tools(
|
||||
cls,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
llm: Optional[LLM] = None,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
|
||||
max_iterations: int = 10,
|
||||
react_chat_formatter: Optional[ReActChatFormatter] = None,
|
||||
output_parser: Optional[ReActOutputParser] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
context: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> "ReActAgent":
|
||||
"""Convenience constructor method from set of of BaseTools (Optional).
|
||||
|
||||
NOTE: kwargs should have been exhausted by this point. In other words
|
||||
the various upstream components such as BaseSynthesizer (response synthesizer)
|
||||
or BaseRetriever should have picked up off their respective kwargs in their
|
||||
constructions.
|
||||
|
||||
Returns:
|
||||
ReActAgent
|
||||
"""
|
||||
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
|
||||
if callback_manager is not None:
|
||||
llm.callback_manager = callback_manager
|
||||
memory = memory or memory_cls.from_defaults(
|
||||
chat_history=chat_history or [], llm=llm
|
||||
)
|
||||
return cls(
|
||||
tools=tools or [],
|
||||
tool_retriever=tool_retriever,
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
max_iterations=max_iterations,
|
||||
react_chat_formatter=react_chat_formatter,
|
||||
output_parser=output_parser,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _get_prompt_modules(self) -> PromptMixinType:
|
||||
"""Get prompt modules."""
|
||||
return {"agent_worker": self.agent_worker}
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
# ReAct agent formatter
|
||||
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from llama_index.agent.react.prompts import (
|
||||
CONTEXT_REACT_CHAT_SYSTEM_HEADER,
|
||||
REACT_CHAT_SYSTEM_HEADER,
|
||||
)
|
||||
from llama_index.agent.react.types import BaseReasoningStep, ObservationReasoningStep
|
||||
from llama_index.bridge.pydantic import BaseModel
|
||||
from llama_index.core.llms.types import ChatMessage, MessageRole
|
||||
from llama_index.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_react_tool_descriptions(tools: Sequence[BaseTool]) -> List[str]:
|
||||
"""Tool."""
|
||||
tool_descs = []
|
||||
for tool in tools:
|
||||
tool_desc = (
|
||||
f"> Tool Name: {tool.metadata.name}\n"
|
||||
f"Tool Description: {tool.metadata.description}\n"
|
||||
f"Tool Args: {tool.metadata.fn_schema_str}\n"
|
||||
)
|
||||
tool_descs.append(tool_desc)
|
||||
return tool_descs
|
||||
|
||||
|
||||
# TODO: come up with better name
|
||||
class BaseAgentChatFormatter(BaseModel):
|
||||
"""Base chat formatter."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def format(
|
||||
self,
|
||||
tools: Sequence[BaseTool],
|
||||
chat_history: List[ChatMessage],
|
||||
current_reasoning: Optional[List[BaseReasoningStep]] = None,
|
||||
) -> List[ChatMessage]:
|
||||
"""Format chat history into list of ChatMessage."""
|
||||
|
||||
|
||||
class ReActChatFormatter(BaseAgentChatFormatter):
|
||||
"""ReAct chat formatter."""
|
||||
|
||||
system_header: str = REACT_CHAT_SYSTEM_HEADER # default
|
||||
context: str = "" # not needed w/ default
|
||||
|
||||
def format(
|
||||
self,
|
||||
tools: Sequence[BaseTool],
|
||||
chat_history: List[ChatMessage],
|
||||
current_reasoning: Optional[List[BaseReasoningStep]] = None,
|
||||
) -> List[ChatMessage]:
|
||||
"""Format chat history into list of ChatMessage."""
|
||||
current_reasoning = current_reasoning or []
|
||||
|
||||
format_args = {
|
||||
"tool_desc": "\n".join(get_react_tool_descriptions(tools)),
|
||||
"tool_names": ", ".join([tool.metadata.get_name() for tool in tools]),
|
||||
}
|
||||
if self.context:
|
||||
format_args["context"] = self.context
|
||||
|
||||
fmt_sys_header = self.system_header.format(**format_args)
|
||||
|
||||
# format reasoning history as alternating user and assistant messages
|
||||
# where the assistant messages are thoughts and actions and the user
|
||||
# messages are observations
|
||||
reasoning_history = []
|
||||
for reasoning_step in current_reasoning:
|
||||
if isinstance(reasoning_step, ObservationReasoningStep):
|
||||
message = ChatMessage(
|
||||
role=MessageRole.USER,
|
||||
content=reasoning_step.get_content(),
|
||||
)
|
||||
else:
|
||||
message = ChatMessage(
|
||||
role=MessageRole.ASSISTANT,
|
||||
content=reasoning_step.get_content(),
|
||||
)
|
||||
reasoning_history.append(message)
|
||||
|
||||
return [
|
||||
ChatMessage(role=MessageRole.SYSTEM, content=fmt_sys_header),
|
||||
*chat_history,
|
||||
*reasoning_history,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_defaults(
|
||||
cls,
|
||||
system_header: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
) -> "ReActChatFormatter":
|
||||
"""Create ReActChatFormatter from defaults."""
|
||||
if not system_header:
|
||||
system_header = (
|
||||
REACT_CHAT_SYSTEM_HEADER
|
||||
if not context
|
||||
else CONTEXT_REACT_CHAT_SYSTEM_HEADER
|
||||
)
|
||||
|
||||
return ReActChatFormatter(
|
||||
system_header=system_header,
|
||||
context=context or "",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_context(cls, context: str) -> "ReActChatFormatter":
|
||||
"""Create ReActChatFormatter from context.
|
||||
|
||||
NOTE: deprecated
|
||||
|
||||
"""
|
||||
logger.warning(
|
||||
"ReActChatFormatter.from_context is deprecated, please use `from_defaults` instead."
|
||||
)
|
||||
return ReActChatFormatter.from_defaults(
|
||||
system_header=CONTEXT_REACT_CHAT_SYSTEM_HEADER, context=context
|
||||
)
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
"""ReAct output parser."""
|
||||
|
||||
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
from llama_index.agent.react.types import (
|
||||
ActionReasoningStep,
|
||||
BaseReasoningStep,
|
||||
ResponseReasoningStep,
|
||||
)
|
||||
from llama_index.output_parsers.utils import extract_json_str
|
||||
from llama_index.types import BaseOutputParser
|
||||
|
||||
|
||||
def extract_tool_use(input_text: str) -> Tuple[str, str, str]:
|
||||
pattern = (
|
||||
r"\s*Thought: (.*?)\nAction: ([a-zA-Z0-9_]+).*?\nAction Input: .*?(\{.*\})"
|
||||
)
|
||||
|
||||
match = re.search(pattern, input_text, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError(f"Could not extract tool use from input text: {input_text}")
|
||||
|
||||
thought = match.group(1).strip()
|
||||
action = match.group(2).strip()
|
||||
action_input = match.group(3).strip()
|
||||
return thought, action, action_input
|
||||
|
||||
|
||||
def action_input_parser(json_str: str) -> dict:
|
||||
processed_string = re.sub(r"(?<!\w)\'|\'(?!\w)", '"', json_str)
|
||||
pattern = r'"(\w+)":\s*"([^"]*)"'
|
||||
matches = re.findall(pattern, processed_string)
|
||||
return dict(matches)
|
||||
|
||||
|
||||
def extract_final_response(input_text: str) -> Tuple[str, str]:
|
||||
pattern = r"\s*Thought:(.*?)Answer:(.*?)(?:$)"
|
||||
|
||||
match = re.search(pattern, input_text, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Could not extract final answer from input text: {input_text}"
|
||||
)
|
||||
|
||||
thought = match.group(1).strip()
|
||||
answer = match.group(2).strip()
|
||||
return thought, answer
|
||||
|
||||
|
||||
def parse_action_reasoning_step(output: str) -> ActionReasoningStep:
|
||||
"""
|
||||
Parse an action reasoning step from the LLM output.
|
||||
"""
|
||||
# Weaker LLMs may generate ReActAgent steps whose Action Input are horrible JSON strings.
|
||||
# `dirtyjson` is more lenient than `json` in parsing JSON strings.
|
||||
import dirtyjson as json
|
||||
|
||||
thought, action, action_input = extract_tool_use(output)
|
||||
json_str = extract_json_str(action_input)
|
||||
# First we try json, if this fails we use ast
|
||||
try:
|
||||
action_input_dict = json.loads(json_str)
|
||||
except Exception:
|
||||
action_input_dict = action_input_parser(json_str)
|
||||
return ActionReasoningStep(
|
||||
thought=thought, action=action, action_input=action_input_dict
|
||||
)
|
||||
|
||||
|
||||
class ReActOutputParser(BaseOutputParser):
|
||||
"""ReAct Output parser."""
|
||||
|
||||
def parse(self, output: str, is_streaming: bool = False) -> BaseReasoningStep:
|
||||
"""Parse output from ReAct agent.
|
||||
|
||||
We expect the output to be in one of the following formats:
|
||||
1. If the agent need to use a tool to answer the question:
|
||||
```
|
||||
Thought: <thought>
|
||||
Action: <action>
|
||||
Action Input: <action_input>
|
||||
```
|
||||
2. If the agent can answer the question without any tools:
|
||||
```
|
||||
Thought: <thought>
|
||||
Answer: <answer>
|
||||
```
|
||||
"""
|
||||
if "Thought:" not in output:
|
||||
# NOTE: handle the case where the agent directly outputs the answer
|
||||
# instead of following the thought-answer format
|
||||
return ResponseReasoningStep(
|
||||
thought="(Implicit) I can answer without any more tools!",
|
||||
response=output,
|
||||
is_streaming=is_streaming,
|
||||
)
|
||||
|
||||
if "Answer:" in output:
|
||||
thought, answer = extract_final_response(output)
|
||||
return ResponseReasoningStep(
|
||||
thought=thought, response=answer, is_streaming=is_streaming
|
||||
)
|
||||
|
||||
if "Action:" in output:
|
||||
return parse_action_reasoning_step(output)
|
||||
|
||||
raise ValueError(f"Could not parse output: {output}")
|
||||
|
||||
def format(self, output: str) -> str:
|
||||
"""Format a query with structured output formatting instructions."""
|
||||
raise NotImplementedError
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
"""Default prompt for ReAct agent."""
|
||||
|
||||
|
||||
# ReAct chat prompt
|
||||
# TODO: have formatting instructions be a part of react output parser
|
||||
|
||||
REACT_CHAT_SYSTEM_HEADER = """\
|
||||
|
||||
You are designed to help with a variety of tasks, from answering questions \
|
||||
to providing summaries to other types of analyses.
|
||||
|
||||
## Tools
|
||||
You have access to a wide variety of tools. You are responsible for using
|
||||
the tools in any sequence you deem appropriate to complete the task at hand.
|
||||
This may require breaking the task into subtasks and using different tools
|
||||
to complete each subtask.
|
||||
|
||||
You have access to the following tools:
|
||||
{tool_desc}
|
||||
|
||||
## Output Format
|
||||
To answer the question, please use the following format.
|
||||
|
||||
```
|
||||
Thought: I need to use a tool to help me answer the question.
|
||||
Action: tool name (one of {tool_names}) if using a tool.
|
||||
Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. {{"input": "hello world", "num_beams": 5}})
|
||||
```
|
||||
|
||||
Please ALWAYS start with a Thought.
|
||||
|
||||
Please use a valid JSON format for the Action Input. Do NOT do this {{'input': 'hello world', 'num_beams': 5}}.
|
||||
|
||||
If this format is used, the user will respond in the following format:
|
||||
|
||||
```
|
||||
Observation: tool response
|
||||
```
|
||||
|
||||
You should keep repeating the above format until you have enough information
|
||||
to answer the question without using any more tools. At that point, you MUST respond
|
||||
in the one of the following two formats:
|
||||
|
||||
```
|
||||
Thought: I can answer without using any more tools.
|
||||
Answer: [your answer here]
|
||||
```
|
||||
|
||||
```
|
||||
Thought: I cannot answer the question with the provided tools.
|
||||
Answer: Sorry, I cannot answer your query.
|
||||
```
|
||||
|
||||
## Current Conversation
|
||||
Below is the current conversation consisting of interleaving human and assistant messages.
|
||||
|
||||
"""
|
||||
|
||||
CONTEXT_REACT_CHAT_SYSTEM_HEADER = """\
|
||||
|
||||
You are designed to help with a variety of tasks, from answering questions \
|
||||
to providing summaries to other types of analyses.
|
||||
|
||||
## Tools
|
||||
You have access to a wide variety of tools. You are responsible for using
|
||||
the tools in any sequence you deem appropriate to complete the task at hand.
|
||||
This may require breaking the task into subtasks and using different tools
|
||||
to complete each subtask.
|
||||
|
||||
Here is some context to help you answer the question and plan:
|
||||
{context}
|
||||
|
||||
You have access to the following tools:
|
||||
{tool_desc}
|
||||
|
||||
## Output Format
|
||||
To answer the question, please use the following format.
|
||||
|
||||
```
|
||||
Thought: I need to use a tool to help me answer the question.
|
||||
Action: tool name (one of {tool_names}) if using a tool.
|
||||
Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. {{"input": "hello world", "num_beams": 5}})
|
||||
```
|
||||
|
||||
Please ALWAYS start with a Thought.
|
||||
|
||||
Please use a valid JSON format for the Action Input. Do NOT do this {{'input': 'hello world', 'num_beams': 5}}.
|
||||
|
||||
If this format is used, the user will respond in the following format:
|
||||
|
||||
```
|
||||
Observation: tool response
|
||||
```
|
||||
|
||||
You should keep repeating the above format until you have enough information
|
||||
to answer the question without using any more tools. At that point, you MUST respond
|
||||
in the one of the following two formats:
|
||||
|
||||
```
|
||||
Thought: I can answer without using any more tools.
|
||||
Answer: [your answer here]
|
||||
```
|
||||
|
||||
```
|
||||
Thought: I cannot answer the question with the provided tools.
|
||||
Answer: Sorry, I cannot answer your query.
|
||||
```
|
||||
|
||||
## Current Conversation
|
||||
Below is the current conversation consisting of interleaving human and assistant messages.
|
||||
|
||||
"""
|
||||
|
|
@ -0,0 +1,640 @@
|
|||
"""ReAct agent worker."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from itertools import chain
|
||||
from threading import Thread
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
from llama_index.agent.react.formatter import ReActChatFormatter
|
||||
from llama_index.agent.react.output_parser import ReActOutputParser
|
||||
from llama_index.agent.react.types import (
|
||||
ActionReasoningStep,
|
||||
BaseReasoningStep,
|
||||
ObservationReasoningStep,
|
||||
ResponseReasoningStep,
|
||||
)
|
||||
from llama_index.agent.types import (
|
||||
BaseAgentWorker,
|
||||
Task,
|
||||
TaskStep,
|
||||
TaskStepOutput,
|
||||
)
|
||||
from llama_index.callbacks import (
|
||||
CallbackManager,
|
||||
CBEventType,
|
||||
EventPayload,
|
||||
trace_method,
|
||||
)
|
||||
from llama_index.chat_engine.types import (
|
||||
AGENT_CHAT_RESPONSE_TYPE,
|
||||
AgentChatResponse,
|
||||
StreamingAgentChatResponse,
|
||||
)
|
||||
from llama_index.core.llms.types import MessageRole
|
||||
from llama_index.llms.base import ChatMessage, ChatResponse
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
|
||||
from llama_index.memory.types import BaseMemory
|
||||
from llama_index.objects.base import ObjectRetriever
|
||||
from llama_index.prompts.base import PromptTemplate
|
||||
from llama_index.prompts.mixin import PromptDictType
|
||||
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
|
||||
from llama_index.tools.types import AsyncBaseTool
|
||||
from llama_index.utils import print_text, unit_generator
|
||||
|
||||
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
|
||||
|
||||
|
||||
def add_user_step_to_reasoning(
|
||||
step: TaskStep,
|
||||
memory: BaseMemory,
|
||||
current_reasoning: List[BaseReasoningStep],
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""Add user step to memory."""
|
||||
if "is_first" in step.step_state and step.step_state["is_first"]:
|
||||
# add to new memory
|
||||
memory.put(ChatMessage(content=step.input, role=MessageRole.USER))
|
||||
step.step_state["is_first"] = False
|
||||
else:
|
||||
reasoning_step = ObservationReasoningStep(observation=step.input)
|
||||
current_reasoning.append(reasoning_step)
|
||||
if verbose:
|
||||
print(f"Added user message to memory: {step.input}")
|
||||
|
||||
|
||||
class ReActAgentWorker(BaseAgentWorker):
|
||||
"""OpenAI Agent worker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: Sequence[BaseTool],
|
||||
llm: LLM,
|
||||
max_iterations: int = 10,
|
||||
react_chat_formatter: Optional[ReActChatFormatter] = None,
|
||||
output_parser: Optional[ReActOutputParser] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
) -> None:
|
||||
self._llm = llm
|
||||
self.callback_manager = callback_manager or llm.callback_manager
|
||||
self._max_iterations = max_iterations
|
||||
self._react_chat_formatter = react_chat_formatter or ReActChatFormatter()
|
||||
self._output_parser = output_parser or ReActOutputParser()
|
||||
self._verbose = verbose
|
||||
|
||||
if len(tools) > 0 and tool_retriever is not None:
|
||||
raise ValueError("Cannot specify both tools and tool_retriever")
|
||||
elif len(tools) > 0:
|
||||
self._get_tools = lambda _: tools
|
||||
elif tool_retriever is not None:
|
||||
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
|
||||
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
|
||||
else:
|
||||
self._get_tools = lambda _: []
|
||||
|
||||
@classmethod
|
||||
def from_tools(
|
||||
cls,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
llm: Optional[LLM] = None,
|
||||
max_iterations: int = 10,
|
||||
react_chat_formatter: Optional[ReActChatFormatter] = None,
|
||||
output_parser: Optional[ReActOutputParser] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "ReActAgentWorker":
|
||||
"""Convenience constructor method from set of of BaseTools (Optional).
|
||||
|
||||
NOTE: kwargs should have been exhausted by this point. In other words
|
||||
the various upstream components such as BaseSynthesizer (response synthesizer)
|
||||
or BaseRetriever should have picked up off their respective kwargs in their
|
||||
constructions.
|
||||
|
||||
Returns:
|
||||
ReActAgent
|
||||
"""
|
||||
llm = llm or OpenAI(model=DEFAULT_MODEL_NAME)
|
||||
if callback_manager is not None:
|
||||
llm.callback_manager = callback_manager
|
||||
return cls(
|
||||
tools=tools or [],
|
||||
tool_retriever=tool_retriever,
|
||||
llm=llm,
|
||||
max_iterations=max_iterations,
|
||||
react_chat_formatter=react_chat_formatter,
|
||||
output_parser=output_parser,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
def _get_prompts(self) -> PromptDictType:
|
||||
"""Get prompts."""
|
||||
# TODO: the ReAct formatter does not explicitly specify PromptTemplate
|
||||
# objects, but wrap it in this to obey the interface
|
||||
sys_header = self._react_chat_formatter.system_header
|
||||
return {"system_prompt": PromptTemplate(sys_header)}
|
||||
|
||||
def _update_prompts(self, prompts: PromptDictType) -> None:
|
||||
"""Update prompts."""
|
||||
if "system_prompt" in prompts:
|
||||
sys_prompt = cast(PromptTemplate, prompts["system_prompt"])
|
||||
self._react_chat_formatter.system_header = sys_prompt.template
|
||||
|
||||
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
|
||||
"""Initialize step from task."""
|
||||
sources: List[ToolOutput] = []
|
||||
current_reasoning: List[BaseReasoningStep] = []
|
||||
# temporary memory for new messages
|
||||
new_memory = ChatMemoryBuffer.from_defaults()
|
||||
|
||||
# initialize task state
|
||||
task_state = {
|
||||
"sources": sources,
|
||||
"current_reasoning": current_reasoning,
|
||||
"new_memory": new_memory,
|
||||
}
|
||||
task.extra_state.update(task_state)
|
||||
|
||||
return TaskStep(
|
||||
task_id=task.task_id,
|
||||
step_id=str(uuid.uuid4()),
|
||||
input=task.input,
|
||||
step_state={"is_first": True},
|
||||
)
|
||||
|
||||
def get_tools(self, input: str) -> List[AsyncBaseTool]:
|
||||
"""Get tools."""
|
||||
return [adapt_to_async_tool(t) for t in self._get_tools(input)]
|
||||
|
||||
def _extract_reasoning_step(
|
||||
self, output: ChatResponse, is_streaming: bool = False
|
||||
) -> Tuple[str, List[BaseReasoningStep], bool]:
|
||||
"""
|
||||
Extracts the reasoning step from the given output.
|
||||
|
||||
This method parses the message content from the output,
|
||||
extracts the reasoning step, and determines whether the processing is
|
||||
complete. It also performs validation checks on the output and
|
||||
handles possible errors.
|
||||
"""
|
||||
if output.message.content is None:
|
||||
raise ValueError("Got empty message.")
|
||||
message_content = output.message.content
|
||||
current_reasoning = []
|
||||
try:
|
||||
reasoning_step = self._output_parser.parse(message_content, is_streaming)
|
||||
except BaseException as exc:
|
||||
raise ValueError(f"Could not parse output: {message_content}") from exc
|
||||
if self._verbose:
|
||||
print_text(f"{reasoning_step.get_content()}\n", color="pink")
|
||||
current_reasoning.append(reasoning_step)
|
||||
|
||||
if reasoning_step.is_done:
|
||||
return message_content, current_reasoning, True
|
||||
|
||||
reasoning_step = cast(ActionReasoningStep, reasoning_step)
|
||||
if not isinstance(reasoning_step, ActionReasoningStep):
|
||||
raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}")
|
||||
|
||||
return message_content, current_reasoning, False
|
||||
|
||||
def _process_actions(
|
||||
self,
|
||||
task: Task,
|
||||
tools: Sequence[AsyncBaseTool],
|
||||
output: ChatResponse,
|
||||
is_streaming: bool = False,
|
||||
) -> Tuple[List[BaseReasoningStep], bool]:
|
||||
tools_dict: Dict[str, AsyncBaseTool] = {
|
||||
tool.metadata.get_name(): tool for tool in tools
|
||||
}
|
||||
_, current_reasoning, is_done = self._extract_reasoning_step(
|
||||
output, is_streaming
|
||||
)
|
||||
|
||||
if is_done:
|
||||
return current_reasoning, True
|
||||
|
||||
# call tool with input
|
||||
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
|
||||
tool = tools_dict[reasoning_step.action]
|
||||
with self.callback_manager.event(
|
||||
CBEventType.FUNCTION_CALL,
|
||||
payload={
|
||||
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
|
||||
EventPayload.TOOL: tool.metadata,
|
||||
},
|
||||
) as event:
|
||||
tool_output = tool.call(**reasoning_step.action_input)
|
||||
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
||||
|
||||
task.extra_state["sources"].append(tool_output)
|
||||
|
||||
observation_step = ObservationReasoningStep(observation=str(tool_output))
|
||||
current_reasoning.append(observation_step)
|
||||
if self._verbose:
|
||||
print_text(f"{observation_step.get_content()}\n", color="blue")
|
||||
return current_reasoning, False
|
||||
|
||||
async def _aprocess_actions(
|
||||
self,
|
||||
task: Task,
|
||||
tools: Sequence[AsyncBaseTool],
|
||||
output: ChatResponse,
|
||||
is_streaming: bool = False,
|
||||
) -> Tuple[List[BaseReasoningStep], bool]:
|
||||
tools_dict = {tool.metadata.name: tool for tool in tools}
|
||||
_, current_reasoning, is_done = self._extract_reasoning_step(
|
||||
output, is_streaming
|
||||
)
|
||||
|
||||
if is_done:
|
||||
return current_reasoning, True
|
||||
|
||||
# call tool with input
|
||||
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
|
||||
tool = tools_dict[reasoning_step.action]
|
||||
with self.callback_manager.event(
|
||||
CBEventType.FUNCTION_CALL,
|
||||
payload={
|
||||
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
|
||||
EventPayload.TOOL: tool.metadata,
|
||||
},
|
||||
) as event:
|
||||
tool_output = await tool.acall(**reasoning_step.action_input)
|
||||
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
||||
|
||||
task.extra_state["sources"].append(tool_output)
|
||||
|
||||
observation_step = ObservationReasoningStep(observation=str(tool_output))
|
||||
current_reasoning.append(observation_step)
|
||||
if self._verbose:
|
||||
print_text(f"{observation_step.get_content()}\n", color="blue")
|
||||
return current_reasoning, False
|
||||
|
||||
def _get_response(
|
||||
self,
|
||||
current_reasoning: List[BaseReasoningStep],
|
||||
sources: List[ToolOutput],
|
||||
) -> AgentChatResponse:
|
||||
"""Get response from reasoning steps."""
|
||||
if len(current_reasoning) == 0:
|
||||
raise ValueError("No reasoning steps were taken.")
|
||||
elif len(current_reasoning) == self._max_iterations:
|
||||
raise ValueError("Reached max iterations.")
|
||||
|
||||
if isinstance(current_reasoning[-1], ResponseReasoningStep):
|
||||
response_step = cast(ResponseReasoningStep, current_reasoning[-1])
|
||||
response_str = response_step.response
|
||||
else:
|
||||
response_str = current_reasoning[-1].get_content()
|
||||
|
||||
# TODO: add sources from reasoning steps
|
||||
return AgentChatResponse(response=response_str, sources=sources)
|
||||
|
||||
def _get_task_step_response(
|
||||
self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
|
||||
) -> TaskStepOutput:
|
||||
"""Get task step response."""
|
||||
if is_done:
|
||||
new_steps = []
|
||||
else:
|
||||
new_steps = [
|
||||
step.get_next_step(
|
||||
step_id=str(uuid.uuid4()),
|
||||
# NOTE: input is unused
|
||||
input=None,
|
||||
)
|
||||
]
|
||||
|
||||
return TaskStepOutput(
|
||||
output=agent_response,
|
||||
task_step=step,
|
||||
is_last=is_done,
|
||||
next_steps=new_steps,
|
||||
)
|
||||
|
||||
def _infer_stream_chunk_is_final(self, chunk: ChatResponse) -> bool:
|
||||
"""Infers if a chunk from a live stream is the start of the final
|
||||
reasoning step. (i.e., and should eventually become
|
||||
ResponseReasoningStep — not part of this function's logic tho.).
|
||||
|
||||
Args:
|
||||
chunk (ChatResponse): the current chunk stream to check
|
||||
|
||||
Returns:
|
||||
bool: Boolean on whether the chunk is the start of the final response
|
||||
"""
|
||||
latest_content = chunk.message.content
|
||||
if latest_content:
|
||||
if not latest_content.startswith(
|
||||
"Thought"
|
||||
): # doesn't follow thought-action format
|
||||
return True
|
||||
else:
|
||||
if "Answer: " in latest_content:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _add_back_chunk_to_stream(
|
||||
self, chunk: ChatResponse, chat_stream: Generator[ChatResponse, None, None]
|
||||
) -> Generator[ChatResponse, None, None]:
|
||||
"""Helper method for adding back initial chunk stream of final response
|
||||
back to the rest of the chat_stream.
|
||||
|
||||
Args:
|
||||
chunk (ChatResponse): the chunk to add back to the beginning of the
|
||||
chat_stream.
|
||||
|
||||
Return:
|
||||
Generator[ChatResponse, None, None]: the updated chat_stream
|
||||
"""
|
||||
updated_stream = chain.from_iterable( # need to add back partial response chunk
|
||||
[
|
||||
unit_generator(chunk),
|
||||
chat_stream,
|
||||
]
|
||||
)
|
||||
# use cast to avoid mypy issue with chain and Generator
|
||||
updated_stream_c: Generator[ChatResponse, None, None] = cast(
|
||||
Generator[ChatResponse, None, None], updated_stream
|
||||
)
|
||||
return updated_stream_c
|
||||
|
||||
async def _async_add_back_chunk_to_stream(
|
||||
self, chunk: ChatResponse, chat_stream: AsyncGenerator[ChatResponse, None]
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Helper method for adding back initial chunk stream of final response
|
||||
back to the rest of the chat_stream.
|
||||
|
||||
NOTE: this itself is not an async function.
|
||||
|
||||
Args:
|
||||
chunk (ChatResponse): the chunk to add back to the beginning of the
|
||||
chat_stream.
|
||||
|
||||
Return:
|
||||
AsyncGenerator[ChatResponse, None]: the updated async chat_stream
|
||||
"""
|
||||
yield chunk
|
||||
async for item in chat_stream:
|
||||
yield item
|
||||
|
||||
def _run_step(
|
||||
self,
|
||||
step: TaskStep,
|
||||
task: Task,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
if step.input is not None:
|
||||
add_user_step_to_reasoning(
|
||||
step,
|
||||
task.extra_state["new_memory"],
|
||||
task.extra_state["current_reasoning"],
|
||||
verbose=self._verbose,
|
||||
)
|
||||
# TODO: see if we want to do step-based inputs
|
||||
tools = self.get_tools(task.input)
|
||||
|
||||
input_chat = self._react_chat_formatter.format(
|
||||
tools,
|
||||
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
|
||||
current_reasoning=task.extra_state["current_reasoning"],
|
||||
)
|
||||
|
||||
# send prompt
|
||||
chat_response = self._llm.chat(input_chat)
|
||||
# given react prompt outputs, call tools or return response
|
||||
reasoning_steps, is_done = self._process_actions(
|
||||
task, tools, output=chat_response
|
||||
)
|
||||
task.extra_state["current_reasoning"].extend(reasoning_steps)
|
||||
agent_response = self._get_response(
|
||||
task.extra_state["current_reasoning"], task.extra_state["sources"]
|
||||
)
|
||||
if is_done:
|
||||
task.extra_state["new_memory"].put(
|
||||
ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
|
||||
)
|
||||
|
||||
return self._get_task_step_response(agent_response, step, is_done)
|
||||
|
||||
async def _arun_step(
|
||||
self,
|
||||
step: TaskStep,
|
||||
task: Task,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
if step.input is not None:
|
||||
add_user_step_to_reasoning(
|
||||
step,
|
||||
task.extra_state["new_memory"],
|
||||
task.extra_state["current_reasoning"],
|
||||
verbose=self._verbose,
|
||||
)
|
||||
# TODO: see if we want to do step-based inputs
|
||||
tools = self.get_tools(task.input)
|
||||
|
||||
input_chat = self._react_chat_formatter.format(
|
||||
tools,
|
||||
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
|
||||
current_reasoning=task.extra_state["current_reasoning"],
|
||||
)
|
||||
# send prompt
|
||||
chat_response = await self._llm.achat(input_chat)
|
||||
# given react prompt outputs, call tools or return response
|
||||
reasoning_steps, is_done = await self._aprocess_actions(
|
||||
task, tools, output=chat_response
|
||||
)
|
||||
task.extra_state["current_reasoning"].extend(reasoning_steps)
|
||||
agent_response = self._get_response(
|
||||
task.extra_state["current_reasoning"], task.extra_state["sources"]
|
||||
)
|
||||
if is_done:
|
||||
task.extra_state["new_memory"].put(
|
||||
ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
|
||||
)
|
||||
|
||||
return self._get_task_step_response(agent_response, step, is_done)
|
||||
|
||||
def _run_step_stream(
|
||||
self,
|
||||
step: TaskStep,
|
||||
task: Task,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
if step.input is not None:
|
||||
add_user_step_to_reasoning(
|
||||
step,
|
||||
task.extra_state["new_memory"],
|
||||
task.extra_state["current_reasoning"],
|
||||
verbose=self._verbose,
|
||||
)
|
||||
# TODO: see if we want to do step-based inputs
|
||||
tools = self.get_tools(task.input)
|
||||
|
||||
input_chat = self._react_chat_formatter.format(
|
||||
tools,
|
||||
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
|
||||
current_reasoning=task.extra_state["current_reasoning"],
|
||||
)
|
||||
|
||||
chat_stream = self._llm.stream_chat(input_chat)
|
||||
|
||||
# iterate over stream, break out if is final answer after the "Answer: "
|
||||
full_response = ChatResponse(
|
||||
message=ChatMessage(content=None, role="assistant")
|
||||
)
|
||||
is_done = False
|
||||
for latest_chunk in chat_stream:
|
||||
full_response = latest_chunk
|
||||
is_done = self._infer_stream_chunk_is_final(latest_chunk)
|
||||
if is_done:
|
||||
break
|
||||
|
||||
if not is_done:
|
||||
# given react prompt outputs, call tools or return response
|
||||
reasoning_steps, _ = self._process_actions(
|
||||
task, tools=tools, output=full_response, is_streaming=True
|
||||
)
|
||||
task.extra_state["current_reasoning"].extend(reasoning_steps)
|
||||
# use _get_response to return intermediate response
|
||||
agent_response: AGENT_CHAT_RESPONSE_TYPE = self._get_response(
|
||||
task.extra_state["current_reasoning"], task.extra_state["sources"]
|
||||
)
|
||||
else:
|
||||
# Get the response in a separate thread so we can yield the response
|
||||
response_stream = self._add_back_chunk_to_stream(
|
||||
chunk=latest_chunk, chat_stream=chat_stream
|
||||
)
|
||||
|
||||
agent_response = StreamingAgentChatResponse(
|
||||
chat_stream=response_stream,
|
||||
sources=task.extra_state["sources"],
|
||||
)
|
||||
thread = Thread(
|
||||
target=agent_response.write_response_to_history,
|
||||
args=(task.extra_state["new_memory"],),
|
||||
)
|
||||
thread.start()
|
||||
|
||||
return self._get_task_step_response(agent_response, step, is_done)
|
||||
|
||||
async def _arun_step_stream(
|
||||
self,
|
||||
step: TaskStep,
|
||||
task: Task,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
if step.input is not None:
|
||||
add_user_step_to_reasoning(
|
||||
step,
|
||||
task.extra_state["new_memory"],
|
||||
task.extra_state["current_reasoning"],
|
||||
verbose=self._verbose,
|
||||
)
|
||||
# TODO: see if we want to do step-based inputs
|
||||
tools = self.get_tools(task.input)
|
||||
|
||||
input_chat = self._react_chat_formatter.format(
|
||||
tools,
|
||||
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
|
||||
current_reasoning=task.extra_state["current_reasoning"],
|
||||
)
|
||||
|
||||
chat_stream = await self._llm.astream_chat(input_chat)
|
||||
|
||||
# iterate over stream, break out if is final answer after the "Answer: "
|
||||
full_response = ChatResponse(
|
||||
message=ChatMessage(content=None, role="assistant")
|
||||
)
|
||||
is_done = False
|
||||
async for latest_chunk in chat_stream:
|
||||
full_response = latest_chunk
|
||||
is_done = self._infer_stream_chunk_is_final(latest_chunk)
|
||||
if is_done:
|
||||
break
|
||||
|
||||
if not is_done:
|
||||
# given react prompt outputs, call tools or return response
|
||||
reasoning_steps, _ = self._process_actions(
|
||||
task, tools=tools, output=full_response, is_streaming=True
|
||||
)
|
||||
task.extra_state["current_reasoning"].extend(reasoning_steps)
|
||||
# use _get_response to return intermediate response
|
||||
agent_response: AGENT_CHAT_RESPONSE_TYPE = self._get_response(
|
||||
task.extra_state["current_reasoning"], task.extra_state["sources"]
|
||||
)
|
||||
else:
|
||||
# Get the response in a separate thread so we can yield the response
|
||||
response_stream = self._async_add_back_chunk_to_stream(
|
||||
chunk=latest_chunk, chat_stream=chat_stream
|
||||
)
|
||||
|
||||
agent_response = StreamingAgentChatResponse(
|
||||
achat_stream=response_stream,
|
||||
sources=task.extra_state["sources"],
|
||||
)
|
||||
# create task to write chat response to history
|
||||
asyncio.create_task(
|
||||
agent_response.awrite_response_to_history(
|
||||
task.extra_state["new_memory"]
|
||||
)
|
||||
)
|
||||
# wait until response writing is done
|
||||
await agent_response._is_function_false_event.wait()
|
||||
|
||||
return self._get_task_step_response(agent_response, step, is_done)
|
||||
|
||||
@trace_method("run_step")
|
||||
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
return self._run_step(step, task)
|
||||
|
||||
@trace_method("run_step")
|
||||
async def arun_step(
|
||||
self, step: TaskStep, task: Task, **kwargs: Any
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async)."""
|
||||
return await self._arun_step(step, task)
|
||||
|
||||
@trace_method("run_step")
|
||||
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
||||
"""Run step (stream)."""
|
||||
# TODO: figure out if we need a different type for TaskStepOutput
|
||||
return self._run_step_stream(step, task)
|
||||
|
||||
@trace_method("run_step")
|
||||
async def astream_step(
|
||||
self, step: TaskStep, task: Task, **kwargs: Any
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async stream)."""
|
||||
return await self._arun_step_stream(step, task)
|
||||
|
||||
def finalize_task(self, task: Task, **kwargs: Any) -> None:
|
||||
"""Finalize task, after all the steps are completed."""
|
||||
# add new messages to memory
|
||||
task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
|
||||
# reset new memory
|
||||
task.extra_state["new_memory"].reset()
|
||||
|
||||
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
||||
"""Set callback manager."""
|
||||
# TODO: make this abstractmethod (right now will break some agent impls)
|
||||
self.callback_manager = callback_manager
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
"""Base types for ReAct agent."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Dict
|
||||
|
||||
from llama_index.bridge.pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseReasoningStep(BaseModel):
|
||||
"""Reasoning step."""
|
||||
|
||||
@abstractmethod
|
||||
def get_content(self) -> str:
|
||||
"""Get content."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_done(self) -> bool:
|
||||
"""Is the reasoning step the last one."""
|
||||
|
||||
|
||||
class ActionReasoningStep(BaseReasoningStep):
|
||||
"""Action Reasoning step."""
|
||||
|
||||
thought: str
|
||||
action: str
|
||||
action_input: Dict
|
||||
|
||||
def get_content(self) -> str:
|
||||
"""Get content."""
|
||||
return (
|
||||
f"Thought: {self.thought}\nAction: {self.action}\n"
|
||||
f"Action Input: {self.action_input}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
"""Is the reasoning step the last one."""
|
||||
return False
|
||||
|
||||
|
||||
class ObservationReasoningStep(BaseReasoningStep):
|
||||
"""Observation reasoning step."""
|
||||
|
||||
observation: str
|
||||
|
||||
def get_content(self) -> str:
|
||||
"""Get content."""
|
||||
return f"Observation: {self.observation}"
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
"""Is the reasoning step the last one."""
|
||||
return False
|
||||
|
||||
|
||||
class ResponseReasoningStep(BaseReasoningStep):
|
||||
"""Response reasoning step."""
|
||||
|
||||
thought: str
|
||||
response: str
|
||||
is_streaming: bool = False
|
||||
|
||||
def get_content(self) -> str:
|
||||
"""Get content."""
|
||||
if self.is_streaming:
|
||||
return (
|
||||
f"Thought: {self.thought}\n"
|
||||
f"Answer (Starts With): {self.response} ..."
|
||||
)
|
||||
else:
|
||||
return f"Thought: {self.thought}\n" f"Answer: {self.response}"
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
"""Is the reasoning step the last one."""
|
||||
return True
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
"""Default prompt for ReAct agent."""
|
||||
|
||||
|
||||
# ReAct multimodal chat prompt
|
||||
# TODO: have formatting instructions be a part of react output parser
|
||||
|
||||
REACT_MM_CHAT_SYSTEM_HEADER = """\
|
||||
|
||||
You are designed to help with a variety of tasks, from answering questions \
|
||||
to providing summaries to other types of analyses. You can take in both text \
|
||||
and images.
|
||||
|
||||
|
||||
## Tools
|
||||
You have access to a wide variety of tools. You are responsible for using
|
||||
the tools in any sequence you deem appropriate to complete the task at hand.
|
||||
This may require breaking the task into subtasks and using different tools
|
||||
to complete each subtask.
|
||||
|
||||
NOTE: you do NOT need to use a tool to understand the provided images. You can
|
||||
use both the input text and images as context to decide which tool to use.
|
||||
|
||||
You have access to the following tools:
|
||||
{tool_desc}
|
||||
|
||||
## Input
|
||||
The user will specify a task (in text) and a set of images. Treat
|
||||
the images as additional context for the task.
|
||||
|
||||
## Output Format
|
||||
To answer the question, please use the following format.
|
||||
|
||||
```
|
||||
Thought: I need to use a tool to help me answer the question.
|
||||
Action: tool name (one of {tool_names}) if using a tool.
|
||||
Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. {{"input": "hello world", "num_beams": 5}})
|
||||
```
|
||||
|
||||
Please ALWAYS start with a Thought.
|
||||
|
||||
Please use a valid JSON format for the Action Input. Do NOT do this {{'input': 'hello world', 'num_beams': 5}}.
|
||||
|
||||
If this format is used, the user will respond in the following format:
|
||||
|
||||
```
|
||||
Observation: tool response
|
||||
```
|
||||
|
||||
Here's a concrete example. Again, you can take in both text and images as input. This can generate a thought which can be used to decide which tool to use.
|
||||
The input to the tool should not assume knowledge of the image. Therefore it is your responsibility \
|
||||
to translate the input text/images into a format that the tool can understand.
|
||||
|
||||
For example:
|
||||
```
|
||||
Thought: This image is a picture of a brown dog. The text asked me to identify its name, so I need to use a tool to lookup its name.
|
||||
Action: churchill_bio_tool
|
||||
Action Input: {{"input": "brown dog name"}}
|
||||
|
||||
```
|
||||
Example user response:
|
||||
|
||||
```
|
||||
Observation: The name of the brown dog is Rufus.
|
||||
```
|
||||
|
||||
|
||||
You should keep repeating the above format until you have enough information
|
||||
to answer the question without using any more tools. At that point, you MUST respond
|
||||
in the one of the following two formats:
|
||||
|
||||
```
|
||||
Thought: I can answer without using any more tools.
|
||||
Answer: [your answer here]
|
||||
```
|
||||
|
||||
```
|
||||
Thought: I cannot answer the question with the provided tools.
|
||||
Answer: Sorry, I cannot answer your query.
|
||||
```
|
||||
|
||||
The answer MUST be grounded in the input text and images. Do not give an answer that is irrelevant to the image
|
||||
provided.
|
||||
|
||||
## Current Conversation
|
||||
Below is the current conversation consisting of interleaving human and assistant messages.
|
||||
|
||||
"""
|
||||
|
|
@ -0,0 +1,479 @@
|
|||
"""ReAct multimodal agent."""
|
||||
|
||||
import uuid
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
from llama_index.agent.react.formatter import ReActChatFormatter
|
||||
from llama_index.agent.react.output_parser import ReActOutputParser
|
||||
from llama_index.agent.react.types import (
|
||||
ActionReasoningStep,
|
||||
BaseReasoningStep,
|
||||
ObservationReasoningStep,
|
||||
ResponseReasoningStep,
|
||||
)
|
||||
from llama_index.agent.react_multimodal.prompts import REACT_MM_CHAT_SYSTEM_HEADER
|
||||
from llama_index.agent.types import (
|
||||
BaseAgentWorker,
|
||||
Task,
|
||||
TaskStep,
|
||||
TaskStepOutput,
|
||||
)
|
||||
from llama_index.callbacks import (
|
||||
CallbackManager,
|
||||
CBEventType,
|
||||
EventPayload,
|
||||
trace_method,
|
||||
)
|
||||
from llama_index.chat_engine.types import (
|
||||
AGENT_CHAT_RESPONSE_TYPE,
|
||||
AgentChatResponse,
|
||||
)
|
||||
from llama_index.core.llms.types import MessageRole
|
||||
from llama_index.llms.base import ChatMessage, ChatResponse
|
||||
from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer
|
||||
from llama_index.memory.types import BaseMemory
|
||||
from llama_index.multi_modal_llms.base import MultiModalLLM
|
||||
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
|
||||
from llama_index.multi_modal_llms.openai_utils import (
|
||||
generate_openai_multi_modal_chat_message,
|
||||
)
|
||||
from llama_index.objects.base import ObjectRetriever
|
||||
from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool
|
||||
from llama_index.tools.types import AsyncBaseTool
|
||||
from llama_index.utils import print_text
|
||||
|
||||
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
|
||||
|
||||
|
||||
def add_user_step_to_reasoning(
|
||||
step: TaskStep,
|
||||
memory: BaseMemory,
|
||||
current_reasoning: List[BaseReasoningStep],
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""Add user step to reasoning.
|
||||
|
||||
Adds both text input and image input to reasoning.
|
||||
|
||||
"""
|
||||
# raise error if step.input is None
|
||||
if step.input is None:
|
||||
raise ValueError("Step input is None.")
|
||||
# TODO: support gemini as well. Currently just supports OpenAI
|
||||
|
||||
# TODO: currently assume that you can't generate images in the loop,
|
||||
# so step_state contains the original image_docs from the task
|
||||
# (it doesn't change)
|
||||
image_docs = step.step_state["image_docs"]
|
||||
image_kwargs = step.step_state.get("image_kwargs", {})
|
||||
|
||||
if "is_first" in step.step_state and step.step_state["is_first"]:
|
||||
mm_message = generate_openai_multi_modal_chat_message(
|
||||
prompt=step.input,
|
||||
role=MessageRole.USER,
|
||||
image_documents=image_docs,
|
||||
**image_kwargs,
|
||||
)
|
||||
# add to new memory
|
||||
memory.put(mm_message)
|
||||
step.step_state["is_first"] = False
|
||||
else:
|
||||
# NOTE: this is where the user specifies an intermediate step in the middle
|
||||
# TODO: don't support specifying image_docs here for now
|
||||
reasoning_step = ObservationReasoningStep(observation=step.input)
|
||||
current_reasoning.append(reasoning_step)
|
||||
if verbose:
|
||||
print(f"Added user message to memory: {step.input}")
|
||||
|
||||
|
||||
class MultimodalReActAgentWorker(BaseAgentWorker):
|
||||
"""Multimodal ReAct Agent worker.
|
||||
|
||||
**NOTE**: This is a BETA feature.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: Sequence[BaseTool],
|
||||
multi_modal_llm: MultiModalLLM,
|
||||
max_iterations: int = 10,
|
||||
react_chat_formatter: Optional[ReActChatFormatter] = None,
|
||||
output_parser: Optional[ReActOutputParser] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
) -> None:
|
||||
self._multi_modal_llm = multi_modal_llm
|
||||
self.callback_manager = callback_manager or CallbackManager([])
|
||||
self._max_iterations = max_iterations
|
||||
self._react_chat_formatter = react_chat_formatter or ReActChatFormatter(
|
||||
system_header=REACT_MM_CHAT_SYSTEM_HEADER
|
||||
)
|
||||
self._output_parser = output_parser or ReActOutputParser()
|
||||
self._verbose = verbose
|
||||
|
||||
if len(tools) > 0 and tool_retriever is not None:
|
||||
raise ValueError("Cannot specify both tools and tool_retriever")
|
||||
elif len(tools) > 0:
|
||||
self._get_tools = lambda _: tools
|
||||
elif tool_retriever is not None:
|
||||
tool_retriever_c = cast(ObjectRetriever[BaseTool], tool_retriever)
|
||||
self._get_tools = lambda message: tool_retriever_c.retrieve(message)
|
||||
else:
|
||||
self._get_tools = lambda _: []
|
||||
|
||||
@classmethod
|
||||
def from_tools(
|
||||
cls,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
|
||||
multi_modal_llm: Optional[MultiModalLLM] = None,
|
||||
max_iterations: int = 10,
|
||||
react_chat_formatter: Optional[ReActChatFormatter] = None,
|
||||
output_parser: Optional[ReActOutputParser] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "MultimodalReActAgentWorker":
|
||||
"""Convenience constructor method from set of of BaseTools (Optional).
|
||||
|
||||
NOTE: kwargs should have been exhausted by this point. In other words
|
||||
the various upstream components such as BaseSynthesizer (response synthesizer)
|
||||
or BaseRetriever should have picked up off their respective kwargs in their
|
||||
constructions.
|
||||
|
||||
Returns:
|
||||
ReActAgent
|
||||
"""
|
||||
multi_modal_llm = multi_modal_llm or OpenAIMultiModal(
|
||||
model="gpt-4-vision-preview", max_new_tokens=1000
|
||||
)
|
||||
return cls(
|
||||
tools=tools or [],
|
||||
tool_retriever=tool_retriever,
|
||||
multi_modal_llm=multi_modal_llm,
|
||||
max_iterations=max_iterations,
|
||||
react_chat_formatter=react_chat_formatter,
|
||||
output_parser=output_parser,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
|
||||
"""Initialize step from task."""
|
||||
sources: List[ToolOutput] = []
|
||||
current_reasoning: List[BaseReasoningStep] = []
|
||||
# temporary memory for new messages
|
||||
new_memory = ChatMemoryBuffer.from_defaults()
|
||||
|
||||
# validation
|
||||
if "image_docs" not in task.extra_state:
|
||||
raise ValueError("Image docs not found in task extra state.")
|
||||
|
||||
# initialize task state
|
||||
task_state = {
|
||||
"sources": sources,
|
||||
"current_reasoning": current_reasoning,
|
||||
"new_memory": new_memory,
|
||||
}
|
||||
task.extra_state.update(task_state)
|
||||
|
||||
return TaskStep(
|
||||
task_id=task.task_id,
|
||||
step_id=str(uuid.uuid4()),
|
||||
input=task.input,
|
||||
step_state={"is_first": True, "image_docs": task.extra_state["image_docs"]},
|
||||
)
|
||||
|
||||
def get_tools(self, input: str) -> List[AsyncBaseTool]:
|
||||
"""Get tools."""
|
||||
return [adapt_to_async_tool(t) for t in self._get_tools(input)]
|
||||
|
||||
def _extract_reasoning_step(
|
||||
self, output: ChatResponse, is_streaming: bool = False
|
||||
) -> Tuple[str, List[BaseReasoningStep], bool]:
|
||||
"""
|
||||
Extracts the reasoning step from the given output.
|
||||
|
||||
This method parses the message content from the output,
|
||||
extracts the reasoning step, and determines whether the processing is
|
||||
complete. It also performs validation checks on the output and
|
||||
handles possible errors.
|
||||
"""
|
||||
if output.message.content is None:
|
||||
raise ValueError("Got empty message.")
|
||||
message_content = output.message.content
|
||||
current_reasoning = []
|
||||
try:
|
||||
reasoning_step = self._output_parser.parse(message_content, is_streaming)
|
||||
except BaseException as exc:
|
||||
raise ValueError(f"Could not parse output: {message_content}") from exc
|
||||
if self._verbose:
|
||||
print_text(f"{reasoning_step.get_content()}\n", color="pink")
|
||||
current_reasoning.append(reasoning_step)
|
||||
|
||||
if reasoning_step.is_done:
|
||||
return message_content, current_reasoning, True
|
||||
|
||||
reasoning_step = cast(ActionReasoningStep, reasoning_step)
|
||||
if not isinstance(reasoning_step, ActionReasoningStep):
|
||||
raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}")
|
||||
|
||||
return message_content, current_reasoning, False
|
||||
|
||||
def _process_actions(
|
||||
self,
|
||||
task: Task,
|
||||
tools: Sequence[AsyncBaseTool],
|
||||
output: ChatResponse,
|
||||
is_streaming: bool = False,
|
||||
) -> Tuple[List[BaseReasoningStep], bool]:
|
||||
tools_dict: Dict[str, AsyncBaseTool] = {
|
||||
tool.metadata.get_name(): tool for tool in tools
|
||||
}
|
||||
_, current_reasoning, is_done = self._extract_reasoning_step(
|
||||
output, is_streaming
|
||||
)
|
||||
|
||||
if is_done:
|
||||
return current_reasoning, True
|
||||
|
||||
# call tool with input
|
||||
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
|
||||
tool = tools_dict[reasoning_step.action]
|
||||
with self.callback_manager.event(
|
||||
CBEventType.FUNCTION_CALL,
|
||||
payload={
|
||||
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
|
||||
EventPayload.TOOL: tool.metadata,
|
||||
},
|
||||
) as event:
|
||||
tool_output = tool.call(**reasoning_step.action_input)
|
||||
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
||||
|
||||
task.extra_state["sources"].append(tool_output)
|
||||
|
||||
observation_step = ObservationReasoningStep(observation=str(tool_output))
|
||||
current_reasoning.append(observation_step)
|
||||
if self._verbose:
|
||||
print_text(f"{observation_step.get_content()}\n", color="blue")
|
||||
return current_reasoning, False
|
||||
|
||||
async def _aprocess_actions(
|
||||
self,
|
||||
task: Task,
|
||||
tools: Sequence[AsyncBaseTool],
|
||||
output: ChatResponse,
|
||||
is_streaming: bool = False,
|
||||
) -> Tuple[List[BaseReasoningStep], bool]:
|
||||
tools_dict = {tool.metadata.name: tool for tool in tools}
|
||||
_, current_reasoning, is_done = self._extract_reasoning_step(
|
||||
output, is_streaming
|
||||
)
|
||||
|
||||
if is_done:
|
||||
return current_reasoning, True
|
||||
|
||||
# call tool with input
|
||||
reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
|
||||
tool = tools_dict[reasoning_step.action]
|
||||
with self.callback_manager.event(
|
||||
CBEventType.FUNCTION_CALL,
|
||||
payload={
|
||||
EventPayload.FUNCTION_CALL: reasoning_step.action_input,
|
||||
EventPayload.TOOL: tool.metadata,
|
||||
},
|
||||
) as event:
|
||||
tool_output = await tool.acall(**reasoning_step.action_input)
|
||||
event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
|
||||
|
||||
task.extra_state["sources"].append(tool_output)
|
||||
|
||||
observation_step = ObservationReasoningStep(observation=str(tool_output))
|
||||
current_reasoning.append(observation_step)
|
||||
if self._verbose:
|
||||
print_text(f"{observation_step.get_content()}\n", color="blue")
|
||||
return current_reasoning, False
|
||||
|
||||
def _get_response(
|
||||
self,
|
||||
current_reasoning: List[BaseReasoningStep],
|
||||
sources: List[ToolOutput],
|
||||
) -> AgentChatResponse:
|
||||
"""Get response from reasoning steps."""
|
||||
if len(current_reasoning) == 0:
|
||||
raise ValueError("No reasoning steps were taken.")
|
||||
elif len(current_reasoning) == self._max_iterations:
|
||||
raise ValueError("Reached max iterations.")
|
||||
|
||||
if isinstance(current_reasoning[-1], ResponseReasoningStep):
|
||||
response_step = cast(ResponseReasoningStep, current_reasoning[-1])
|
||||
response_str = response_step.response
|
||||
else:
|
||||
response_str = current_reasoning[-1].get_content()
|
||||
|
||||
# TODO: add sources from reasoning steps
|
||||
return AgentChatResponse(response=response_str, sources=sources)
|
||||
|
||||
def _get_task_step_response(
|
||||
self, agent_response: AGENT_CHAT_RESPONSE_TYPE, step: TaskStep, is_done: bool
|
||||
) -> TaskStepOutput:
|
||||
"""Get task step response."""
|
||||
if is_done:
|
||||
new_steps = []
|
||||
else:
|
||||
new_steps = [
|
||||
step.get_next_step(
|
||||
step_id=str(uuid.uuid4()),
|
||||
# NOTE: input is unused
|
||||
input=None,
|
||||
)
|
||||
]
|
||||
|
||||
return TaskStepOutput(
|
||||
output=agent_response,
|
||||
task_step=step,
|
||||
is_last=is_done,
|
||||
next_steps=new_steps,
|
||||
)
|
||||
|
||||
def _run_step(
|
||||
self,
|
||||
step: TaskStep,
|
||||
task: Task,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
# This is either not None on the first step or if the user specifies
|
||||
# an intermediate step in the middle
|
||||
if step.input is not None:
|
||||
add_user_step_to_reasoning(
|
||||
step,
|
||||
task.extra_state["new_memory"],
|
||||
task.extra_state["current_reasoning"],
|
||||
verbose=self._verbose,
|
||||
)
|
||||
# TODO: see if we want to do step-based inputs
|
||||
tools = self.get_tools(task.input)
|
||||
|
||||
input_chat = self._react_chat_formatter.format(
|
||||
tools,
|
||||
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
|
||||
current_reasoning=task.extra_state["current_reasoning"],
|
||||
)
|
||||
|
||||
# send prompt
|
||||
chat_response = self._multi_modal_llm.chat(input_chat)
|
||||
# given react prompt outputs, call tools or return response
|
||||
reasoning_steps, is_done = self._process_actions(
|
||||
task, tools, output=chat_response
|
||||
)
|
||||
task.extra_state["current_reasoning"].extend(reasoning_steps)
|
||||
agent_response = self._get_response(
|
||||
task.extra_state["current_reasoning"], task.extra_state["sources"]
|
||||
)
|
||||
if is_done:
|
||||
task.extra_state["new_memory"].put(
|
||||
ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
|
||||
)
|
||||
|
||||
return self._get_task_step_response(agent_response, step, is_done)
|
||||
|
||||
async def _arun_step(
|
||||
self,
|
||||
step: TaskStep,
|
||||
task: Task,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
if step.input is not None:
|
||||
add_user_step_to_reasoning(
|
||||
step,
|
||||
task.extra_state["new_memory"],
|
||||
task.extra_state["current_reasoning"],
|
||||
verbose=self._verbose,
|
||||
)
|
||||
# TODO: see if we want to do step-based inputs
|
||||
tools = self.get_tools(task.input)
|
||||
|
||||
input_chat = self._react_chat_formatter.format(
|
||||
tools,
|
||||
chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
|
||||
current_reasoning=task.extra_state["current_reasoning"],
|
||||
)
|
||||
# send prompt
|
||||
chat_response = await self._multi_modal_llm.achat(input_chat)
|
||||
# given react prompt outputs, call tools or return response
|
||||
reasoning_steps, is_done = await self._aprocess_actions(
|
||||
task, tools, output=chat_response
|
||||
)
|
||||
task.extra_state["current_reasoning"].extend(reasoning_steps)
|
||||
agent_response = self._get_response(
|
||||
task.extra_state["current_reasoning"], task.extra_state["sources"]
|
||||
)
|
||||
if is_done:
|
||||
task.extra_state["new_memory"].put(
|
||||
ChatMessage(content=agent_response.response, role=MessageRole.ASSISTANT)
|
||||
)
|
||||
|
||||
return self._get_task_step_response(agent_response, step, is_done)
|
||||
|
||||
def _run_step_stream(
|
||||
self,
|
||||
step: TaskStep,
|
||||
task: Task,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
raise NotImplementedError("Stream step not implemented yet.")
|
||||
|
||||
async def _arun_step_stream(
|
||||
self,
|
||||
step: TaskStep,
|
||||
task: Task,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
raise NotImplementedError("Stream step not implemented yet.")
|
||||
|
||||
@trace_method("run_step")
|
||||
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
return self._run_step(step, task)
|
||||
|
||||
@trace_method("run_step")
|
||||
async def arun_step(
|
||||
self, step: TaskStep, task: Task, **kwargs: Any
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async)."""
|
||||
return await self._arun_step(step, task)
|
||||
|
||||
@trace_method("run_step")
|
||||
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
||||
"""Run step (stream)."""
|
||||
# TODO: figure out if we need a different type for TaskStepOutput
|
||||
return self._run_step_stream(step, task)
|
||||
|
||||
@trace_method("run_step")
|
||||
async def astream_step(
|
||||
self, step: TaskStep, task: Task, **kwargs: Any
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async stream)."""
|
||||
return await self._arun_step_stream(step, task)
|
||||
|
||||
def finalize_task(self, task: Task, **kwargs: Any) -> None:
|
||||
"""Finalize task, after all the steps are completed."""
|
||||
# add new messages to memory
|
||||
task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all())
|
||||
# reset new memory
|
||||
task.extra_state["new_memory"].reset()
|
||||
|
||||
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
||||
"""Set callback manager."""
|
||||
# TODO: make this abstractmethod (right now will break some agent impls)
|
||||
self.callback_manager = callback_manager
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Init params."""
|
||||
|
|
@ -0,0 +1,631 @@
|
|||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Optional, Union, cast
|
||||
|
||||
from llama_index.agent.types import (
|
||||
BaseAgent,
|
||||
BaseAgentWorker,
|
||||
Task,
|
||||
TaskStep,
|
||||
TaskStepOutput,
|
||||
)
|
||||
from llama_index.bridge.pydantic import BaseModel, Field
|
||||
from llama_index.callbacks import (
|
||||
CallbackManager,
|
||||
CBEventType,
|
||||
EventPayload,
|
||||
trace_method,
|
||||
)
|
||||
from llama_index.chat_engine.types import (
|
||||
AGENT_CHAT_RESPONSE_TYPE,
|
||||
AgentChatResponse,
|
||||
ChatResponseMode,
|
||||
StreamingAgentChatResponse,
|
||||
)
|
||||
from llama_index.llms.base import ChatMessage
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.memory import BaseMemory, ChatMemoryBuffer
|
||||
from llama_index.memory.types import BaseMemory
|
||||
from llama_index.tools.types import BaseTool
|
||||
|
||||
|
||||
class BaseAgentRunner(BaseAgent):
|
||||
"""Base agent runner."""
|
||||
|
||||
@abstractmethod
|
||||
def create_task(self, input: str, **kwargs: Any) -> Task:
|
||||
"""Create task."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_task(
|
||||
self,
|
||||
task_id: str,
|
||||
) -> None:
|
||||
"""Delete task.
|
||||
|
||||
NOTE: this will not delete any previous executions from memory.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_tasks(self, **kwargs: Any) -> List[Task]:
|
||||
"""List tasks."""
|
||||
|
||||
@abstractmethod
|
||||
def get_task(self, task_id: str, **kwargs: Any) -> Task:
|
||||
"""Get task."""
|
||||
|
||||
@abstractmethod
|
||||
def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
|
||||
"""Get upcoming steps."""
|
||||
|
||||
@abstractmethod
|
||||
def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
|
||||
"""Get completed steps."""
|
||||
|
||||
def get_completed_step(
|
||||
self, task_id: str, step_id: str, **kwargs: Any
|
||||
) -> TaskStepOutput:
|
||||
"""Get completed step."""
|
||||
# call get_completed_steps, and then find the right task
|
||||
completed_steps = self.get_completed_steps(task_id, **kwargs)
|
||||
for step_output in completed_steps:
|
||||
if step_output.task_step.step_id == step_id:
|
||||
return step_output
|
||||
raise ValueError(f"Could not find step_id: {step_id}")
|
||||
|
||||
@abstractmethod
|
||||
def run_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: Optional[str] = None,
|
||||
step: Optional[TaskStep] = None,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
|
||||
@abstractmethod
|
||||
async def arun_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: Optional[str] = None,
|
||||
step: Optional[TaskStep] = None,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async)."""
|
||||
|
||||
@abstractmethod
|
||||
def stream_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: Optional[str] = None,
|
||||
step: Optional[TaskStep] = None,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (stream)."""
|
||||
|
||||
@abstractmethod
|
||||
async def astream_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: Optional[str] = None,
|
||||
step: Optional[TaskStep] = None,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async stream)."""
|
||||
|
||||
@abstractmethod
|
||||
def finalize_response(
|
||||
self,
|
||||
task_id: str,
|
||||
step_output: Optional[TaskStepOutput] = None,
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Finalize response."""
|
||||
|
||||
@abstractmethod
|
||||
def undo_step(self, task_id: str) -> None:
|
||||
"""Undo previous step."""
|
||||
raise NotImplementedError("undo_step not implemented")
|
||||
|
||||
|
||||
def validate_step_from_args(
|
||||
task_id: str, input: Optional[str] = None, step: Optional[Any] = None, **kwargs: Any
|
||||
) -> Optional[TaskStep]:
|
||||
"""Validate step from args."""
|
||||
if step is not None:
|
||||
if input is not None:
|
||||
raise ValueError("Cannot specify both `step` and `input`")
|
||||
if not isinstance(step, TaskStep):
|
||||
raise ValueError(f"step must be TaskStep: {step}")
|
||||
return step
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
"""Task state."""
|
||||
|
||||
task: Task = Field(..., description="Task.")
|
||||
step_queue: Deque[TaskStep] = Field(
|
||||
default_factory=deque, description="Task step queue."
|
||||
)
|
||||
completed_steps: List[TaskStepOutput] = Field(
|
||||
default_factory=list, description="Completed step outputs."
|
||||
)
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
"""Agent state."""
|
||||
|
||||
task_dict: Dict[str, TaskState] = Field(
|
||||
default_factory=dict, description="Task dictionary."
|
||||
)
|
||||
|
||||
def get_task(self, task_id: str) -> Task:
|
||||
"""Get task state."""
|
||||
return self.task_dict[task_id].task
|
||||
|
||||
def get_completed_steps(self, task_id: str) -> List[TaskStepOutput]:
|
||||
"""Get completed steps."""
|
||||
return self.task_dict[task_id].completed_steps
|
||||
|
||||
def get_step_queue(self, task_id: str) -> Deque[TaskStep]:
|
||||
"""Get step queue."""
|
||||
return self.task_dict[task_id].step_queue
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset."""
|
||||
self.task_dict = {}
|
||||
|
||||
|
||||
class AgentRunner(BaseAgentRunner):
|
||||
"""Agent runner.
|
||||
|
||||
Top-level agent orchestrator that can create tasks, run each step in a task,
|
||||
or run a task e2e. Stores state and keeps track of tasks.
|
||||
|
||||
Args:
|
||||
agent_worker (BaseAgentWorker): step executor
|
||||
chat_history (Optional[List[ChatMessage]], optional): chat history. Defaults to None.
|
||||
state (Optional[AgentState], optional): agent state. Defaults to None.
|
||||
memory (Optional[BaseMemory], optional): memory. Defaults to None.
|
||||
llm (Optional[LLM], optional): LLM. Defaults to None.
|
||||
callback_manager (Optional[CallbackManager], optional): callback manager. Defaults to None.
|
||||
init_task_state_kwargs (Optional[dict], optional): init task state kwargs. Defaults to None.
|
||||
|
||||
"""
|
||||
|
||||
# # TODO: implement this in Pydantic
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_worker: BaseAgentWorker,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
state: Optional[AgentState] = None,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
llm: Optional[LLM] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
init_task_state_kwargs: Optional[dict] = None,
|
||||
delete_task_on_finish: bool = False,
|
||||
default_tool_choice: str = "auto",
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""Initialize."""
|
||||
self.agent_worker = agent_worker
|
||||
self.state = state or AgentState()
|
||||
self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm)
|
||||
|
||||
# get and set callback manager
|
||||
if callback_manager is not None:
|
||||
self.agent_worker.set_callback_manager(callback_manager)
|
||||
self.callback_manager = callback_manager
|
||||
else:
|
||||
# TODO: This is *temporary*
|
||||
# Stopgap before having a callback on the BaseAgentWorker interface.
|
||||
# Doing that requires a bit more refactoring to make sure existing code
|
||||
# doesn't break.
|
||||
if hasattr(self.agent_worker, "callback_manager"):
|
||||
self.callback_manager = (
|
||||
self.agent_worker.callback_manager or CallbackManager()
|
||||
)
|
||||
else:
|
||||
self.callback_manager = CallbackManager()
|
||||
|
||||
self.init_task_state_kwargs = init_task_state_kwargs or {}
|
||||
self.delete_task_on_finish = delete_task_on_finish
|
||||
self.default_tool_choice = default_tool_choice
|
||||
self.verbose = verbose
|
||||
|
||||
@staticmethod
|
||||
def from_llm(
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
llm: Optional[LLM] = None,
|
||||
**kwargs: Any,
|
||||
) -> "AgentRunner":
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.llms.openai_utils import is_function_calling_model
|
||||
|
||||
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
|
||||
from llama_index.agent import OpenAIAgent
|
||||
|
||||
return OpenAIAgent.from_tools(
|
||||
tools=tools,
|
||||
llm=llm,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
from llama_index.agent import ReActAgent
|
||||
|
||||
return ReActAgent.from_tools(
|
||||
tools=tools,
|
||||
llm=llm,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def chat_history(self) -> List[ChatMessage]:
|
||||
return self.memory.get_all()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.memory.reset()
|
||||
self.state.reset()
|
||||
|
||||
def create_task(self, input: str, **kwargs: Any) -> Task:
|
||||
"""Create task."""
|
||||
if not self.init_task_state_kwargs:
|
||||
extra_state = kwargs.pop("extra_state", {})
|
||||
else:
|
||||
if "extra_state" in kwargs:
|
||||
raise ValueError(
|
||||
"Cannot specify both `extra_state` and `init_task_state_kwargs`"
|
||||
)
|
||||
else:
|
||||
extra_state = self.init_task_state_kwargs
|
||||
|
||||
callback_manager = kwargs.pop("callback_manager", self.callback_manager)
|
||||
task = Task(
|
||||
input=input,
|
||||
memory=self.memory,
|
||||
extra_state=extra_state,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
# # put input into memory
|
||||
# self.memory.put(ChatMessage(content=input, role=MessageRole.USER))
|
||||
|
||||
# get initial step from task, and put it in the step queue
|
||||
initial_step = self.agent_worker.initialize_step(task)
|
||||
task_state = TaskState(
|
||||
task=task,
|
||||
step_queue=deque([initial_step]),
|
||||
)
|
||||
# add it to state
|
||||
self.state.task_dict[task.task_id] = task_state
|
||||
|
||||
return task
|
||||
|
||||
def delete_task(
|
||||
self,
|
||||
task_id: str,
|
||||
) -> None:
|
||||
"""Delete task.
|
||||
|
||||
NOTE: this will not delete any previous executions from memory.
|
||||
|
||||
"""
|
||||
self.state.task_dict.pop(task_id)
|
||||
|
||||
def list_tasks(self, **kwargs: Any) -> List[Task]:
|
||||
"""List tasks."""
|
||||
return list(self.state.task_dict.values())
|
||||
|
||||
def get_task(self, task_id: str, **kwargs: Any) -> Task:
|
||||
"""Get task."""
|
||||
return self.state.get_task(task_id)
|
||||
|
||||
def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
|
||||
"""Get upcoming steps."""
|
||||
return list(self.state.get_step_queue(task_id))
|
||||
|
||||
def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
|
||||
"""Get completed steps."""
|
||||
return self.state.get_completed_steps(task_id)
|
||||
|
||||
def _run_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step: Optional[TaskStep] = None,
|
||||
input: Optional[str] = None,
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Execute step."""
|
||||
task = self.state.get_task(task_id)
|
||||
step_queue = self.state.get_step_queue(task_id)
|
||||
step = step or step_queue.popleft()
|
||||
if input is not None:
|
||||
step.input = input
|
||||
|
||||
if self.verbose:
|
||||
print(f"> Running step {step.step_id}. Step input: {step.input}")
|
||||
|
||||
# TODO: figure out if you can dynamically swap in different step executors
|
||||
# not clear when you would do that by theoretically possible
|
||||
|
||||
if mode == ChatResponseMode.WAIT:
|
||||
cur_step_output = self.agent_worker.run_step(step, task, **kwargs)
|
||||
elif mode == ChatResponseMode.STREAM:
|
||||
cur_step_output = self.agent_worker.stream_step(step, task, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
# append cur_step_output next steps to queue
|
||||
next_steps = cur_step_output.next_steps
|
||||
step_queue.extend(next_steps)
|
||||
|
||||
# add cur_step_output to completed steps
|
||||
completed_steps = self.state.get_completed_steps(task_id)
|
||||
completed_steps.append(cur_step_output)
|
||||
|
||||
return cur_step_output
|
||||
|
||||
async def _arun_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step: Optional[TaskStep] = None,
|
||||
input: Optional[str] = None,
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Execute step."""
|
||||
task = self.state.get_task(task_id)
|
||||
step_queue = self.state.get_step_queue(task_id)
|
||||
step = step or step_queue.popleft()
|
||||
if input is not None:
|
||||
step.input = input
|
||||
|
||||
if self.verbose:
|
||||
print(f"> Running step {step.step_id}. Step input: {step.input}")
|
||||
|
||||
# TODO: figure out if you can dynamically swap in different step executors
|
||||
# not clear when you would do that by theoretically possible
|
||||
if mode == ChatResponseMode.WAIT:
|
||||
cur_step_output = await self.agent_worker.arun_step(step, task, **kwargs)
|
||||
elif mode == ChatResponseMode.STREAM:
|
||||
cur_step_output = await self.agent_worker.astream_step(step, task, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
# append cur_step_output next steps to queue
|
||||
next_steps = cur_step_output.next_steps
|
||||
step_queue.extend(next_steps)
|
||||
|
||||
# add cur_step_output to completed steps
|
||||
completed_steps = self.state.get_completed_steps(task_id)
|
||||
completed_steps.append(cur_step_output)
|
||||
|
||||
return cur_step_output
|
||||
|
||||
def run_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: Optional[str] = None,
|
||||
step: Optional[TaskStep] = None,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
step = validate_step_from_args(task_id, input, step, **kwargs)
|
||||
return self._run_step(
|
||||
task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs
|
||||
)
|
||||
|
||||
async def arun_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: Optional[str] = None,
|
||||
step: Optional[TaskStep] = None,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async)."""
|
||||
step = validate_step_from_args(task_id, input, step, **kwargs)
|
||||
return await self._arun_step(
|
||||
task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs
|
||||
)
|
||||
|
||||
def stream_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: Optional[str] = None,
|
||||
step: Optional[TaskStep] = None,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (stream)."""
|
||||
step = validate_step_from_args(task_id, input, step, **kwargs)
|
||||
return self._run_step(
|
||||
task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs
|
||||
)
|
||||
|
||||
async def astream_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: Optional[str] = None,
|
||||
step: Optional[TaskStep] = None,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async stream)."""
|
||||
step = validate_step_from_args(task_id, input, step, **kwargs)
|
||||
return await self._arun_step(
|
||||
task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs
|
||||
)
|
||||
|
||||
def finalize_response(
|
||||
self,
|
||||
task_id: str,
|
||||
step_output: Optional[TaskStepOutput] = None,
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Finalize response."""
|
||||
if step_output is None:
|
||||
step_output = self.state.get_completed_steps(task_id)[-1]
|
||||
if not step_output.is_last:
|
||||
raise ValueError(
|
||||
"finalize_response can only be called on the last step output"
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
step_output.output,
|
||||
(AgentChatResponse, StreamingAgentChatResponse),
|
||||
):
|
||||
raise ValueError(
|
||||
"When `is_last` is True, cur_step_output.output must be "
|
||||
f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}"
|
||||
)
|
||||
|
||||
# finalize task
|
||||
self.agent_worker.finalize_task(self.state.get_task(task_id))
|
||||
|
||||
if self.delete_task_on_finish:
|
||||
self.delete_task(task_id)
|
||||
|
||||
return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output)
|
||||
|
||||
def _chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Chat with step executor."""
|
||||
if chat_history is not None:
|
||||
self.memory.set(chat_history)
|
||||
task = self.create_task(message)
|
||||
|
||||
result_output = None
|
||||
while True:
|
||||
# pass step queue in as argument, assume step executor is stateless
|
||||
cur_step_output = self._run_step(
|
||||
task.task_id, mode=mode, tool_choice=tool_choice
|
||||
)
|
||||
|
||||
if cur_step_output.is_last:
|
||||
result_output = cur_step_output
|
||||
break
|
||||
|
||||
# ensure tool_choice does not cause endless loops
|
||||
tool_choice = "auto"
|
||||
|
||||
return self.finalize_response(task.task_id, result_output)
|
||||
|
||||
async def _achat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Chat with step executor."""
|
||||
if chat_history is not None:
|
||||
self.memory.set(chat_history)
|
||||
task = self.create_task(message)
|
||||
|
||||
result_output = None
|
||||
while True:
|
||||
# pass step queue in as argument, assume step executor is stateless
|
||||
cur_step_output = await self._arun_step(
|
||||
task.task_id, mode=mode, tool_choice=tool_choice
|
||||
)
|
||||
|
||||
if cur_step_output.is_last:
|
||||
result_output = cur_step_output
|
||||
break
|
||||
|
||||
# ensure tool_choice does not cause endless loops
|
||||
tool_choice = "auto"
|
||||
|
||||
return self.finalize_response(task.task_id, result_output)
|
||||
|
||||
@trace_method("chat")
|
||||
def chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
) -> AgentChatResponse:
|
||||
# override tool choice is provided as input.
|
||||
if tool_choice is None:
|
||||
tool_choice = self.default_tool_choice
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = self._chat(
|
||||
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
|
||||
)
|
||||
assert isinstance(chat_response, AgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
async def achat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
) -> AgentChatResponse:
|
||||
# override tool choice is provided as input.
|
||||
if tool_choice is None:
|
||||
tool_choice = self.default_tool_choice
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = await self._achat(
|
||||
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
|
||||
)
|
||||
assert isinstance(chat_response, AgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
def stream_chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
) -> StreamingAgentChatResponse:
|
||||
# override tool choice is provided as input.
|
||||
if tool_choice is None:
|
||||
tool_choice = self.default_tool_choice
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = self._chat(
|
||||
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
|
||||
)
|
||||
assert isinstance(chat_response, StreamingAgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
async def astream_chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
) -> StreamingAgentChatResponse:
|
||||
# override tool choice is provided as input.
|
||||
if tool_choice is None:
|
||||
tool_choice = self.default_tool_choice
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = await self._achat(
|
||||
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
|
||||
)
|
||||
assert isinstance(chat_response, StreamingAgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
def undo_step(self, task_id: str) -> None:
|
||||
"""Undo previous step."""
|
||||
raise NotImplementedError("undo_step not implemented")
|
||||
|
|
@ -0,0 +1,472 @@
|
|||
"""Agent executor."""
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Optional, Union, cast
|
||||
|
||||
from llama_index.agent.runner.base import BaseAgentRunner
|
||||
from llama_index.agent.types import (
|
||||
BaseAgentWorker,
|
||||
Task,
|
||||
TaskStep,
|
||||
TaskStepOutput,
|
||||
)
|
||||
from llama_index.bridge.pydantic import BaseModel, Field
|
||||
from llama_index.callbacks import (
|
||||
CallbackManager,
|
||||
CBEventType,
|
||||
EventPayload,
|
||||
trace_method,
|
||||
)
|
||||
from llama_index.chat_engine.types import (
|
||||
AGENT_CHAT_RESPONSE_TYPE,
|
||||
AgentChatResponse,
|
||||
ChatResponseMode,
|
||||
StreamingAgentChatResponse,
|
||||
)
|
||||
from llama_index.llms.base import ChatMessage
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.memory import BaseMemory, ChatMemoryBuffer
|
||||
from llama_index.memory.types import BaseMemory
|
||||
|
||||
|
||||
class DAGTaskState(BaseModel):
|
||||
"""DAG Task state."""
|
||||
|
||||
task: Task = Field(..., description="Task.")
|
||||
root_step: TaskStep = Field(..., description="Root step.")
|
||||
step_queue: Deque[TaskStep] = Field(
|
||||
default_factory=deque, description="Task step queue."
|
||||
)
|
||||
completed_steps: List[TaskStepOutput] = Field(
|
||||
default_factory=list, description="Completed step outputs."
|
||||
)
|
||||
|
||||
@property
|
||||
def task_id(self) -> str:
|
||||
"""Task id."""
|
||||
return self.task.task_id
|
||||
|
||||
|
||||
class DAGAgentState(BaseModel):
|
||||
"""Agent state."""
|
||||
|
||||
task_dict: Dict[str, DAGTaskState] = Field(
|
||||
default_factory=dict, description="Task dictionary."
|
||||
)
|
||||
|
||||
def get_task(self, task_id: str) -> Task:
|
||||
"""Get task state."""
|
||||
return self.task_dict[task_id].task
|
||||
|
||||
def get_completed_steps(self, task_id: str) -> List[TaskStepOutput]:
|
||||
"""Get completed steps."""
|
||||
return self.task_dict[task_id].completed_steps
|
||||
|
||||
def get_step_queue(self, task_id: str) -> Deque[TaskStep]:
|
||||
"""Get step queue."""
|
||||
return self.task_dict[task_id].step_queue
|
||||
|
||||
|
||||
class ParallelAgentRunner(BaseAgentRunner):
|
||||
"""Parallel agent runner.
|
||||
|
||||
Executes steps in queue in parallel. Requires async support.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_worker: BaseAgentWorker,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
state: Optional[DAGAgentState] = None,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
llm: Optional[LLM] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
init_task_state_kwargs: Optional[dict] = None,
|
||||
delete_task_on_finish: bool = False,
|
||||
) -> None:
|
||||
"""Initialize."""
|
||||
self.memory = memory or ChatMemoryBuffer.from_defaults(chat_history, llm=llm)
|
||||
self.state = state or DAGAgentState()
|
||||
self.callback_manager = callback_manager or CallbackManager([])
|
||||
self.init_task_state_kwargs = init_task_state_kwargs or {}
|
||||
self.agent_worker = agent_worker
|
||||
self.delete_task_on_finish = delete_task_on_finish
|
||||
|
||||
@property
|
||||
def chat_history(self) -> List[ChatMessage]:
|
||||
return self.memory.get_all()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.memory.reset()
|
||||
|
||||
def create_task(self, input: str, **kwargs: Any) -> Task:
|
||||
"""Create task."""
|
||||
task = Task(
|
||||
input=input,
|
||||
memory=self.memory,
|
||||
extra_state=self.init_task_state_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
# # put input into memory
|
||||
# self.memory.put(ChatMessage(content=input, role=MessageRole.USER))
|
||||
|
||||
# add it to state
|
||||
# get initial step from task, and put it in the step queue
|
||||
initial_step = self.agent_worker.initialize_step(task)
|
||||
task_state = DAGTaskState(
|
||||
task=task,
|
||||
root_step=initial_step,
|
||||
step_queue=deque([initial_step]),
|
||||
)
|
||||
|
||||
self.state.task_dict[task.task_id] = task_state
|
||||
|
||||
return task
|
||||
|
||||
def delete_task(
|
||||
self,
|
||||
task_id: str,
|
||||
) -> None:
|
||||
"""Delete task.
|
||||
|
||||
NOTE: this will not delete any previous executions from memory.
|
||||
|
||||
"""
|
||||
self.state.task_dict.pop(task_id)
|
||||
|
||||
def list_tasks(self, **kwargs: Any) -> List[Task]:
|
||||
"""List tasks."""
|
||||
task_states = list(self.state.task_dict.values())
|
||||
return [task_state.task for task_state in task_states]
|
||||
|
||||
def get_task(self, task_id: str, **kwargs: Any) -> Task:
|
||||
"""Get task."""
|
||||
return self.state.get_task(task_id)
|
||||
|
||||
def get_upcoming_steps(self, task_id: str, **kwargs: Any) -> List[TaskStep]:
|
||||
"""Get upcoming steps."""
|
||||
return list(self.state.get_step_queue(task_id))
|
||||
|
||||
def get_completed_steps(self, task_id: str, **kwargs: Any) -> List[TaskStepOutput]:
|
||||
"""Get completed steps."""
|
||||
return self.state.get_completed_steps(task_id)
|
||||
|
||||
def run_steps_in_queue(
|
||||
self,
|
||||
task_id: str,
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
**kwargs: Any,
|
||||
) -> List[TaskStepOutput]:
|
||||
"""Execute steps in queue.
|
||||
|
||||
Run all steps in queue, clearing it out.
|
||||
|
||||
Assume that all steps can be run in parallel.
|
||||
|
||||
"""
|
||||
return asyncio.run(self.arun_steps_in_queue(task_id, mode=mode, **kwargs))
|
||||
|
||||
async def arun_steps_in_queue(
|
||||
self,
|
||||
task_id: str,
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
**kwargs: Any,
|
||||
) -> List[TaskStepOutput]:
|
||||
"""Execute all steps in queue.
|
||||
|
||||
All steps in queue are assumed to be ready.
|
||||
|
||||
"""
|
||||
# first pop all steps from step_queue
|
||||
steps: List[TaskStep] = []
|
||||
while len(self.state.get_step_queue(task_id)) > 0:
|
||||
steps.append(self.state.get_step_queue(task_id).popleft())
|
||||
|
||||
# take every item in the queue, and run it
|
||||
tasks = []
|
||||
for step in steps:
|
||||
tasks.append(self._arun_step(task_id, step=step, mode=mode, **kwargs))
|
||||
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def _run_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step: Optional[TaskStep] = None,
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Execute step."""
|
||||
task = self.state.get_task(task_id)
|
||||
task_queue = self.state.get_step_queue(task_id)
|
||||
step = step or task_queue.popleft()
|
||||
|
||||
if not step.is_ready:
|
||||
raise ValueError(f"Step {step.step_id} is not ready")
|
||||
|
||||
if mode == ChatResponseMode.WAIT:
|
||||
cur_step_output: TaskStepOutput = self.agent_worker.run_step(
|
||||
step, task, **kwargs
|
||||
)
|
||||
elif mode == ChatResponseMode.STREAM:
|
||||
cur_step_output = self.agent_worker.stream_step(step, task, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
for next_step in cur_step_output.next_steps:
|
||||
if next_step.is_ready:
|
||||
task_queue.append(next_step)
|
||||
|
||||
# add cur_step_output to completed steps
|
||||
completed_steps = self.state.get_completed_steps(task_id)
|
||||
completed_steps.append(cur_step_output)
|
||||
|
||||
return cur_step_output
|
||||
|
||||
async def _arun_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step: Optional[TaskStep] = None,
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Execute step."""
|
||||
task = self.state.get_task(task_id)
|
||||
task_queue = self.state.get_step_queue(task_id)
|
||||
step = step or task_queue.popleft()
|
||||
|
||||
if not step.is_ready:
|
||||
raise ValueError(f"Step {step.step_id} is not ready")
|
||||
|
||||
if mode == ChatResponseMode.WAIT:
|
||||
cur_step_output = await self.agent_worker.arun_step(step, task, **kwargs)
|
||||
elif mode == ChatResponseMode.STREAM:
|
||||
cur_step_output = await self.agent_worker.astream_step(step, task, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
for next_step in cur_step_output.next_steps:
|
||||
if next_step.is_ready:
|
||||
task_queue.append(next_step)
|
||||
|
||||
# add cur_step_output to completed steps
|
||||
completed_steps = self.state.get_completed_steps(task_id)
|
||||
completed_steps.append(cur_step_output)
|
||||
|
||||
return cur_step_output
|
||||
|
||||
def run_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: Optional[str] = None,
|
||||
step: Optional[TaskStep] = None,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
return self._run_step(task_id, step, mode=ChatResponseMode.WAIT, **kwargs)
|
||||
|
||||
async def arun_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: Optional[str] = None,
|
||||
step: Optional[TaskStep] = None,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async)."""
|
||||
return await self._arun_step(
|
||||
task_id, step, mode=ChatResponseMode.WAIT, **kwargs
|
||||
)
|
||||
|
||||
def stream_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: Optional[str] = None,
|
||||
step: Optional[TaskStep] = None,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (stream)."""
|
||||
return self._run_step(task_id, step, mode=ChatResponseMode.STREAM, **kwargs)
|
||||
|
||||
async def astream_step(
|
||||
self,
|
||||
task_id: str,
|
||||
input: Optional[str] = None,
|
||||
step: Optional[TaskStep] = None,
|
||||
**kwargs: Any,
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async stream)."""
|
||||
return await self._arun_step(
|
||||
task_id, step, mode=ChatResponseMode.STREAM, **kwargs
|
||||
)
|
||||
|
||||
def finalize_response(
|
||||
self,
|
||||
task_id: str,
|
||||
step_output: Optional[TaskStepOutput] = None,
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Finalize response."""
|
||||
if step_output is None:
|
||||
step_output = self.state.get_completed_steps(task_id)[-1]
|
||||
if not step_output.is_last:
|
||||
raise ValueError(
|
||||
"finalize_response can only be called on the last step output"
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
step_output.output,
|
||||
(AgentChatResponse, StreamingAgentChatResponse),
|
||||
):
|
||||
raise ValueError(
|
||||
"When `is_last` is True, cur_step_output.output must be "
|
||||
f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}"
|
||||
)
|
||||
|
||||
# finalize task
|
||||
self.agent_worker.finalize_task(self.state.get_task(task_id))
|
||||
|
||||
if self.delete_task_on_finish:
|
||||
self.delete_task(task_id)
|
||||
|
||||
return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output)
|
||||
|
||||
def _chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Chat with step executor."""
|
||||
if chat_history is not None:
|
||||
self.memory.set(chat_history)
|
||||
task = self.create_task(message)
|
||||
|
||||
result_output = None
|
||||
while True:
|
||||
# pass step queue in as argument, assume step executor is stateless
|
||||
cur_step_outputs = self.run_steps_in_queue(task.task_id, mode=mode)
|
||||
|
||||
# check if a step output is_last
|
||||
is_last = any(
|
||||
cur_step_output.is_last for cur_step_output in cur_step_outputs
|
||||
)
|
||||
if is_last:
|
||||
if len(cur_step_outputs) > 1:
|
||||
raise ValueError(
|
||||
"More than one step output returned in final step."
|
||||
)
|
||||
cur_step_output = cur_step_outputs[0]
|
||||
result_output = cur_step_output
|
||||
break
|
||||
|
||||
return self.finalize_response(task.task_id, result_output)
|
||||
|
||||
async def _achat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
mode: ChatResponseMode = ChatResponseMode.WAIT,
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Chat with step executor."""
|
||||
if chat_history is not None:
|
||||
self.memory.set(chat_history)
|
||||
task = self.create_task(message)
|
||||
|
||||
result_output = None
|
||||
while True:
|
||||
# pass step queue in as argument, assume step executor is stateless
|
||||
cur_step_outputs = await self.arun_steps_in_queue(task.task_id, mode=mode)
|
||||
|
||||
# check if a step output is_last
|
||||
is_last = any(
|
||||
cur_step_output.is_last for cur_step_output in cur_step_outputs
|
||||
)
|
||||
if is_last:
|
||||
if len(cur_step_outputs) > 1:
|
||||
raise ValueError(
|
||||
"More than one step output returned in final step."
|
||||
)
|
||||
cur_step_output = cur_step_outputs[0]
|
||||
result_output = cur_step_output
|
||||
break
|
||||
|
||||
return self.finalize_response(task.task_id, result_output)
|
||||
|
||||
@trace_method("chat")
|
||||
def chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
) -> AgentChatResponse:
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = self._chat(
|
||||
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
|
||||
)
|
||||
assert isinstance(chat_response, AgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
async def achat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
) -> AgentChatResponse:
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = await self._achat(
|
||||
message, chat_history, tool_choice, mode=ChatResponseMode.WAIT
|
||||
)
|
||||
assert isinstance(chat_response, AgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
def stream_chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
) -> StreamingAgentChatResponse:
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = self._chat(
|
||||
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
|
||||
)
|
||||
assert isinstance(chat_response, StreamingAgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
async def astream_chat(
|
||||
self,
|
||||
message: str,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
tool_choice: Union[str, dict] = "auto",
|
||||
) -> StreamingAgentChatResponse:
|
||||
with self.callback_manager.event(
|
||||
CBEventType.AGENT_STEP,
|
||||
payload={EventPayload.MESSAGES: [message]},
|
||||
) as e:
|
||||
chat_response = await self._achat(
|
||||
message, chat_history, tool_choice, mode=ChatResponseMode.STREAM
|
||||
)
|
||||
assert isinstance(chat_response, StreamingAgentChatResponse)
|
||||
e.on_end(payload={EventPayload.RESPONSE: chat_response})
|
||||
return chat_response
|
||||
|
||||
def undo_step(self, task_id: str) -> None:
|
||||
"""Undo previous step."""
|
||||
raise NotImplementedError("undo_step not implemented")
|
||||
|
|
@ -0,0 +1,235 @@
|
|||
"""Base agent type."""
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_index.bridge.pydantic import BaseModel, Field
|
||||
from llama_index.callbacks import CallbackManager, trace_method
|
||||
from llama_index.chat_engine.types import BaseChatEngine, StreamingAgentChatResponse
|
||||
from llama_index.core.base_query_engine import BaseQueryEngine
|
||||
from llama_index.core.llms.types import ChatMessage
|
||||
from llama_index.core.response.schema import RESPONSE_TYPE, Response
|
||||
from llama_index.memory.types import BaseMemory
|
||||
from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType
|
||||
from llama_index.schema import QueryBundle
|
||||
|
||||
|
||||
class BaseAgent(BaseChatEngine, BaseQueryEngine):
|
||||
"""Base Agent."""
|
||||
|
||||
def _get_prompts(self) -> PromptDictType:
|
||||
"""Get prompts."""
|
||||
# TODO: the ReAct agent does not explicitly specify prompts, would need a
|
||||
# refactor to expose those prompts
|
||||
return {}
|
||||
|
||||
def _get_prompt_modules(self) -> PromptMixinType:
|
||||
"""Get prompt modules."""
|
||||
return {}
|
||||
|
||||
def _update_prompts(self, prompts: PromptDictType) -> None:
|
||||
"""Update prompts."""
|
||||
|
||||
# ===== Query Engine Interface =====
|
||||
@trace_method("query")
|
||||
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
||||
agent_response = self.chat(
|
||||
query_bundle.query_str,
|
||||
chat_history=[],
|
||||
)
|
||||
return Response(
|
||||
response=str(agent_response), source_nodes=agent_response.source_nodes
|
||||
)
|
||||
|
||||
@trace_method("query")
|
||||
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
||||
agent_response = await self.achat(
|
||||
query_bundle.query_str,
|
||||
chat_history=[],
|
||||
)
|
||||
return Response(
|
||||
response=str(agent_response), source_nodes=agent_response.source_nodes
|
||||
)
|
||||
|
||||
def stream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
raise NotImplementedError("stream_chat not implemented")
|
||||
|
||||
async def astream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
raise NotImplementedError("astream_chat not implemented")
|
||||
|
||||
|
||||
class TaskStep(BaseModel):
|
||||
"""Agent task step.
|
||||
|
||||
Represents a single input step within the execution run ("Task") of an agent
|
||||
given a user input.
|
||||
|
||||
The output is returned as a `TaskStepOutput`.
|
||||
|
||||
"""
|
||||
|
||||
task_id: str = Field(..., diescription="Task ID")
|
||||
step_id: str = Field(..., description="Step ID")
|
||||
input: Optional[str] = Field(default=None, description="User input")
|
||||
# memory: BaseMemory = Field(
|
||||
# ..., type=BaseMemory, description="Conversational Memory"
|
||||
# )
|
||||
step_state: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional state for a given step."
|
||||
)
|
||||
|
||||
# NOTE: the state below may change throughout the course of execution
|
||||
# this tracks the relationships to other steps
|
||||
next_steps: Dict[str, "TaskStep"] = Field(
|
||||
default_factory=dict, description="Next steps to be executed."
|
||||
)
|
||||
prev_steps: Dict[str, "TaskStep"] = Field(
|
||||
default_factory=dict,
|
||||
description="Previous steps that were dependencies for this step.",
|
||||
)
|
||||
is_ready: bool = Field(
|
||||
default=True, description="Is this step ready to be executed?"
|
||||
)
|
||||
|
||||
def get_next_step(
|
||||
self,
|
||||
step_id: str,
|
||||
input: Optional[str] = None,
|
||||
step_state: Optional[Dict[str, Any]] = None,
|
||||
) -> "TaskStep":
|
||||
"""Convenience function to get next step.
|
||||
|
||||
Preserve task_id, memory, step_state.
|
||||
|
||||
"""
|
||||
return TaskStep(
|
||||
task_id=self.task_id,
|
||||
step_id=step_id,
|
||||
input=input,
|
||||
# memory=self.memory,
|
||||
step_state=step_state or self.step_state,
|
||||
)
|
||||
|
||||
def link_step(
|
||||
self,
|
||||
next_step: "TaskStep",
|
||||
) -> None:
|
||||
"""Link to next step.
|
||||
|
||||
Add link from this step to next, and from next step to current.
|
||||
|
||||
"""
|
||||
self.next_steps[next_step.step_id] = next_step
|
||||
next_step.prev_steps[self.step_id] = self
|
||||
|
||||
|
||||
class TaskStepOutput(BaseModel):
|
||||
"""Agent task step output."""
|
||||
|
||||
output: Any = Field(..., description="Task step output")
|
||||
task_step: TaskStep = Field(..., description="Task step input")
|
||||
next_steps: List[TaskStep] = Field(..., description="Next steps to be executed.")
|
||||
is_last: bool = Field(default=False, description="Is this the last step?")
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation."""
|
||||
return str(self.output)
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
"""Agent Task.
|
||||
|
||||
Represents a "run" of an agent given a user input.
|
||||
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
task_id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), type=str, description="Task ID"
|
||||
)
|
||||
input: str = Field(..., type=str, description="User input")
|
||||
|
||||
# NOTE: this is state that may be modified throughout the course of execution of the task
|
||||
memory: BaseMemory = Field(
|
||||
...,
|
||||
type=BaseMemory,
|
||||
description=(
|
||||
"Conversational Memory. Maintains state before execution of this task."
|
||||
),
|
||||
)
|
||||
|
||||
callback_manager: CallbackManager = Field(
|
||||
default_factory=CallbackManager,
|
||||
exclude=True,
|
||||
description="Callback manager for the task.",
|
||||
)
|
||||
|
||||
extra_state: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description=(
|
||||
"Additional user-specified state for a given task. "
|
||||
"Can be modified throughout the execution of a task."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class BaseAgentWorker(PromptMixin):
|
||||
"""Base agent worker."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_prompts(self) -> PromptDictType:
|
||||
"""Get prompts."""
|
||||
# TODO: the ReAct agent does not explicitly specify prompts, would need a
|
||||
# refactor to expose those prompts
|
||||
return {}
|
||||
|
||||
def _get_prompt_modules(self) -> PromptMixinType:
|
||||
"""Get prompt modules."""
|
||||
return {}
|
||||
|
||||
def _update_prompts(self, prompts: PromptDictType) -> None:
|
||||
"""Update prompts."""
|
||||
|
||||
@abstractmethod
|
||||
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
|
||||
"""Initialize step from task."""
|
||||
|
||||
@abstractmethod
|
||||
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
||||
"""Run step."""
|
||||
|
||||
@abstractmethod
|
||||
async def arun_step(
|
||||
self, step: TaskStep, task: Task, **kwargs: Any
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async)."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
|
||||
"""Run step (stream)."""
|
||||
# TODO: figure out if we need a different type for TaskStepOutput
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def astream_step(
|
||||
self, step: TaskStep, task: Task, **kwargs: Any
|
||||
) -> TaskStepOutput:
|
||||
"""Run step (async stream)."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def finalize_task(self, task: Task, **kwargs: Any) -> None:
|
||||
"""Finalize task, after all the steps are completed."""
|
||||
|
||||
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
||||
"""Set callback manager."""
|
||||
# TODO: make this abstractmethod (right now will break some agent impls)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
"""Agent utils."""
|
||||
|
||||
|
||||
from llama_index.agent.types import TaskStep
|
||||
from llama_index.core.llms.types import MessageRole
|
||||
from llama_index.llms.base import ChatMessage
|
||||
from llama_index.memory import BaseMemory
|
||||
|
||||
|
||||
def add_user_step_to_memory(
|
||||
step: TaskStep, memory: BaseMemory, verbose: bool = False
|
||||
) -> None:
|
||||
"""Add user step to memory."""
|
||||
user_message = ChatMessage(content=step.input, role=MessageRole.USER)
|
||||
memory.put(user_message)
|
||||
if verbose:
|
||||
print(f"Added user message to memory: {step.input}")
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
"""Async utils."""
|
||||
import asyncio
|
||||
from itertools import zip_longest
|
||||
from typing import Any, Coroutine, Iterable, List
|
||||
|
||||
|
||||
def asyncio_module(show_progress: bool = False) -> Any:
|
||||
if show_progress:
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
module = tqdm_asyncio
|
||||
else:
|
||||
module = asyncio
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def run_async_tasks(
|
||||
tasks: List[Coroutine],
|
||||
show_progress: bool = False,
|
||||
progress_bar_desc: str = "Running async tasks",
|
||||
) -> List[Any]:
|
||||
"""Run a list of async tasks."""
|
||||
tasks_to_execute: List[Any] = tasks
|
||||
if show_progress:
|
||||
try:
|
||||
import nest_asyncio
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
# jupyter notebooks already have an event loop running
|
||||
# we need to reuse it instead of creating a new one
|
||||
nest_asyncio.apply()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def _tqdm_gather() -> List[Any]:
|
||||
return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc)
|
||||
|
||||
tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather())
|
||||
return tqdm_outputs
|
||||
# run the operation w/o tqdm on hitting a fatal
|
||||
# may occur in some environments where tqdm.asyncio
|
||||
# is not supported
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _gather() -> List[Any]:
|
||||
return await asyncio.gather(*tasks_to_execute)
|
||||
|
||||
outputs: List[Any] = asyncio.run(_gather())
|
||||
return outputs
|
||||
|
||||
|
||||
def chunks(iterable: Iterable, size: int) -> Iterable:
|
||||
args = [iter(iterable)] * size
|
||||
return zip_longest(*args, fillvalue=None)
|
||||
|
||||
|
||||
async def batch_gather(
|
||||
tasks: List[Coroutine], batch_size: int = 10, verbose: bool = False
|
||||
) -> List[Any]:
|
||||
output: List[Any] = []
|
||||
for task_chunk in chunks(tasks, batch_size):
|
||||
output_chunk = await asyncio.gather(*task_chunk)
|
||||
output.extend(output_chunk)
|
||||
if verbose:
|
||||
print(f"Completed {len(output)} out of {len(tasks)} tasks")
|
||||
return output
|
||||
|
||||
|
||||
def get_asyncio_module(show_progress: bool = False) -> Any:
|
||||
if show_progress:
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
module = tqdm_asyncio
|
||||
else:
|
||||
module = asyncio
|
||||
|
||||
return module
|
||||
|
||||
|
||||
DEFAULT_NUM_WORKERS = 4
|
||||
|
||||
|
||||
async def run_jobs(
|
||||
jobs: List[Coroutine],
|
||||
show_progress: bool = False,
|
||||
workers: int = DEFAULT_NUM_WORKERS,
|
||||
) -> List[Any]:
|
||||
"""Run jobs.
|
||||
|
||||
Args:
|
||||
jobs (List[Coroutine]):
|
||||
List of jobs to run.
|
||||
show_progress (bool):
|
||||
Whether to show progress bar.
|
||||
|
||||
Returns:
|
||||
List[Any]:
|
||||
List of results.
|
||||
"""
|
||||
asyncio_mod = get_asyncio_module(show_progress=show_progress)
|
||||
semaphore = asyncio.Semaphore(workers)
|
||||
|
||||
async def worker(job: Coroutine) -> Any:
|
||||
async with semaphore:
|
||||
return await job
|
||||
|
||||
pool_jobs = [worker(job) for job in jobs]
|
||||
|
||||
return await asyncio_mod.gather(*pool_jobs)
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
import langchain
|
||||
from langchain.agents import AgentExecutor, AgentType, initialize_agent
|
||||
|
||||
# agents and tools
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
||||
# callback
|
||||
from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
|
||||
from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
|
||||
|
||||
# chat and memory
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.output_parsers import ResponseSchema
|
||||
|
||||
# prompts
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.chat import (
|
||||
AIMessagePromptTemplate,
|
||||
BaseMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
|
||||
# schema
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMemory,
|
||||
BaseMessage,
|
||||
BaseOutputParser,
|
||||
ChatGeneration,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
# embeddings
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
|
||||
# input & output
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain.tools import BaseTool, StructuredTool, Tool
|
||||
from langchain_community.chat_models import ChatAnyscale, ChatOpenAI
|
||||
from langchain_community.embeddings import (
|
||||
HuggingFaceBgeEmbeddings,
|
||||
HuggingFaceEmbeddings,
|
||||
)
|
||||
|
||||
# LLMs
|
||||
from langchain_community.llms import AI21, BaseLLM, Cohere, FakeListLLM, OpenAI
|
||||
|
||||
__all__ = [
|
||||
"langchain",
|
||||
"BaseLLM",
|
||||
"FakeListLLM",
|
||||
"OpenAI",
|
||||
"AI21",
|
||||
"Cohere",
|
||||
"BaseChatModel",
|
||||
"ChatAnyscale",
|
||||
"ChatOpenAI",
|
||||
"BaseLanguageModel",
|
||||
"Embeddings",
|
||||
"HuggingFaceEmbeddings",
|
||||
"HuggingFaceBgeEmbeddings",
|
||||
"PromptTemplate",
|
||||
"BasePromptTemplate",
|
||||
"ConditionalPromptSelector",
|
||||
"is_chat_model",
|
||||
"AIMessagePromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
"HumanMessagePromptTemplate",
|
||||
"BaseMessagePromptTemplate",
|
||||
"SystemMessagePromptTemplate",
|
||||
"BaseChatMemory",
|
||||
"ConversationBufferMemory",
|
||||
"ChatMessageHistory",
|
||||
"BaseToolkit",
|
||||
"AgentType",
|
||||
"AgentExecutor",
|
||||
"initialize_agent",
|
||||
"StructuredTool",
|
||||
"Tool",
|
||||
"BaseTool",
|
||||
"ResponseSchema",
|
||||
"BaseCallbackHandler",
|
||||
"BaseCallbackManager",
|
||||
"AIMessage",
|
||||
"FunctionMessage",
|
||||
"BaseMessage",
|
||||
"ChatMessage",
|
||||
"HumanMessage",
|
||||
"SystemMessage",
|
||||
"BaseMemory",
|
||||
"BaseOutputParser",
|
||||
"LLMResult",
|
||||
"ChatGeneration",
|
||||
"Document",
|
||||
"RecursiveCharacterTextSplitter",
|
||||
"TextSplitter",
|
||||
]
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
try:
|
||||
import pydantic.v1 as pydantic
|
||||
from pydantic.v1 import (
|
||||
BaseConfig,
|
||||
BaseModel,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
StrictFloat,
|
||||
StrictInt,
|
||||
StrictStr,
|
||||
create_model,
|
||||
root_validator,
|
||||
validator,
|
||||
)
|
||||
from pydantic.v1.error_wrappers import ValidationError
|
||||
from pydantic.v1.fields import FieldInfo
|
||||
from pydantic.v1.generics import GenericModel
|
||||
except ImportError:
|
||||
import pydantic # type: ignore
|
||||
from pydantic import (
|
||||
BaseConfig,
|
||||
BaseModel,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
StrictFloat,
|
||||
StrictInt,
|
||||
StrictStr,
|
||||
create_model,
|
||||
root_validator,
|
||||
validator,
|
||||
)
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
__all__ = [
|
||||
"pydantic",
|
||||
"BaseModel",
|
||||
"Field",
|
||||
"PrivateAttr",
|
||||
"root_validator",
|
||||
"validator",
|
||||
"create_model",
|
||||
"StrictFloat",
|
||||
"StrictInt",
|
||||
"StrictStr",
|
||||
"FieldInfo",
|
||||
"ValidationError",
|
||||
"GenericModel",
|
||||
"BaseConfig",
|
||||
]
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
from .aim import AimCallback
|
||||
from .base import CallbackManager
|
||||
from .finetuning_handler import GradientAIFineTuningHandler, OpenAIFineTuningHandler
|
||||
from .llama_debug import LlamaDebugHandler
|
||||
from .open_inference_callback import OpenInferenceCallbackHandler
|
||||
from .schema import CBEvent, CBEventType, EventPayload
|
||||
from .token_counting import TokenCountingHandler
|
||||
from .utils import trace_method
|
||||
from .wandb_callback import WandbCallbackHandler
|
||||
|
||||
__all__ = [
|
||||
"OpenInferenceCallbackHandler",
|
||||
"CallbackManager",
|
||||
"CBEvent",
|
||||
"CBEventType",
|
||||
"EventPayload",
|
||||
"LlamaDebugHandler",
|
||||
"AimCallback",
|
||||
"WandbCallbackHandler",
|
||||
"TokenCountingHandler",
|
||||
"OpenAIFineTuningHandler",
|
||||
"GradientAIFineTuningHandler",
|
||||
"trace_method",
|
||||
]
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from aim import Run, Text
|
||||
except ModuleNotFoundError:
|
||||
Run, Text = None, None
|
||||
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
from llama_index.callbacks.schema import CBEventType, EventPayload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class AimCallback(BaseCallbackHandler):
|
||||
"""
|
||||
AimCallback callback class.
|
||||
|
||||
Args:
|
||||
repo (:obj:`str`, optional):
|
||||
Aim repository path or Repo object to which Run object is bound.
|
||||
If skipped, default Repo is used.
|
||||
experiment_name (:obj:`str`, optional):
|
||||
Sets Run's `experiment` property. 'default' if not specified.
|
||||
Can be used later to query runs/sequences.
|
||||
system_tracking_interval (:obj:`int`, optional):
|
||||
Sets the tracking interval in seconds for system usage
|
||||
metrics (CPU, Memory, etc.). Set to `None` to disable
|
||||
system metrics tracking.
|
||||
log_system_params (:obj:`bool`, optional):
|
||||
Enable/Disable logging of system params such as installed packages,
|
||||
git info, environment variables, etc.
|
||||
capture_terminal_logs (:obj:`bool`, optional):
|
||||
Enable/Disable terminal stdout logging.
|
||||
event_starts_to_ignore (Optional[List[CBEventType]]):
|
||||
list of event types to ignore when tracking event starts.
|
||||
event_ends_to_ignore (Optional[List[CBEventType]]):
|
||||
list of event types to ignore when tracking event ends.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo: Optional[str] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
system_tracking_interval: Optional[int] = 1,
|
||||
log_system_params: Optional[bool] = True,
|
||||
capture_terminal_logs: Optional[bool] = True,
|
||||
event_starts_to_ignore: Optional[List[CBEventType]] = None,
|
||||
event_ends_to_ignore: Optional[List[CBEventType]] = None,
|
||||
run_params: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if Run is None:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install aim to use the AimCallback: 'pip install aim'"
|
||||
)
|
||||
|
||||
event_starts_to_ignore = (
|
||||
event_starts_to_ignore if event_starts_to_ignore else []
|
||||
)
|
||||
event_ends_to_ignore = event_ends_to_ignore if event_ends_to_ignore else []
|
||||
super().__init__(
|
||||
event_starts_to_ignore=event_starts_to_ignore,
|
||||
event_ends_to_ignore=event_ends_to_ignore,
|
||||
)
|
||||
|
||||
self.repo = repo
|
||||
self.experiment_name = experiment_name
|
||||
self.system_tracking_interval = system_tracking_interval
|
||||
self.log_system_params = log_system_params
|
||||
self.capture_terminal_logs = capture_terminal_logs
|
||||
self._run: Optional[Any] = None
|
||||
self._run_hash = None
|
||||
|
||||
self._llm_response_step = 0
|
||||
|
||||
self.setup(run_params)
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
parent_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
event_type (CBEventType): event type to store.
|
||||
payload (Optional[Dict[str, Any]]): payload to store.
|
||||
event_id (str): event id to store.
|
||||
parent_id (str): parent event id.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
event_type (CBEventType): event type to store.
|
||||
payload (Optional[Dict[str, Any]]): payload to store.
|
||||
event_id (str): event id to store.
|
||||
"""
|
||||
if not self._run:
|
||||
raise ValueError("AimCallback failed to init properly.")
|
||||
|
||||
if event_type is CBEventType.LLM and payload:
|
||||
if EventPayload.PROMPT in payload:
|
||||
llm_input = str(payload[EventPayload.PROMPT])
|
||||
llm_output = str(payload[EventPayload.COMPLETION])
|
||||
else:
|
||||
message = payload.get(EventPayload.MESSAGES, [])
|
||||
llm_input = "\n".join([str(x) for x in message])
|
||||
llm_output = str(payload[EventPayload.RESPONSE])
|
||||
|
||||
self._run.track(
|
||||
Text(llm_input),
|
||||
name="prompt",
|
||||
step=self._llm_response_step,
|
||||
context={"event_id": event_id},
|
||||
)
|
||||
|
||||
self._run.track(
|
||||
Text(llm_output),
|
||||
name="response",
|
||||
step=self._llm_response_step,
|
||||
context={"event_id": event_id},
|
||||
)
|
||||
|
||||
self._llm_response_step += 1
|
||||
elif event_type is CBEventType.CHUNKING and payload:
|
||||
for chunk_id, chunk in enumerate(payload[EventPayload.CHUNKS]):
|
||||
self._run.track(
|
||||
Text(chunk),
|
||||
name="chunk",
|
||||
step=self._llm_response_step,
|
||||
context={"chunk_id": chunk_id, "event_id": event_id},
|
||||
)
|
||||
|
||||
@property
|
||||
def experiment(self) -> Run:
|
||||
if not self._run:
|
||||
self.setup()
|
||||
return self._run
|
||||
|
||||
def setup(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
if not self._run:
|
||||
if self._run_hash:
|
||||
self._run = Run(
|
||||
self._run_hash,
|
||||
repo=self.repo,
|
||||
system_tracking_interval=self.system_tracking_interval,
|
||||
log_system_params=self.log_system_params,
|
||||
capture_terminal_logs=self.capture_terminal_logs,
|
||||
)
|
||||
else:
|
||||
self._run = Run(
|
||||
repo=self.repo,
|
||||
experiment=self.experiment_name,
|
||||
system_tracking_interval=self.system_tracking_interval,
|
||||
log_system_params=self.log_system_params,
|
||||
capture_terminal_logs=self.capture_terminal_logs,
|
||||
)
|
||||
self._run_hash = self._run.hash
|
||||
|
||||
# Log config parameters
|
||||
if args:
|
||||
try:
|
||||
for key in args:
|
||||
self._run.set(key, args[key], strict=False)
|
||||
except Exception as e:
|
||||
logger.warning(f"Aim could not log config parameters -> {e}")
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self._run and self._run.active:
|
||||
self._run.close()
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
pass
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
|
||||
|
||||
def argilla_callback_handler(**kwargs: Any) -> BaseCallbackHandler:
|
||||
try:
|
||||
# lazy import
|
||||
from argilla_llama_index import ArgillaCallbackHandler
|
||||
except ImportError:
|
||||
raise ImportError("Please install Argilla with `pip install argilla`")
|
||||
return ArgillaCallbackHandler(**kwargs)
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
|
||||
|
||||
def arize_phoenix_callback_handler(**kwargs: Any) -> BaseCallbackHandler:
|
||||
try:
|
||||
from phoenix.trace.llama_index import OpenInferenceTraceCallbackHandler
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install Arize Phoenix with `pip install -q arize-phoenix`"
|
||||
)
|
||||
return OpenInferenceTraceCallbackHandler(**kwargs)
|
||||
|
|
@ -0,0 +1,274 @@
|
|||
import logging
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
from llama_index.callbacks.schema import (
|
||||
BASE_TRACE_EVENT,
|
||||
LEAF_EVENTS,
|
||||
CBEventType,
|
||||
EventPayload,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
global_stack_trace = ContextVar("trace", default=[BASE_TRACE_EVENT])
|
||||
empty_trace_ids: List[str] = []
|
||||
global_stack_trace_ids = ContextVar("trace_ids", default=empty_trace_ids)
|
||||
|
||||
|
||||
class CallbackManager(BaseCallbackHandler, ABC):
|
||||
"""
|
||||
Callback manager that handles callbacks for events within LlamaIndex.
|
||||
|
||||
The callback manager provides a way to call handlers on event starts/ends.
|
||||
|
||||
Additionally, the callback manager traces the current stack of events.
|
||||
It does this by using a few key attributes.
|
||||
- trace_stack - The current stack of events that have not ended yet.
|
||||
When an event ends, it's removed from the stack.
|
||||
Since this is a contextvar, it is unique to each
|
||||
thread/task.
|
||||
- trace_map - A mapping of event ids to their children events.
|
||||
On the start of events, the bottom of the trace stack
|
||||
is used as the current parent event for the trace map.
|
||||
- trace_id - A simple name for the current trace, usually denoting the
|
||||
entrypoint (query, index_construction, insert, etc.)
|
||||
|
||||
Args:
|
||||
handlers (List[BaseCallbackHandler]): list of handlers to use.
|
||||
|
||||
Usage:
|
||||
with callback_manager.event(CBEventType.QUERY) as event:
|
||||
event.on_start(payload={key, val})
|
||||
...
|
||||
event.on_end(payload={key, val})
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, handlers: Optional[List[BaseCallbackHandler]] = None):
|
||||
"""Initialize the manager with a list of handlers."""
|
||||
from llama_index import global_handler
|
||||
|
||||
handlers = handlers or []
|
||||
|
||||
# add eval handlers based on global defaults
|
||||
if global_handler is not None:
|
||||
new_handler = global_handler
|
||||
# go through existing handlers, check if any are same type as new handler
|
||||
# if so, error
|
||||
for existing_handler in handlers:
|
||||
if isinstance(existing_handler, type(new_handler)):
|
||||
raise ValueError(
|
||||
"Cannot add two handlers of the same type "
|
||||
f"{type(new_handler)} to the callback manager."
|
||||
)
|
||||
handlers.append(new_handler)
|
||||
|
||||
self.handlers = handlers
|
||||
self._trace_map: Dict[str, List[str]] = defaultdict(list)
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: Optional[str] = None,
|
||||
parent_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run handlers when an event starts and return id of event."""
|
||||
event_id = event_id or str(uuid.uuid4())
|
||||
|
||||
# if no trace is running, start a default trace
|
||||
try:
|
||||
parent_id = parent_id or global_stack_trace.get()[-1]
|
||||
except IndexError:
|
||||
self.start_trace("llama-index")
|
||||
parent_id = global_stack_trace.get()[-1]
|
||||
|
||||
self._trace_map[parent_id].append(event_id)
|
||||
for handler in self.handlers:
|
||||
if event_type not in handler.event_starts_to_ignore:
|
||||
handler.on_event_start(
|
||||
event_type,
|
||||
payload,
|
||||
event_id=event_id,
|
||||
parent_id=parent_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if event_type not in LEAF_EVENTS:
|
||||
# copy the stack trace to prevent conflicts with threads/coroutines
|
||||
current_trace_stack = global_stack_trace.get().copy()
|
||||
current_trace_stack.append(event_id)
|
||||
global_stack_trace.set(current_trace_stack)
|
||||
|
||||
return event_id
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run handlers when an event ends."""
|
||||
event_id = event_id or str(uuid.uuid4())
|
||||
for handler in self.handlers:
|
||||
if event_type not in handler.event_ends_to_ignore:
|
||||
handler.on_event_end(event_type, payload, event_id=event_id, **kwargs)
|
||||
|
||||
if event_type not in LEAF_EVENTS:
|
||||
# copy the stack trace to prevent conflicts with threads/coroutines
|
||||
current_trace_stack = global_stack_trace.get().copy()
|
||||
current_trace_stack.pop()
|
||||
global_stack_trace.set(current_trace_stack)
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Remove a handler from the callback manager."""
|
||||
self.handlers.remove(handler)
|
||||
|
||||
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager."""
|
||||
self.handlers = handlers
|
||||
|
||||
@contextmanager
|
||||
def event(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: Optional[str] = None,
|
||||
) -> Generator["EventContext", None, None]:
|
||||
"""Context manager for lanching and shutdown of events.
|
||||
|
||||
Handles sending on_evnt_start and on_event_end to handlers for specified event.
|
||||
|
||||
Usage:
|
||||
with callback_manager.event(CBEventType.QUERY, payload={key, val}) as event:
|
||||
...
|
||||
event.on_end(payload={key, val}) # optional
|
||||
"""
|
||||
# create event context wrapper
|
||||
event = EventContext(self, event_type, event_id=event_id)
|
||||
event.on_start(payload=payload)
|
||||
|
||||
payload = None
|
||||
try:
|
||||
yield event
|
||||
except Exception as e:
|
||||
# data already logged to trace?
|
||||
if not hasattr(e, "event_added"):
|
||||
payload = {EventPayload.EXCEPTION: e}
|
||||
e.event_added = True # type: ignore
|
||||
if not event.finished:
|
||||
event.on_end(payload=payload)
|
||||
raise
|
||||
finally:
|
||||
# ensure event is ended
|
||||
if not event.finished:
|
||||
event.on_end(payload=payload)
|
||||
|
||||
@contextmanager
|
||||
def as_trace(self, trace_id: str) -> Generator[None, None, None]:
|
||||
"""Context manager tracer for lanching and shutdown of traces."""
|
||||
self.start_trace(trace_id=trace_id)
|
||||
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
# event already added to trace?
|
||||
if not hasattr(e, "event_added"):
|
||||
self.on_event_start(
|
||||
CBEventType.EXCEPTION, payload={EventPayload.EXCEPTION: e}
|
||||
)
|
||||
e.event_added = True # type: ignore
|
||||
|
||||
raise
|
||||
finally:
|
||||
# ensure trace is ended
|
||||
self.end_trace(trace_id=trace_id)
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""Run when an overall trace is launched."""
|
||||
current_trace_stack_ids = global_stack_trace_ids.get().copy()
|
||||
if trace_id is not None:
|
||||
if len(current_trace_stack_ids) == 0:
|
||||
self._reset_trace_events()
|
||||
|
||||
for handler in self.handlers:
|
||||
handler.start_trace(trace_id=trace_id)
|
||||
|
||||
current_trace_stack_ids = [trace_id]
|
||||
else:
|
||||
current_trace_stack_ids.append(trace_id)
|
||||
|
||||
global_stack_trace_ids.set(current_trace_stack_ids)
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""Run when an overall trace is exited."""
|
||||
current_trace_stack_ids = global_stack_trace_ids.get().copy()
|
||||
if trace_id is not None and len(current_trace_stack_ids) > 0:
|
||||
current_trace_stack_ids.pop()
|
||||
if len(current_trace_stack_ids) == 0:
|
||||
for handler in self.handlers:
|
||||
handler.end_trace(trace_id=trace_id, trace_map=self._trace_map)
|
||||
current_trace_stack_ids = []
|
||||
|
||||
global_stack_trace_ids.set(current_trace_stack_ids)
|
||||
|
||||
def _reset_trace_events(self) -> None:
|
||||
"""Helper function to reset the current trace."""
|
||||
self._trace_map = defaultdict(list)
|
||||
global_stack_trace.set([BASE_TRACE_EVENT])
|
||||
|
||||
@property
|
||||
def trace_map(self) -> Dict[str, List[str]]:
|
||||
return self._trace_map
|
||||
|
||||
|
||||
class EventContext:
|
||||
"""
|
||||
Simple wrapper to call callbacks on event starts and ends
|
||||
with an event type and id.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
callback_manager: CallbackManager,
|
||||
event_type: CBEventType,
|
||||
event_id: Optional[str] = None,
|
||||
):
|
||||
self._callback_manager = callback_manager
|
||||
self._event_type = event_type
|
||||
self._event_id = event_id or str(uuid.uuid4())
|
||||
self.started = False
|
||||
self.finished = False
|
||||
|
||||
def on_start(self, payload: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
|
||||
if not self.started:
|
||||
self.started = True
|
||||
self._callback_manager.on_event_start(
|
||||
self._event_type, payload=payload, event_id=self._event_id, **kwargs
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Event {self._event_type!s}: {self._event_id} already started!"
|
||||
)
|
||||
|
||||
def on_end(self, payload: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
|
||||
if not self.finished:
|
||||
self.finished = True
|
||||
self._callback_manager.on_event_end(
|
||||
self._event_type, payload=payload, event_id=self._event_id, **kwargs
|
||||
)
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_index.callbacks.schema import BASE_TRACE_EVENT, CBEventType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
global_stack_trace = ContextVar("trace", default=[BASE_TRACE_EVENT])
|
||||
|
||||
|
||||
class BaseCallbackHandler(ABC):
|
||||
"""Base callback handler that can be used to track event starts and ends."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_starts_to_ignore: List[CBEventType],
|
||||
event_ends_to_ignore: List[CBEventType],
|
||||
) -> None:
|
||||
"""Initialize the base callback handler."""
|
||||
self.event_starts_to_ignore = tuple(event_starts_to_ignore)
|
||||
self.event_ends_to_ignore = tuple(event_ends_to_ignore)
|
||||
|
||||
@abstractmethod
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
parent_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run when an event starts and return id of event."""
|
||||
|
||||
@abstractmethod
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when an event ends."""
|
||||
|
||||
@abstractmethod
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""Run when an overall trace is launched."""
|
||||
|
||||
@abstractmethod
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""Run when an overall trace is exited."""
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
|
||||
|
||||
def deepeval_callback_handler(**kwargs: Any) -> BaseCallbackHandler:
|
||||
try:
|
||||
from deepeval.tracing.integrations.llama_index import LlamaIndexCallbackHandler
|
||||
except ImportError:
|
||||
raise ImportError("Please install DeepEval with `pip install -U deepeval`")
|
||||
return LlamaIndexCallbackHandler(**kwargs)
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
import json
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_index.callbacks.base import BaseCallbackHandler
|
||||
from llama_index.callbacks.schema import CBEventType, EventPayload
|
||||
|
||||
|
||||
class BaseFinetuningHandler(BaseCallbackHandler):
|
||||
"""
|
||||
Callback handler for finetuning.
|
||||
|
||||
This handler will collect all messages
|
||||
sent to the LLM, along with their responses.
|
||||
It also defines a `get_finetuning_events` endpoint as well as a
|
||||
`save_finetuning_events` endpoint.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the base callback handler."""
|
||||
super().__init__(
|
||||
event_starts_to_ignore=[],
|
||||
event_ends_to_ignore=[],
|
||||
)
|
||||
self._finetuning_events: Dict[str, List[Any]] = {}
|
||||
self._function_calls: Dict[str, List[Any]] = {}
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
parent_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run when an event starts and return id of event."""
|
||||
from llama_index.core.llms.types import ChatMessage, MessageRole
|
||||
|
||||
if event_type == CBEventType.LLM:
|
||||
cur_messages = []
|
||||
if payload and EventPayload.PROMPT in payload:
|
||||
message = ChatMessage(
|
||||
role=MessageRole.USER, text=str(payload[EventPayload.PROMPT])
|
||||
)
|
||||
cur_messages = [message]
|
||||
elif payload and EventPayload.MESSAGES in payload:
|
||||
cur_messages = payload[EventPayload.MESSAGES]
|
||||
|
||||
if len(cur_messages) > 0:
|
||||
if event_id in self._finetuning_events:
|
||||
self._finetuning_events[event_id].extend(cur_messages)
|
||||
else:
|
||||
self._finetuning_events[event_id] = cur_messages
|
||||
|
||||
# if functions exists, add that
|
||||
if payload and EventPayload.ADDITIONAL_KWARGS in payload:
|
||||
kwargs_dict = payload[EventPayload.ADDITIONAL_KWARGS]
|
||||
if "functions" in kwargs_dict:
|
||||
self._function_calls[event_id] = kwargs_dict["functions"]
|
||||
return event_id
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when an event ends."""
|
||||
from llama_index.core.llms.types import ChatMessage, MessageRole
|
||||
|
||||
if (
|
||||
event_type == CBEventType.LLM
|
||||
and event_id in self._finetuning_events
|
||||
and payload is not None
|
||||
):
|
||||
if isinstance(payload[EventPayload.RESPONSE], str):
|
||||
response = ChatMessage(
|
||||
role=MessageRole.ASSISTANT, text=str(payload[EventPayload.RESPONSE])
|
||||
)
|
||||
else:
|
||||
response = payload[EventPayload.RESPONSE].message
|
||||
|
||||
self._finetuning_events[event_id].append(response)
|
||||
|
||||
@abstractmethod
|
||||
def get_finetuning_events(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get finetuning events."""
|
||||
|
||||
@abstractmethod
|
||||
def save_finetuning_events(self, path: str) -> None:
|
||||
"""Save the finetuning events to a file."""
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""Run when an overall trace is launched."""
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""Run when an overall trace is exited."""
|
||||
|
||||
|
||||
class OpenAIFineTuningHandler(BaseFinetuningHandler):
|
||||
"""
|
||||
Callback handler for OpenAI fine-tuning.
|
||||
|
||||
This handler will collect all messages
|
||||
sent to the LLM, along with their responses. It will then save these messages
|
||||
in a `.jsonl` format that can be used for fine-tuning with OpenAI's API.
|
||||
"""
|
||||
|
||||
def get_finetuning_events(self) -> Dict[str, Dict[str, Any]]:
|
||||
events_dict = {}
|
||||
for event_id, event in self._finetuning_events.items():
|
||||
events_dict[event_id] = {"messages": event[:-1], "response": event[-1]}
|
||||
|
||||
return events_dict
|
||||
|
||||
def save_finetuning_events(self, path: str) -> None:
|
||||
"""
|
||||
Save the finetuning events to a file.
|
||||
|
||||
This saved format can be used for fine-tuning with OpenAI's API.
|
||||
The structure for each json line is as follows:
|
||||
{
|
||||
messages: [
|
||||
{ rol: "system", content: "Text"},
|
||||
{ role: "user", content: "Text" },
|
||||
]
|
||||
},
|
||||
...
|
||||
"""
|
||||
from llama_index.llms.openai_utils import to_openai_message_dicts
|
||||
|
||||
events_dict = self.get_finetuning_events()
|
||||
json_strs = []
|
||||
for event_id, event in events_dict.items():
|
||||
all_messages = event["messages"] + [event["response"]]
|
||||
message_dicts = to_openai_message_dicts(all_messages, drop_none=True)
|
||||
event_dict = {"messages": message_dicts}
|
||||
if event_id in self._function_calls:
|
||||
event_dict["functions"] = self._function_calls[event_id]
|
||||
json_strs.append(json.dumps(event_dict))
|
||||
|
||||
with open(path, "w") as f:
|
||||
f.write("\n".join(json_strs))
|
||||
print(f"Wrote {len(json_strs)} examples to {path}")
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""Run when an overall trace is launched."""
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""Run when an overall trace is exited."""
|
||||
|
||||
|
||||
class GradientAIFineTuningHandler(BaseFinetuningHandler):
|
||||
"""
|
||||
Callback handler for Gradient AI fine-tuning.
|
||||
|
||||
This handler will collect all messages
|
||||
sent to the LLM, along with their responses. It will then save these messages
|
||||
in a `.jsonl` format that can be used for fine-tuning with Gradient AI's API.
|
||||
"""
|
||||
|
||||
def get_finetuning_events(self) -> Dict[str, Dict[str, Any]]:
|
||||
events_dict = {}
|
||||
for event_id, event in self._finetuning_events.items():
|
||||
events_dict[event_id] = {"messages": event[:-1], "response": event[-1]}
|
||||
|
||||
return events_dict
|
||||
|
||||
def save_finetuning_events(self, path: str) -> None:
|
||||
"""
|
||||
Save the finetuning events to a file.
|
||||
|
||||
This saved format can be used for fine-tuning with OpenAI's API.
|
||||
The structure for each json line is as follows:
|
||||
{
|
||||
"inputs": "<full_prompt_str>"
|
||||
},
|
||||
...
|
||||
"""
|
||||
from llama_index.llms.generic_utils import messages_to_history_str
|
||||
|
||||
events_dict = self.get_finetuning_events()
|
||||
json_strs = []
|
||||
for event in events_dict.values():
|
||||
all_messages = event["messages"] + [event["response"]]
|
||||
|
||||
# TODO: come up with model-specific message->prompt serialization format
|
||||
prompt_str = messages_to_history_str(all_messages)
|
||||
|
||||
input_dict = {"inputs": prompt_str}
|
||||
json_strs.append(json.dumps(input_dict))
|
||||
|
||||
with open(path, "w") as f:
|
||||
f.write("\n".join(json_strs))
|
||||
print(f"Wrote {len(json_strs)} examples to {path}")
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""Run when an overall trace is launched."""
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""Run when an overall trace is exited."""
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
"""Global eval handlers."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.callbacks.argilla_callback import argilla_callback_handler
|
||||
from llama_index.callbacks.arize_phoenix_callback import arize_phoenix_callback_handler
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
from llama_index.callbacks.deepeval_callback import deepeval_callback_handler
|
||||
from llama_index.callbacks.honeyhive_callback import honeyhive_callback_handler
|
||||
from llama_index.callbacks.open_inference_callback import OpenInferenceCallbackHandler
|
||||
from llama_index.callbacks.promptlayer_handler import PromptLayerHandler
|
||||
from llama_index.callbacks.simple_llm_handler import SimpleLLMHandler
|
||||
from llama_index.callbacks.wandb_callback import WandbCallbackHandler
|
||||
|
||||
|
||||
def set_global_handler(eval_mode: str, **eval_params: Any) -> None:
|
||||
"""Set global eval handlers."""
|
||||
import llama_index
|
||||
|
||||
llama_index.global_handler = create_global_handler(eval_mode, **eval_params)
|
||||
|
||||
|
||||
def create_global_handler(eval_mode: str, **eval_params: Any) -> BaseCallbackHandler:
|
||||
"""Get global eval handler."""
|
||||
if eval_mode == "wandb":
|
||||
handler: BaseCallbackHandler = WandbCallbackHandler(**eval_params)
|
||||
elif eval_mode == "openinference":
|
||||
handler = OpenInferenceCallbackHandler(**eval_params)
|
||||
elif eval_mode == "arize_phoenix":
|
||||
handler = arize_phoenix_callback_handler(**eval_params)
|
||||
elif eval_mode == "honeyhive":
|
||||
handler = honeyhive_callback_handler(**eval_params)
|
||||
elif eval_mode == "promptlayer":
|
||||
handler = PromptLayerHandler(**eval_params)
|
||||
elif eval_mode == "deepeval":
|
||||
handler = deepeval_callback_handler(**eval_params)
|
||||
elif eval_mode == "simple":
|
||||
handler = SimpleLLMHandler(**eval_params)
|
||||
elif eval_mode == "argilla":
|
||||
handler = argilla_callback_handler(**eval_params)
|
||||
else:
|
||||
raise ValueError(f"Eval mode {eval_mode} not supported.")
|
||||
|
||||
return handler
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
|
||||
|
||||
def honeyhive_callback_handler(**kwargs: Any) -> BaseCallbackHandler:
|
||||
try:
|
||||
from honeyhive.utils.llamaindex_tracer import HoneyHiveLlamaIndexTracer
|
||||
except ImportError:
|
||||
raise ImportError("Please install HoneyHive with `pip install honeyhive`")
|
||||
return HoneyHiveLlamaIndexTracer(**kwargs)
|
||||
|
|
@ -0,0 +1,205 @@
|
|||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
from llama_index.callbacks.schema import (
|
||||
BASE_TRACE_EVENT,
|
||||
TIMESTAMP_FORMAT,
|
||||
CBEvent,
|
||||
CBEventType,
|
||||
EventStats,
|
||||
)
|
||||
|
||||
|
||||
class LlamaDebugHandler(BaseCallbackHandler):
|
||||
"""Callback handler that keeps track of debug info.
|
||||
|
||||
NOTE: this is a beta feature. The usage within our codebase, and the interface
|
||||
may change.
|
||||
|
||||
This handler simply keeps track of event starts/ends, separated by event types.
|
||||
You can use this callback handler to keep track of and debug events.
|
||||
|
||||
Args:
|
||||
event_starts_to_ignore (Optional[List[CBEventType]]): list of event types to
|
||||
ignore when tracking event starts.
|
||||
event_ends_to_ignore (Optional[List[CBEventType]]): list of event types to
|
||||
ignore when tracking event ends.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_starts_to_ignore: Optional[List[CBEventType]] = None,
|
||||
event_ends_to_ignore: Optional[List[CBEventType]] = None,
|
||||
print_trace_on_end: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the llama debug handler."""
|
||||
self._event_pairs_by_type: Dict[CBEventType, List[CBEvent]] = defaultdict(list)
|
||||
self._event_pairs_by_id: Dict[str, List[CBEvent]] = defaultdict(list)
|
||||
self._sequential_events: List[CBEvent] = []
|
||||
self._cur_trace_id: Optional[str] = None
|
||||
self._trace_map: Dict[str, List[str]] = defaultdict(list)
|
||||
self.print_trace_on_end = print_trace_on_end
|
||||
event_starts_to_ignore = (
|
||||
event_starts_to_ignore if event_starts_to_ignore else []
|
||||
)
|
||||
event_ends_to_ignore = event_ends_to_ignore if event_ends_to_ignore else []
|
||||
super().__init__(
|
||||
event_starts_to_ignore=event_starts_to_ignore,
|
||||
event_ends_to_ignore=event_ends_to_ignore,
|
||||
)
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
parent_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Store event start data by event type.
|
||||
|
||||
Args:
|
||||
event_type (CBEventType): event type to store.
|
||||
payload (Optional[Dict[str, Any]]): payload to store.
|
||||
event_id (str): event id to store.
|
||||
parent_id (str): parent event id.
|
||||
|
||||
"""
|
||||
event = CBEvent(event_type, payload=payload, id_=event_id)
|
||||
self._event_pairs_by_type[event.event_type].append(event)
|
||||
self._event_pairs_by_id[event.id_].append(event)
|
||||
self._sequential_events.append(event)
|
||||
return event.id_
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Store event end data by event type.
|
||||
|
||||
Args:
|
||||
event_type (CBEventType): event type to store.
|
||||
payload (Optional[Dict[str, Any]]): payload to store.
|
||||
event_id (str): event id to store.
|
||||
|
||||
"""
|
||||
event = CBEvent(event_type, payload=payload, id_=event_id)
|
||||
self._event_pairs_by_type[event.event_type].append(event)
|
||||
self._event_pairs_by_id[event.id_].append(event)
|
||||
self._sequential_events.append(event)
|
||||
self._trace_map = defaultdict(list)
|
||||
|
||||
def get_events(self, event_type: Optional[CBEventType] = None) -> List[CBEvent]:
|
||||
"""Get all events for a specific event type."""
|
||||
if event_type is not None:
|
||||
return self._event_pairs_by_type[event_type]
|
||||
|
||||
return self._sequential_events
|
||||
|
||||
def _get_event_pairs(self, events: List[CBEvent]) -> List[List[CBEvent]]:
|
||||
"""Helper function to pair events according to their ID."""
|
||||
event_pairs: Dict[str, List[CBEvent]] = defaultdict(list)
|
||||
for event in events:
|
||||
event_pairs[event.id_].append(event)
|
||||
|
||||
return sorted(
|
||||
event_pairs.values(),
|
||||
key=lambda x: datetime.strptime(x[0].time, TIMESTAMP_FORMAT),
|
||||
)
|
||||
|
||||
def _get_time_stats_from_event_pairs(
|
||||
self, event_pairs: List[List[CBEvent]]
|
||||
) -> EventStats:
|
||||
"""Calculate time-based stats for a set of event pairs."""
|
||||
total_secs = 0.0
|
||||
for event_pair in event_pairs:
|
||||
start_time = datetime.strptime(event_pair[0].time, TIMESTAMP_FORMAT)
|
||||
end_time = datetime.strptime(event_pair[-1].time, TIMESTAMP_FORMAT)
|
||||
total_secs += (end_time - start_time).total_seconds()
|
||||
|
||||
return EventStats(
|
||||
total_secs=total_secs,
|
||||
average_secs=total_secs / len(event_pairs),
|
||||
total_count=len(event_pairs),
|
||||
)
|
||||
|
||||
def get_event_pairs(
|
||||
self, event_type: Optional[CBEventType] = None
|
||||
) -> List[List[CBEvent]]:
|
||||
"""Pair events by ID, either all events or a specific type."""
|
||||
if event_type is not None:
|
||||
return self._get_event_pairs(self._event_pairs_by_type[event_type])
|
||||
|
||||
return self._get_event_pairs(self._sequential_events)
|
||||
|
||||
def get_llm_inputs_outputs(self) -> List[List[CBEvent]]:
|
||||
"""Get the exact LLM inputs and outputs."""
|
||||
return self._get_event_pairs(self._event_pairs_by_type[CBEventType.LLM])
|
||||
|
||||
def get_event_time_info(
|
||||
self, event_type: Optional[CBEventType] = None
|
||||
) -> EventStats:
|
||||
event_pairs = self.get_event_pairs(event_type)
|
||||
return self._get_time_stats_from_event_pairs(event_pairs)
|
||||
|
||||
def flush_event_logs(self) -> None:
|
||||
"""Clear all events from memory."""
|
||||
self._event_pairs_by_type = defaultdict(list)
|
||||
self._event_pairs_by_id = defaultdict(list)
|
||||
self._sequential_events = []
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""Launch a trace."""
|
||||
self._trace_map = defaultdict(list)
|
||||
self._cur_trace_id = trace_id
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""Shutdown the current trace."""
|
||||
self._trace_map = trace_map or defaultdict(list)
|
||||
if self.print_trace_on_end:
|
||||
self.print_trace_map()
|
||||
|
||||
def _print_trace_map(self, cur_event_id: str, level: int = 0) -> None:
|
||||
"""Recursively print trace map to terminal for debugging."""
|
||||
event_pair = self._event_pairs_by_id[cur_event_id]
|
||||
if event_pair:
|
||||
time_stats = self._get_time_stats_from_event_pairs([event_pair])
|
||||
indent = " " * level * 2
|
||||
print(
|
||||
f"{indent}|_{event_pair[0].event_type} -> ",
|
||||
f"{time_stats.total_secs} seconds",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
child_event_ids = self._trace_map[cur_event_id]
|
||||
for child_event_id in child_event_ids:
|
||||
self._print_trace_map(child_event_id, level=level + 1)
|
||||
|
||||
def print_trace_map(self) -> None:
|
||||
"""Print simple trace map to terminal for debugging of the most recent trace."""
|
||||
print("*" * 10, flush=True)
|
||||
print(f"Trace: {self._cur_trace_id}", flush=True)
|
||||
self._print_trace_map(BASE_TRACE_EVENT, level=1)
|
||||
print("*" * 10, flush=True)
|
||||
|
||||
@property
|
||||
def event_pairs_by_type(self) -> Dict[CBEventType, List[CBEvent]]:
|
||||
return self._event_pairs_by_type
|
||||
|
||||
@property
|
||||
def events_pairs_by_id(self) -> Dict[str, List[CBEvent]]:
|
||||
return self._event_pairs_by_id
|
||||
|
||||
@property
|
||||
def sequential_events(self) -> List[CBEvent]:
|
||||
return self._sequential_events
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
"""
|
||||
Callback handler for storing generation data in OpenInference format.
|
||||
OpenInference is an open standard for capturing and storing AI model inferences.
|
||||
It enables production LLMapp servers to seamlessly integrate with LLM
|
||||
observability solutions such as Arize and Phoenix.
|
||||
|
||||
For more information on the specification, see
|
||||
https://github.com/Arize-ai/open-inference-spec
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import uuid
|
||||
from dataclasses import dataclass, field, fields
|
||||
from datetime import datetime
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, TypeVar
|
||||
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
from llama_index.callbacks.schema import CBEventType, EventPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pandas import DataFrame
|
||||
|
||||
|
||||
OPENINFERENCE_COLUMN_NAME = "openinference_column_name"
|
||||
Embedding = List[float]
|
||||
|
||||
|
||||
def _generate_random_id() -> str:
|
||||
"""Generates a random ID.
|
||||
|
||||
Returns:
|
||||
str: A random ID.
|
||||
"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryData:
|
||||
"""
|
||||
Query data with column names following the OpenInference specification.
|
||||
"""
|
||||
|
||||
id: str = field(
|
||||
default_factory=_generate_random_id,
|
||||
metadata={OPENINFERENCE_COLUMN_NAME: ":id.id:"},
|
||||
)
|
||||
timestamp: Optional[str] = field(
|
||||
default=None, metadata={OPENINFERENCE_COLUMN_NAME: ":timestamp.iso_8601:"}
|
||||
)
|
||||
query_text: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={OPENINFERENCE_COLUMN_NAME: ":feature.text:prompt"},
|
||||
)
|
||||
query_embedding: Optional[Embedding] = field(
|
||||
default=None,
|
||||
metadata={OPENINFERENCE_COLUMN_NAME: ":feature.[float].embedding:prompt"},
|
||||
)
|
||||
response_text: Optional[str] = field(
|
||||
default=None, metadata={OPENINFERENCE_COLUMN_NAME: ":prediction.text:response"}
|
||||
)
|
||||
node_ids: List[str] = field(
|
||||
default_factory=list,
|
||||
metadata={
|
||||
OPENINFERENCE_COLUMN_NAME: ":feature.[str].retrieved_document_ids:prompt"
|
||||
},
|
||||
)
|
||||
scores: List[float] = field(
|
||||
default_factory=list,
|
||||
metadata={
|
||||
OPENINFERENCE_COLUMN_NAME: (
|
||||
":feature.[float].retrieved_document_scores:prompt"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeData:
|
||||
"""Node data."""
|
||||
|
||||
id: str
|
||||
node_text: Optional[str] = None
|
||||
node_embedding: Optional[Embedding] = None
|
||||
|
||||
|
||||
BaseDataType = TypeVar("BaseDataType", QueryData, NodeData)
|
||||
|
||||
|
||||
def as_dataframe(data: Iterable[BaseDataType]) -> "DataFrame":
|
||||
"""Converts a list of BaseDataType to a pandas dataframe.
|
||||
|
||||
Args:
|
||||
data (Iterable[BaseDataType]): A list of BaseDataType.
|
||||
|
||||
Returns:
|
||||
DataFrame: The converted pandas dataframe.
|
||||
"""
|
||||
pandas = _import_package("pandas")
|
||||
as_dict_list = []
|
||||
for datum in data:
|
||||
as_dict = {
|
||||
field.metadata.get(OPENINFERENCE_COLUMN_NAME, field.name): getattr(
|
||||
datum, field.name
|
||||
)
|
||||
for field in fields(datum)
|
||||
}
|
||||
as_dict_list.append(as_dict)
|
||||
|
||||
return pandas.DataFrame(as_dict_list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceData:
|
||||
"""Trace data."""
|
||||
|
||||
query_data: QueryData = field(default_factory=QueryData)
|
||||
node_datas: List[NodeData] = field(default_factory=list)
|
||||
|
||||
|
||||
def _import_package(package_name: str) -> ModuleType:
|
||||
"""Dynamically imports a package.
|
||||
|
||||
Args:
|
||||
package_name (str): Name of the package to import.
|
||||
|
||||
Raises:
|
||||
ImportError: If the package is not installed.
|
||||
|
||||
Returns:
|
||||
ModuleType: The imported package.
|
||||
"""
|
||||
try:
|
||||
package = importlib.import_module(package_name)
|
||||
except ImportError:
|
||||
raise ImportError(f"The {package_name} package must be installed.")
|
||||
return package
|
||||
|
||||
|
||||
class OpenInferenceCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for storing generation data in OpenInference format.
|
||||
OpenInference is an open standard for capturing and storing AI model
|
||||
inferences. It enables production LLMapp servers to seamlessly integrate
|
||||
with LLM observability solutions such as Arize and Phoenix.
|
||||
|
||||
For more information on the specification, see
|
||||
https://github.com/Arize-ai/open-inference-spec
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
callback: Optional[Callable[[List[QueryData], List[NodeData]], None]] = None,
|
||||
) -> None:
|
||||
"""Initializes the OpenInferenceCallbackHandler.
|
||||
|
||||
Args:
|
||||
callback (Optional[Callable[[List[QueryData], List[NodeData]], None]], optional): A
|
||||
callback function that will be called when a query trace is
|
||||
completed, often used for logging or persisting query data.
|
||||
"""
|
||||
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
|
||||
self._callback = callback
|
||||
self._trace_data = TraceData()
|
||||
self._query_data_buffer: List[QueryData] = []
|
||||
self._node_data_buffer: List[NodeData] = []
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
if trace_id == "query":
|
||||
self._trace_data = TraceData()
|
||||
self._trace_data.query_data.timestamp = datetime.now().isoformat()
|
||||
self._trace_data.query_data.id = _generate_random_id()
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
if trace_id == "query":
|
||||
self._query_data_buffer.append(self._trace_data.query_data)
|
||||
self._node_data_buffer.extend(self._trace_data.node_datas)
|
||||
self._trace_data = TraceData()
|
||||
if self._callback is not None:
|
||||
self._callback(self._query_data_buffer, self._node_data_buffer)
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
parent_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if payload is not None:
|
||||
if event_type is CBEventType.QUERY:
|
||||
query_text = payload[EventPayload.QUERY_STR]
|
||||
self._trace_data.query_data.query_text = query_text
|
||||
return event_id
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if payload is None:
|
||||
return
|
||||
if event_type is CBEventType.RETRIEVE:
|
||||
for node_with_score in payload[EventPayload.NODES]:
|
||||
node = node_with_score.node
|
||||
score = node_with_score.score
|
||||
self._trace_data.query_data.node_ids.append(node.hash)
|
||||
self._trace_data.query_data.scores.append(score)
|
||||
self._trace_data.node_datas.append(
|
||||
NodeData(
|
||||
id=node.hash,
|
||||
node_text=node.text,
|
||||
)
|
||||
)
|
||||
elif event_type is CBEventType.LLM:
|
||||
self._trace_data.query_data.response_text = str(
|
||||
payload.get(EventPayload.RESPONSE, "")
|
||||
) or str(payload.get(EventPayload.COMPLETION, ""))
|
||||
elif event_type is CBEventType.EMBEDDING:
|
||||
self._trace_data.query_data.query_embedding = payload[
|
||||
EventPayload.EMBEDDINGS
|
||||
][0]
|
||||
|
||||
def flush_query_data_buffer(self) -> List[QueryData]:
|
||||
"""Clears the query data buffer and returns the data.
|
||||
|
||||
Returns:
|
||||
List[QueryData]: The query data.
|
||||
"""
|
||||
query_data_buffer = self._query_data_buffer
|
||||
self._query_data_buffer = []
|
||||
return query_data_buffer
|
||||
|
||||
def flush_node_data_buffer(self) -> List[NodeData]:
|
||||
"""Clears the node data buffer and returns the data.
|
||||
|
||||
Returns:
|
||||
List[NodeData]: The node data.
|
||||
"""
|
||||
node_data_buffer = self._node_data_buffer
|
||||
self._node_data_buffer = []
|
||||
return node_data_buffer
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
import datetime
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from llama_index.bridge.pydantic import BaseModel
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
from llama_index.callbacks.schema import CBEventType, EventPayload
|
||||
from llama_index.llms import ChatMessage
|
||||
|
||||
PROMPT_LAYER_CHAT_FUNCTION_NAME = "llamaindex.chat.openai"
|
||||
PROMPT_LAYER_COMPLETION_FUNCTION_NAME = "llamaindex.completion.openai"
|
||||
|
||||
|
||||
class PromptLayerHandler(BaseCallbackHandler):
|
||||
"""Callback handler for sending to promptlayer.com."""
|
||||
|
||||
pl_tags: Optional[List[str]]
|
||||
return_pl_id: bool = False
|
||||
|
||||
def __init__(self, pl_tags: List[str] = [], return_pl_id: bool = False) -> None:
|
||||
try:
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||
|
||||
self._promptlayer_api_request = promptlayer_api_request
|
||||
self._promptlayer_api_key = get_api_key()
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install PromptLAyer with `pip install promptlayer`"
|
||||
)
|
||||
self.pl_tags = pl_tags
|
||||
self.return_pl_id = return_pl_id
|
||||
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
return
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
event_map: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def add_event(self, event_id: str, **kwargs: Any) -> None:
|
||||
self.event_map[event_id] = {
|
||||
"kwargs": kwargs,
|
||||
"request_start_time": datetime.datetime.now().timestamp(),
|
||||
}
|
||||
|
||||
def get_event(
|
||||
self,
|
||||
event_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
return self.event_map[event_id] or {}
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
parent_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if event_type == CBEventType.LLM and payload is not None:
|
||||
self.add_event(
|
||||
event_id=event_id, **payload.get(EventPayload.SERIALIZED, {})
|
||||
)
|
||||
return event_id
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if event_type != CBEventType.LLM or payload is None:
|
||||
return
|
||||
request_end_time = datetime.datetime.now().timestamp()
|
||||
prompt = str(payload.get(EventPayload.PROMPT))
|
||||
completion = payload.get(EventPayload.COMPLETION)
|
||||
response = payload.get(EventPayload.RESPONSE)
|
||||
function_name = PROMPT_LAYER_CHAT_FUNCTION_NAME
|
||||
event_data = self.get_event(event_id=event_id)
|
||||
resp: Union[str, Dict]
|
||||
extra_args = {}
|
||||
if response:
|
||||
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
|
||||
resp = response.message.dict()
|
||||
assert isinstance(resp, dict)
|
||||
|
||||
usage_dict: Dict[str, int] = {}
|
||||
try:
|
||||
usage = response.raw.get("usage", None) # type: ignore
|
||||
|
||||
if isinstance(usage, dict):
|
||||
usage_dict = {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
}
|
||||
elif isinstance(usage, BaseModel):
|
||||
usage_dict = usage.dict()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
extra_args = {
|
||||
"messages": [message.dict() for message in messages],
|
||||
"usage": usage_dict,
|
||||
}
|
||||
## promptlayer needs tool_calls toplevel.
|
||||
if "tool_calls" in response.message.additional_kwargs:
|
||||
resp["tool_calls"] = [
|
||||
tool_call.dict()
|
||||
for tool_call in resp["additional_kwargs"]["tool_calls"]
|
||||
]
|
||||
del resp["additional_kwargs"]["tool_calls"]
|
||||
if completion:
|
||||
function_name = PROMPT_LAYER_COMPLETION_FUNCTION_NAME
|
||||
resp = str(completion)
|
||||
pl_request_id = self._promptlayer_api_request(
|
||||
function_name,
|
||||
"openai",
|
||||
[prompt],
|
||||
{
|
||||
**extra_args,
|
||||
**event_data["kwargs"],
|
||||
},
|
||||
self.pl_tags,
|
||||
[resp],
|
||||
event_data["request_start_time"],
|
||||
request_end_time,
|
||||
self._promptlayer_api_key,
|
||||
return_pl_id=self.return_pl_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
"""Base schema for callback managers."""
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
# timestamp for callback events
|
||||
TIMESTAMP_FORMAT = "%m/%d/%Y, %H:%M:%S.%f"
|
||||
|
||||
# base trace_id for the tracemap in callback_manager
|
||||
BASE_TRACE_EVENT = "root"
|
||||
|
||||
|
||||
class CBEventType(str, Enum):
|
||||
"""Callback manager event types.
|
||||
|
||||
Attributes:
|
||||
CHUNKING: Logs for the before and after of text splitting.
|
||||
NODE_PARSING: Logs for the documents and the nodes that they are parsed into.
|
||||
EMBEDDING: Logs for the number of texts embedded.
|
||||
LLM: Logs for the template and response of LLM calls.
|
||||
QUERY: Keeps track of the start and end of each query.
|
||||
RETRIEVE: Logs for the nodes retrieved for a query.
|
||||
SYNTHESIZE: Logs for the result for synthesize calls.
|
||||
TREE: Logs for the summary and level of summaries generated.
|
||||
SUB_QUESTION: Logs for a generated sub question and answer.
|
||||
"""
|
||||
|
||||
CHUNKING = "chunking"
|
||||
NODE_PARSING = "node_parsing"
|
||||
EMBEDDING = "embedding"
|
||||
LLM = "llm"
|
||||
QUERY = "query"
|
||||
RETRIEVE = "retrieve"
|
||||
SYNTHESIZE = "synthesize"
|
||||
TREE = "tree"
|
||||
SUB_QUESTION = "sub_question"
|
||||
TEMPLATING = "templating"
|
||||
FUNCTION_CALL = "function_call"
|
||||
RERANKING = "reranking"
|
||||
EXCEPTION = "exception"
|
||||
AGENT_STEP = "agent_step"
|
||||
|
||||
|
||||
class EventPayload(str, Enum):
|
||||
DOCUMENTS = "documents" # list of documents before parsing
|
||||
CHUNKS = "chunks" # list of text chunks
|
||||
NODES = "nodes" # list of nodes
|
||||
PROMPT = "formatted_prompt" # formatted prompt sent to LLM
|
||||
MESSAGES = "messages" # list of messages sent to LLM
|
||||
COMPLETION = "completion" # completion from LLM
|
||||
RESPONSE = "response" # message response from LLM
|
||||
QUERY_STR = "query_str" # query used for query engine
|
||||
SUB_QUESTION = "sub_question" # a sub question & answer + sources
|
||||
EMBEDDINGS = "embeddings" # list of embeddings
|
||||
TOP_K = "top_k" # top k nodes retrieved
|
||||
ADDITIONAL_KWARGS = "additional_kwargs" # additional kwargs for event call
|
||||
SERIALIZED = "serialized" # serialized object for event caller
|
||||
FUNCTION_CALL = "function_call" # function call for the LLM
|
||||
FUNCTION_OUTPUT = "function_call_response" # function call output
|
||||
TOOL = "tool" # tool used in LLM call
|
||||
MODEL_NAME = "model_name" # model name used in an event
|
||||
TEMPLATE = "template" # template used in LLM call
|
||||
TEMPLATE_VARS = "template_vars" # template variables used in LLM call
|
||||
SYSTEM_PROMPT = "system_prompt" # system prompt used in LLM call
|
||||
QUERY_WRAPPER_PROMPT = "query_wrapper_prompt" # query wrapper prompt used in LLM
|
||||
EXCEPTION = "exception" # exception raised in an event
|
||||
|
||||
|
||||
# events that will never have children events
|
||||
LEAF_EVENTS = (CBEventType.CHUNKING, CBEventType.LLM, CBEventType.EMBEDDING)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CBEvent:
|
||||
"""Generic class to store event information."""
|
||||
|
||||
event_type: CBEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
time: str = ""
|
||||
id_: str = ""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Init time and id if needed."""
|
||||
if not self.time:
|
||||
self.time = datetime.now().strftime(TIMESTAMP_FORMAT)
|
||||
if not self.id_:
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventStats:
|
||||
"""Time-based Statistics for events."""
|
||||
|
||||
total_secs: float
|
||||
average_secs: float
|
||||
total_count: int
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
from llama_index.callbacks.schema import CBEventType, EventPayload
|
||||
|
||||
|
||||
class SimpleLLMHandler(BaseCallbackHandler):
|
||||
"""Callback handler for printing llms inputs/outputs."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
return
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
def _print_llm_event(self, payload: dict) -> None:
|
||||
from llama_index.llms import ChatMessage
|
||||
|
||||
if EventPayload.PROMPT in payload:
|
||||
prompt = str(payload.get(EventPayload.PROMPT))
|
||||
completion = str(payload.get(EventPayload.COMPLETION))
|
||||
|
||||
print(f"** Prompt: **\n{prompt}")
|
||||
print("*" * 50)
|
||||
print(f"** Completion: **\n{completion}")
|
||||
print("*" * 50)
|
||||
print("\n")
|
||||
elif EventPayload.MESSAGES in payload:
|
||||
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
|
||||
messages_str = "\n".join([str(x) for x in messages])
|
||||
response = str(payload.get(EventPayload.RESPONSE))
|
||||
|
||||
print(f"** Messages: **\n{messages_str}")
|
||||
print("*" * 50)
|
||||
print(f"** Response: **\n{response}")
|
||||
print("*" * 50)
|
||||
print("\n")
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
parent_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return event_id
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Count the LLM or Embedding tokens as needed."""
|
||||
if event_type == CBEventType.LLM and payload is not None:
|
||||
self._print_llm_event(payload)
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
from llama_index.callbacks.schema import CBEventType, EventPayload
|
||||
from llama_index.utilities.token_counting import TokenCounter
|
||||
from llama_index.utils import get_tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenCountingEvent:
|
||||
prompt: str
|
||||
completion: str
|
||||
completion_token_count: int
|
||||
prompt_token_count: int
|
||||
total_token_count: int = 0
|
||||
event_id: str = ""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.total_token_count = self.prompt_token_count + self.completion_token_count
|
||||
|
||||
|
||||
def get_llm_token_counts(
|
||||
token_counter: TokenCounter, payload: Dict[str, Any], event_id: str = ""
|
||||
) -> TokenCountingEvent:
|
||||
from llama_index.llms import ChatMessage
|
||||
|
||||
if EventPayload.PROMPT in payload:
|
||||
prompt = str(payload.get(EventPayload.PROMPT))
|
||||
completion = str(payload.get(EventPayload.COMPLETION))
|
||||
|
||||
return TokenCountingEvent(
|
||||
event_id=event_id,
|
||||
prompt=prompt,
|
||||
prompt_token_count=token_counter.get_string_tokens(prompt),
|
||||
completion=completion,
|
||||
completion_token_count=token_counter.get_string_tokens(completion),
|
||||
)
|
||||
|
||||
elif EventPayload.MESSAGES in payload:
|
||||
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
|
||||
messages_str = "\n".join([str(x) for x in messages])
|
||||
|
||||
response = payload.get(EventPayload.RESPONSE)
|
||||
response_str = str(response)
|
||||
|
||||
# try getting attached token counts first
|
||||
try:
|
||||
messages_tokens = 0
|
||||
response_tokens = 0
|
||||
|
||||
if response is not None and response.raw is not None:
|
||||
usage = response.raw.get("usage", None)
|
||||
|
||||
if usage is not None:
|
||||
if not isinstance(usage, dict):
|
||||
usage = dict(usage)
|
||||
messages_tokens = usage.get("prompt_tokens", 0)
|
||||
response_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
if messages_tokens == 0 or response_tokens == 0:
|
||||
raise ValueError("Invalid token counts!")
|
||||
|
||||
return TokenCountingEvent(
|
||||
event_id=event_id,
|
||||
prompt=messages_str,
|
||||
prompt_token_count=messages_tokens,
|
||||
completion=response_str,
|
||||
completion_token_count=response_tokens,
|
||||
)
|
||||
|
||||
except (ValueError, KeyError):
|
||||
# Invalid token counts, or no token counts attached
|
||||
pass
|
||||
|
||||
# Should count tokens ourselves
|
||||
messages_tokens = token_counter.estimate_tokens_in_messages(messages)
|
||||
response_tokens = token_counter.get_string_tokens(response_str)
|
||||
|
||||
return TokenCountingEvent(
|
||||
event_id=event_id,
|
||||
prompt=messages_str,
|
||||
prompt_token_count=messages_tokens,
|
||||
completion=response_str,
|
||||
completion_token_count=response_tokens,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid payload! Need prompt and completion or messages and response."
|
||||
)
|
||||
|
||||
|
||||
class TokenCountingHandler(BaseCallbackHandler):
|
||||
"""Callback handler for counting tokens in LLM and Embedding events.
|
||||
|
||||
Args:
|
||||
tokenizer:
|
||||
Tokenizer to use. Defaults to the global tokenizer
|
||||
(see llama_index.utils.globals_helper).
|
||||
event_starts_to_ignore: List of event types to ignore at the start of a trace.
|
||||
event_ends_to_ignore: List of event types to ignore at the end of a trace.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Optional[Callable[[str], List]] = None,
|
||||
event_starts_to_ignore: Optional[List[CBEventType]] = None,
|
||||
event_ends_to_ignore: Optional[List[CBEventType]] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
self.llm_token_counts: List[TokenCountingEvent] = []
|
||||
self.embedding_token_counts: List[TokenCountingEvent] = []
|
||||
self.tokenizer = tokenizer or get_tokenizer()
|
||||
|
||||
self._token_counter = TokenCounter(tokenizer=self.tokenizer)
|
||||
self._verbose = verbose
|
||||
|
||||
super().__init__(
|
||||
event_starts_to_ignore=event_starts_to_ignore or [],
|
||||
event_ends_to_ignore=event_ends_to_ignore or [],
|
||||
)
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
return
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
parent_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return event_id
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Count the LLM or Embedding tokens as needed."""
|
||||
if (
|
||||
event_type == CBEventType.LLM
|
||||
and event_type not in self.event_ends_to_ignore
|
||||
and payload is not None
|
||||
):
|
||||
self.llm_token_counts.append(
|
||||
get_llm_token_counts(
|
||||
token_counter=self._token_counter,
|
||||
payload=payload,
|
||||
event_id=event_id,
|
||||
)
|
||||
)
|
||||
|
||||
if self._verbose:
|
||||
print(
|
||||
"LLM Prompt Token Usage: "
|
||||
f"{self.llm_token_counts[-1].prompt_token_count}\n"
|
||||
"LLM Completion Token Usage: "
|
||||
f"{self.llm_token_counts[-1].completion_token_count}",
|
||||
flush=True,
|
||||
)
|
||||
elif (
|
||||
event_type == CBEventType.EMBEDDING
|
||||
and event_type not in self.event_ends_to_ignore
|
||||
and payload is not None
|
||||
):
|
||||
total_chunk_tokens = 0
|
||||
for chunk in payload.get(EventPayload.CHUNKS, []):
|
||||
self.embedding_token_counts.append(
|
||||
TokenCountingEvent(
|
||||
event_id=event_id,
|
||||
prompt=chunk,
|
||||
prompt_token_count=self._token_counter.get_string_tokens(chunk),
|
||||
completion="",
|
||||
completion_token_count=0,
|
||||
)
|
||||
)
|
||||
total_chunk_tokens += self.embedding_token_counts[-1].total_token_count
|
||||
|
||||
if self._verbose:
|
||||
print(f"Embedding Token Usage: {total_chunk_tokens}", flush=True)
|
||||
|
||||
@property
|
||||
def total_llm_token_count(self) -> int:
|
||||
"""Get the current total LLM token count."""
|
||||
return sum([x.total_token_count for x in self.llm_token_counts])
|
||||
|
||||
@property
|
||||
def prompt_llm_token_count(self) -> int:
|
||||
"""Get the current total LLM prompt token count."""
|
||||
return sum([x.prompt_token_count for x in self.llm_token_counts])
|
||||
|
||||
@property
|
||||
def completion_llm_token_count(self) -> int:
|
||||
"""Get the current total LLM completion token count."""
|
||||
return sum([x.completion_token_count for x in self.llm_token_counts])
|
||||
|
||||
@property
|
||||
def total_embedding_token_count(self) -> int:
|
||||
"""Get the current total Embedding token count."""
|
||||
return sum([x.total_token_count for x in self.embedding_token_counts])
|
||||
|
||||
def reset_counts(self) -> None:
|
||||
"""Reset the token counts."""
|
||||
self.llm_token_counts = []
|
||||
self.embedding_token_counts = []
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
from llama_index.callbacks.base import CallbackManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def trace_method(
|
||||
trace_id: str, callback_manager_attr: str = "callback_manager"
|
||||
) -> Callable[[Callable], Callable]:
|
||||
"""
|
||||
Decorator to trace a method.
|
||||
|
||||
Example:
|
||||
@trace_method("my_trace_id")
|
||||
def my_method(self):
|
||||
pass
|
||||
|
||||
Assumes that the self instance has a CallbackManager instance in an attribute
|
||||
named `callback_manager`.
|
||||
This can be overridden by passing in a `callback_manager_attr` keyword argument.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func) # preserve signature, name, etc. of func
|
||||
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
callback_manager = getattr(self, callback_manager_attr)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"Could not find attribute %s on %s.",
|
||||
callback_manager_attr,
|
||||
type(self),
|
||||
)
|
||||
return func(self, *args, **kwargs)
|
||||
callback_manager = cast(CallbackManager, callback_manager)
|
||||
with callback_manager.as_trace(trace_id):
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
@functools.wraps(func) # preserve signature, name, etc. of func
|
||||
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
callback_manager = getattr(self, callback_manager_attr)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"Could not find attribute %s on %s.",
|
||||
callback_manager_attr,
|
||||
type(self),
|
||||
)
|
||||
return await func(self, *args, **kwargs)
|
||||
callback_manager = cast(CallbackManager, callback_manager)
|
||||
with callback_manager.as_trace(trace_id):
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else wrapper
|
||||
|
||||
return decorator
|
||||
|
|
@ -0,0 +1,570 @@
|
|||
import os
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
from llama_index.callbacks.schema import (
|
||||
TIMESTAMP_FORMAT,
|
||||
CBEvent,
|
||||
CBEventType,
|
||||
EventPayload,
|
||||
)
|
||||
from llama_index.callbacks.token_counting import get_llm_token_counts
|
||||
from llama_index.utilities.token_counting import TokenCounter
|
||||
from llama_index.utils import get_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from wandb import Settings as WBSettings
|
||||
from wandb.sdk.data_types import trace_tree
|
||||
|
||||
from llama_index.indices import (
|
||||
ComposableGraph,
|
||||
GPTEmptyIndex,
|
||||
GPTKeywordTableIndex,
|
||||
GPTRAKEKeywordTableIndex,
|
||||
GPTSimpleKeywordTableIndex,
|
||||
GPTSQLStructStoreIndex,
|
||||
GPTTreeIndex,
|
||||
GPTVectorStoreIndex,
|
||||
SummaryIndex,
|
||||
)
|
||||
from llama_index.storage.storage_context import StorageContext
|
||||
|
||||
IndexType = Union[
|
||||
ComposableGraph,
|
||||
GPTKeywordTableIndex,
|
||||
GPTSimpleKeywordTableIndex,
|
||||
GPTRAKEKeywordTableIndex,
|
||||
SummaryIndex,
|
||||
GPTEmptyIndex,
|
||||
GPTTreeIndex,
|
||||
GPTVectorStoreIndex,
|
||||
GPTSQLStructStoreIndex,
|
||||
]
|
||||
|
||||
|
||||
# remove this class
|
||||
class WandbRunArgs(TypedDict):
|
||||
job_type: Optional[str]
|
||||
dir: Optional[str]
|
||||
config: Union[Dict, str, None]
|
||||
project: Optional[str]
|
||||
entity: Optional[str]
|
||||
reinit: Optional[bool]
|
||||
tags: Optional[Sequence]
|
||||
group: Optional[str]
|
||||
name: Optional[str]
|
||||
notes: Optional[str]
|
||||
magic: Optional[Union[dict, str, bool]]
|
||||
config_exclude_keys: Optional[List[str]]
|
||||
config_include_keys: Optional[List[str]]
|
||||
anonymous: Optional[str]
|
||||
mode: Optional[str]
|
||||
allow_val_change: Optional[bool]
|
||||
resume: Optional[Union[bool, str]]
|
||||
force: Optional[bool]
|
||||
tensorboard: Optional[bool]
|
||||
sync_tensorboard: Optional[bool]
|
||||
monitor_gym: Optional[bool]
|
||||
save_code: Optional[bool]
|
||||
id: Optional[str]
|
||||
settings: Union["WBSettings", Dict[str, Any], None]
|
||||
|
||||
|
||||
class WandbCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler that logs events to wandb.
|
||||
|
||||
NOTE: this is a beta feature. The usage within our codebase, and the interface
|
||||
may change.
|
||||
|
||||
Use the `WandbCallbackHandler` to log trace events to wandb. This handler is
|
||||
useful for debugging and visualizing the trace events. It captures the payload of
|
||||
the events and logs them to wandb. The handler also tracks the start and end of
|
||||
events. This is particularly useful for debugging your LLM calls.
|
||||
|
||||
The `WandbCallbackHandler` can also be used to log the indices and graphs to wandb
|
||||
using the `persist_index` method. This will save the indexes as artifacts in wandb.
|
||||
The `load_storage_context` method can be used to load the indexes from wandb
|
||||
artifacts. This method will return a `StorageContext` object that can be used to
|
||||
build the index, using `load_index_from_storage`, `load_indices_from_storage` or
|
||||
`load_graph_from_storage` functions.
|
||||
|
||||
|
||||
Args:
|
||||
event_starts_to_ignore (Optional[List[CBEventType]]): list of event types to
|
||||
ignore when tracking event starts.
|
||||
event_ends_to_ignore (Optional[List[CBEventType]]): list of event types to
|
||||
ignore when tracking event ends.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_args: Optional[WandbRunArgs] = None,
|
||||
tokenizer: Optional[Callable[[str], List]] = None,
|
||||
event_starts_to_ignore: Optional[List[CBEventType]] = None,
|
||||
event_ends_to_ignore: Optional[List[CBEventType]] = None,
|
||||
) -> None:
|
||||
try:
|
||||
import wandb
|
||||
from wandb.sdk.data_types import trace_tree
|
||||
|
||||
self._wandb = wandb
|
||||
self._trace_tree = trace_tree
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"WandbCallbackHandler requires wandb. "
|
||||
"Please install it with `pip install wandb`."
|
||||
)
|
||||
|
||||
from llama_index.indices import (
|
||||
ComposableGraph,
|
||||
GPTEmptyIndex,
|
||||
GPTKeywordTableIndex,
|
||||
GPTRAKEKeywordTableIndex,
|
||||
GPTSimpleKeywordTableIndex,
|
||||
GPTSQLStructStoreIndex,
|
||||
GPTTreeIndex,
|
||||
GPTVectorStoreIndex,
|
||||
SummaryIndex,
|
||||
)
|
||||
|
||||
self._IndexType = (
|
||||
ComposableGraph,
|
||||
GPTKeywordTableIndex,
|
||||
GPTSimpleKeywordTableIndex,
|
||||
GPTRAKEKeywordTableIndex,
|
||||
SummaryIndex,
|
||||
GPTEmptyIndex,
|
||||
GPTTreeIndex,
|
||||
GPTVectorStoreIndex,
|
||||
GPTSQLStructStoreIndex,
|
||||
)
|
||||
|
||||
self._run_args = run_args
|
||||
# Check if a W&B run is already initialized; if not, initialize one
|
||||
self._ensure_run(should_print_url=(self._wandb.run is None))
|
||||
|
||||
self._event_pairs_by_id: Dict[str, List[CBEvent]] = defaultdict(list)
|
||||
self._cur_trace_id: Optional[str] = None
|
||||
self._trace_map: Dict[str, List[str]] = defaultdict(list)
|
||||
|
||||
self.tokenizer = tokenizer or get_tokenizer()
|
||||
self._token_counter = TokenCounter(tokenizer=self.tokenizer)
|
||||
|
||||
event_starts_to_ignore = (
|
||||
event_starts_to_ignore if event_starts_to_ignore else []
|
||||
)
|
||||
event_ends_to_ignore = event_ends_to_ignore if event_ends_to_ignore else []
|
||||
super().__init__(
|
||||
event_starts_to_ignore=event_starts_to_ignore,
|
||||
event_ends_to_ignore=event_ends_to_ignore,
|
||||
)
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
parent_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Store event start data by event type.
|
||||
|
||||
Args:
|
||||
event_type (CBEventType): event type to store.
|
||||
payload (Optional[Dict[str, Any]]): payload to store.
|
||||
event_id (str): event id to store.
|
||||
parent_id (str): parent event id.
|
||||
|
||||
"""
|
||||
event = CBEvent(event_type, payload=payload, id_=event_id)
|
||||
self._event_pairs_by_id[event.id_].append(event)
|
||||
return event.id_
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Store event end data by event type.
|
||||
|
||||
Args:
|
||||
event_type (CBEventType): event type to store.
|
||||
payload (Optional[Dict[str, Any]]): payload to store.
|
||||
event_id (str): event id to store.
|
||||
|
||||
"""
|
||||
event = CBEvent(event_type, payload=payload, id_=event_id)
|
||||
self._event_pairs_by_id[event.id_].append(event)
|
||||
self._trace_map = defaultdict(list)
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""Launch a trace."""
|
||||
self._trace_map = defaultdict(list)
|
||||
self._cur_trace_id = trace_id
|
||||
self._start_time = datetime.now()
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
# Ensure W&B run is initialized
|
||||
self._ensure_run()
|
||||
|
||||
self._trace_map = trace_map or defaultdict(list)
|
||||
self._end_time = datetime.now()
|
||||
|
||||
# Log the trace map to wandb
|
||||
# We can control what trace ids we want to log here.
|
||||
self.log_trace_tree()
|
||||
|
||||
# TODO (ayulockin): Log the LLM token counts to wandb when weave is ready
|
||||
|
||||
def log_trace_tree(self) -> None:
|
||||
"""Log the trace tree to wandb."""
|
||||
try:
|
||||
child_nodes = self._trace_map["root"]
|
||||
root_span = self._convert_event_pair_to_wb_span(
|
||||
self._event_pairs_by_id[child_nodes[0]],
|
||||
trace_id=self._cur_trace_id if len(child_nodes) > 1 else None,
|
||||
)
|
||||
|
||||
if len(child_nodes) == 1:
|
||||
child_nodes = self._trace_map[child_nodes[0]]
|
||||
root_span = self._build_trace_tree(child_nodes, root_span)
|
||||
else:
|
||||
root_span = self._build_trace_tree(child_nodes, root_span)
|
||||
if root_span:
|
||||
root_trace = self._trace_tree.WBTraceTree(root_span)
|
||||
if self._wandb.run:
|
||||
self._wandb.run.log({"trace": root_trace})
|
||||
self._wandb.termlog("Logged trace tree to W&B.")
|
||||
except Exception as e:
|
||||
print(f"Failed to log trace tree to W&B: {e}")
|
||||
# ignore errors to not break user code
|
||||
|
||||
def persist_index(
|
||||
self, index: "IndexType", index_name: str, persist_dir: Union[str, None] = None
|
||||
) -> None:
|
||||
"""Upload an index to wandb as an artifact. You can learn more about W&B
|
||||
artifacts here: https://docs.wandb.ai/guides/artifacts.
|
||||
|
||||
For the `ComposableGraph` index, the root id is stored as artifact metadata.
|
||||
|
||||
Args:
|
||||
index (IndexType): index to upload.
|
||||
index_name (str): name of the index. This will be used as the artifact name.
|
||||
persist_dir (Union[str, None]): directory to persist the index. If None, a
|
||||
temporary directory will be created and used.
|
||||
|
||||
"""
|
||||
if persist_dir is None:
|
||||
persist_dir = f"{self._wandb.run.dir}/storage" # type: ignore
|
||||
_default_persist_dir = True
|
||||
if not os.path.exists(persist_dir):
|
||||
os.makedirs(persist_dir)
|
||||
|
||||
if isinstance(index, self._IndexType):
|
||||
try:
|
||||
index.storage_context.persist(persist_dir) # type: ignore
|
||||
|
||||
metadata = None
|
||||
# For the `ComposableGraph` index, store the root id as metadata
|
||||
if isinstance(index, self._IndexType[0]):
|
||||
root_id = index.root_id
|
||||
metadata = {"root_id": root_id}
|
||||
|
||||
self._upload_index_as_wb_artifact(persist_dir, index_name, metadata)
|
||||
except Exception as e:
|
||||
# Silently ignore errors to not break user code
|
||||
self._print_upload_index_fail_message(e)
|
||||
|
||||
# clear the default storage dir
|
||||
if _default_persist_dir:
|
||||
shutil.rmtree(persist_dir, ignore_errors=True)
|
||||
|
||||
def load_storage_context(
|
||||
self, artifact_url: str, index_download_dir: Union[str, None] = None
|
||||
) -> "StorageContext":
|
||||
"""Download an index from wandb and return a storage context.
|
||||
|
||||
Use this storage context to load the index into memory using
|
||||
`load_index_from_storage`, `load_indices_from_storage` or
|
||||
`load_graph_from_storage` functions.
|
||||
|
||||
Args:
|
||||
artifact_url (str): url of the artifact to download. The artifact url will
|
||||
be of the form: `entity/project/index_name:version` and can be found in
|
||||
the W&B UI.
|
||||
index_download_dir (Union[str, None]): directory to download the index to.
|
||||
"""
|
||||
from llama_index.storage.storage_context import StorageContext
|
||||
|
||||
artifact = self._wandb.use_artifact(artifact_url, type="storage_context")
|
||||
artifact_dir = artifact.download(root=index_download_dir)
|
||||
|
||||
return StorageContext.from_defaults(persist_dir=artifact_dir)
|
||||
|
||||
def _upload_index_as_wb_artifact(
|
||||
self, dir_path: str, artifact_name: str, metadata: Optional[Dict]
|
||||
) -> None:
|
||||
"""Utility function to upload a dir to W&B as an artifact."""
|
||||
artifact = self._wandb.Artifact(artifact_name, type="storage_context")
|
||||
|
||||
if metadata:
|
||||
artifact.metadata = metadata
|
||||
|
||||
artifact.add_dir(dir_path)
|
||||
self._wandb.run.log_artifact(artifact) # type: ignore
|
||||
|
||||
def _build_trace_tree(
|
||||
self, events: List[str], span: "trace_tree.Span"
|
||||
) -> "trace_tree.Span":
|
||||
"""Build the trace tree from the trace map."""
|
||||
for child_event in events:
|
||||
child_span = self._convert_event_pair_to_wb_span(
|
||||
self._event_pairs_by_id[child_event]
|
||||
)
|
||||
child_span = self._build_trace_tree(
|
||||
self._trace_map[child_event], child_span
|
||||
)
|
||||
span.add_child_span(child_span)
|
||||
|
||||
return span
|
||||
|
||||
def _convert_event_pair_to_wb_span(
|
||||
self,
|
||||
event_pair: List[CBEvent],
|
||||
trace_id: Optional[str] = None,
|
||||
) -> "trace_tree.Span":
|
||||
"""Convert a pair of events to a wandb trace tree span."""
|
||||
start_time_ms, end_time_ms = self._get_time_in_ms(event_pair)
|
||||
|
||||
if trace_id is None:
|
||||
event_type = event_pair[0].event_type
|
||||
span_kind = self._map_event_type_to_span_kind(event_type)
|
||||
else:
|
||||
event_type = trace_id # type: ignore
|
||||
span_kind = None
|
||||
|
||||
wb_span = self._trace_tree.Span(
|
||||
name=f"{event_type}",
|
||||
span_kind=span_kind,
|
||||
start_time_ms=start_time_ms,
|
||||
end_time_ms=end_time_ms,
|
||||
)
|
||||
|
||||
inputs, outputs, wb_span = self._add_payload_to_span(wb_span, event_pair)
|
||||
wb_span.add_named_result(inputs=inputs, outputs=outputs) # type: ignore
|
||||
|
||||
return wb_span
|
||||
|
||||
def _map_event_type_to_span_kind(
|
||||
self, event_type: CBEventType
|
||||
) -> Union[None, "trace_tree.SpanKind"]:
|
||||
"""Map a CBEventType to a wandb trace tree SpanKind."""
|
||||
if event_type == CBEventType.CHUNKING:
|
||||
span_kind = None
|
||||
elif event_type == CBEventType.NODE_PARSING:
|
||||
span_kind = None
|
||||
elif event_type == CBEventType.EMBEDDING:
|
||||
# TODO: add span kind for EMBEDDING when it's available
|
||||
span_kind = None
|
||||
elif event_type == CBEventType.LLM:
|
||||
span_kind = self._trace_tree.SpanKind.LLM
|
||||
elif event_type == CBEventType.QUERY:
|
||||
span_kind = self._trace_tree.SpanKind.AGENT
|
||||
elif event_type == CBEventType.AGENT_STEP:
|
||||
span_kind = self._trace_tree.SpanKind.AGENT
|
||||
elif event_type == CBEventType.RETRIEVE:
|
||||
span_kind = self._trace_tree.SpanKind.TOOL
|
||||
elif event_type == CBEventType.SYNTHESIZE:
|
||||
span_kind = self._trace_tree.SpanKind.CHAIN
|
||||
elif event_type == CBEventType.TREE:
|
||||
span_kind = self._trace_tree.SpanKind.CHAIN
|
||||
elif event_type == CBEventType.SUB_QUESTION:
|
||||
span_kind = self._trace_tree.SpanKind.CHAIN
|
||||
elif event_type == CBEventType.RERANKING:
|
||||
span_kind = self._trace_tree.SpanKind.CHAIN
|
||||
elif event_type == CBEventType.FUNCTION_CALL:
|
||||
span_kind = self._trace_tree.SpanKind.TOOL
|
||||
else:
|
||||
span_kind = None
|
||||
|
||||
return span_kind
|
||||
|
||||
def _add_payload_to_span(
|
||||
self, span: "trace_tree.Span", event_pair: List[CBEvent]
|
||||
) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]], "trace_tree.Span"]:
|
||||
"""Add the event's payload to the span."""
|
||||
assert len(event_pair) == 2
|
||||
event_type = event_pair[0].event_type
|
||||
inputs = None
|
||||
outputs = None
|
||||
|
||||
if event_type == CBEventType.NODE_PARSING:
|
||||
# TODO: disabled full detailed inputs/outputs due to UI lag
|
||||
inputs, outputs = self._handle_node_parsing_payload(event_pair)
|
||||
elif event_type == CBEventType.LLM:
|
||||
inputs, outputs, span = self._handle_llm_payload(event_pair, span)
|
||||
elif event_type == CBEventType.QUERY:
|
||||
inputs, outputs = self._handle_query_payload(event_pair)
|
||||
elif event_type == CBEventType.EMBEDDING:
|
||||
inputs, outputs = self._handle_embedding_payload(event_pair)
|
||||
|
||||
return inputs, outputs, span
|
||||
|
||||
def _handle_node_parsing_payload(
|
||||
self, event_pair: List[CBEvent]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Handle the payload of a NODE_PARSING event."""
|
||||
inputs = event_pair[0].payload
|
||||
outputs = event_pair[-1].payload
|
||||
|
||||
if inputs and EventPayload.DOCUMENTS in inputs:
|
||||
documents = inputs.pop(EventPayload.DOCUMENTS)
|
||||
inputs["num_documents"] = len(documents)
|
||||
|
||||
if outputs and EventPayload.NODES in outputs:
|
||||
nodes = outputs.pop(EventPayload.NODES)
|
||||
outputs["num_nodes"] = len(nodes)
|
||||
|
||||
return inputs or {}, outputs or {}
|
||||
|
||||
def _handle_llm_payload(
|
||||
self, event_pair: List[CBEvent], span: "trace_tree.Span"
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any], "trace_tree.Span"]:
|
||||
"""Handle the payload of a LLM event."""
|
||||
inputs = event_pair[0].payload
|
||||
outputs = event_pair[-1].payload
|
||||
|
||||
assert isinstance(inputs, dict) and isinstance(outputs, dict)
|
||||
|
||||
# Get `original_template` from Prompt
|
||||
if EventPayload.PROMPT in inputs:
|
||||
inputs[EventPayload.PROMPT] = inputs[EventPayload.PROMPT]
|
||||
|
||||
# Format messages
|
||||
if EventPayload.MESSAGES in inputs:
|
||||
inputs[EventPayload.MESSAGES] = "\n".join(
|
||||
[str(x) for x in inputs[EventPayload.MESSAGES]]
|
||||
)
|
||||
|
||||
token_counts = get_llm_token_counts(self._token_counter, outputs)
|
||||
metadata = {
|
||||
"formatted_prompt_tokens_count": token_counts.prompt_token_count,
|
||||
"prediction_tokens_count": token_counts.completion_token_count,
|
||||
"total_tokens_used": token_counts.total_token_count,
|
||||
}
|
||||
span.attributes = metadata
|
||||
|
||||
# Make `response` part of `outputs`
|
||||
outputs = {EventPayload.RESPONSE: str(outputs[EventPayload.RESPONSE])}
|
||||
|
||||
return inputs, outputs, span
|
||||
|
||||
def _handle_query_payload(
|
||||
self, event_pair: List[CBEvent]
|
||||
) -> Tuple[Optional[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""Handle the payload of a QUERY event."""
|
||||
inputs = event_pair[0].payload
|
||||
outputs = event_pair[-1].payload
|
||||
|
||||
if outputs:
|
||||
response_obj = outputs[EventPayload.RESPONSE]
|
||||
response = str(outputs[EventPayload.RESPONSE])
|
||||
|
||||
if type(response).__name__ == "Response":
|
||||
response = response_obj.response
|
||||
elif type(response).__name__ == "StreamingResponse":
|
||||
response = response_obj.get_response().response
|
||||
else:
|
||||
response = " "
|
||||
|
||||
outputs = {"response": response}
|
||||
|
||||
return inputs, outputs
|
||||
|
||||
def _handle_embedding_payload(
|
||||
self,
|
||||
event_pair: List[CBEvent],
|
||||
) -> Tuple[Optional[Dict[str, Any]], Dict[str, Any]]:
|
||||
event_pair[0].payload
|
||||
outputs = event_pair[-1].payload
|
||||
|
||||
chunks = []
|
||||
if outputs:
|
||||
chunks = outputs.get(EventPayload.CHUNKS, [])
|
||||
|
||||
return {}, {"num_chunks": len(chunks)}
|
||||
|
||||
def _get_time_in_ms(self, event_pair: List[CBEvent]) -> Tuple[int, int]:
|
||||
"""Get the start and end time of an event pair in milliseconds."""
|
||||
start_time = datetime.strptime(event_pair[0].time, TIMESTAMP_FORMAT)
|
||||
end_time = datetime.strptime(event_pair[1].time, TIMESTAMP_FORMAT)
|
||||
|
||||
start_time_in_ms = int(
|
||||
(start_time - datetime(1970, 1, 1)).total_seconds() * 1000
|
||||
)
|
||||
end_time_in_ms = int((end_time - datetime(1970, 1, 1)).total_seconds() * 1000)
|
||||
|
||||
return start_time_in_ms, end_time_in_ms
|
||||
|
||||
def _ensure_run(self, should_print_url: bool = False) -> None:
|
||||
"""Ensures an active W&B run exists.
|
||||
|
||||
If not, will start a new run with the provided run_args.
|
||||
"""
|
||||
if self._wandb.run is None:
|
||||
# Make a shallow copy of the run args, so we don't modify the original
|
||||
run_args = self._run_args or {} # type: ignore
|
||||
run_args: dict = {**run_args} # type: ignore
|
||||
|
||||
# Prefer to run in silent mode since W&B has a lot of output
|
||||
# which can be undesirable when dealing with text-based models.
|
||||
if "settings" not in run_args: # type: ignore
|
||||
run_args["settings"] = {"silent": True} # type: ignore
|
||||
|
||||
# Start the run and add the stream table
|
||||
self._wandb.init(**run_args)
|
||||
self._wandb.run._label(repo="llama_index") # type: ignore
|
||||
|
||||
if should_print_url:
|
||||
self._print_wandb_init_message(
|
||||
self._wandb.run.settings.run_url # type: ignore
|
||||
)
|
||||
|
||||
def _print_wandb_init_message(self, run_url: str) -> None:
|
||||
"""Print a message to the terminal when W&B is initialized."""
|
||||
self._wandb.termlog(
|
||||
f"Streaming LlamaIndex events to W&B at {run_url}\n"
|
||||
"`WandbCallbackHandler` is currently in beta.\n"
|
||||
"Please report any issues to https://github.com/wandb/wandb/issues "
|
||||
"with the tag `llamaindex`."
|
||||
)
|
||||
|
||||
def _print_upload_index_fail_message(self, e: Exception) -> None:
|
||||
"""Print a message to the terminal when uploading the index fails."""
|
||||
self._wandb.termlog(
|
||||
f"Failed to upload index to W&B with the following error: {e}\n"
|
||||
)
|
||||
|
||||
def finish(self) -> None:
|
||||
"""Finish the callback handler."""
|
||||
self._wandb.finish()
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
from llama_index.chat_engine.condense_plus_context import CondensePlusContextChatEngine
|
||||
from llama_index.chat_engine.condense_question import CondenseQuestionChatEngine
|
||||
from llama_index.chat_engine.context import ContextChatEngine
|
||||
from llama_index.chat_engine.simple import SimpleChatEngine
|
||||
|
||||
__all__ = [
|
||||
"SimpleChatEngine",
|
||||
"CondenseQuestionChatEngine",
|
||||
"ContextChatEngine",
|
||||
"CondensePlusContextChatEngine",
|
||||
]
|
||||
|
|
@ -0,0 +1,362 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from threading import Thread
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from llama_index.callbacks import CallbackManager, trace_method
|
||||
from llama_index.chat_engine.types import (
|
||||
AgentChatResponse,
|
||||
BaseChatEngine,
|
||||
StreamingAgentChatResponse,
|
||||
ToolOutput,
|
||||
)
|
||||
from llama_index.core.llms.types import ChatMessage, MessageRole
|
||||
from llama_index.indices.base_retriever import BaseRetriever
|
||||
from llama_index.indices.query.schema import QueryBundle
|
||||
from llama_index.indices.service_context import ServiceContext
|
||||
from llama_index.llms.generic_utils import messages_to_history_str
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.memory import BaseMemory, ChatMemoryBuffer
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.prompts.base import PromptTemplate
|
||||
from llama_index.schema import MetadataMode, NodeWithScore
|
||||
from llama_index.utilities.token_counting import TokenCounter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_CONTEXT_PROMPT_TEMPLATE = """
|
||||
The following is a friendly conversation between a user and an AI assistant.
|
||||
The assistant is talkative and provides lots of specific details from its context.
|
||||
If the assistant does not know the answer to a question, it truthfully says it
|
||||
does not know.
|
||||
|
||||
Here are the relevant documents for the context:
|
||||
|
||||
{context_str}
|
||||
|
||||
Instruction: Based on the above documents, provide a detailed answer for the user question below.
|
||||
Answer "don't know" if not present in the document.
|
||||
"""
|
||||
|
||||
DEFAULT_CONDENSE_PROMPT_TEMPLATE = """
|
||||
Given the following conversation between a user and an AI assistant and a follow up question from user,
|
||||
rephrase the follow up question to be a standalone question.
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
Follow Up Input: {question}
|
||||
Standalone question:"""
|
||||
|
||||
|
||||
class CondensePlusContextChatEngine(BaseChatEngine):
|
||||
"""Condensed Conversation & Context Chat Engine.
|
||||
|
||||
First condense a conversation and latest user message to a standalone question
|
||||
Then build a context for the standalone question from a retriever,
|
||||
Then pass the context along with prompt and user message to LLM to generate a response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever: BaseRetriever,
|
||||
llm: LLM,
|
||||
memory: BaseMemory,
|
||||
context_prompt: Optional[str] = None,
|
||||
condense_prompt: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
skip_condense: bool = False,
|
||||
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
self._retriever = retriever
|
||||
self._llm = llm
|
||||
self._memory = memory
|
||||
self._context_prompt_template = (
|
||||
context_prompt or DEFAULT_CONTEXT_PROMPT_TEMPLATE
|
||||
)
|
||||
condense_prompt_str = condense_prompt or DEFAULT_CONDENSE_PROMPT_TEMPLATE
|
||||
self._condense_prompt_template = PromptTemplate(condense_prompt_str)
|
||||
self._system_prompt = system_prompt
|
||||
self._skip_condense = skip_condense
|
||||
self._node_postprocessors = node_postprocessors or []
|
||||
self.callback_manager = callback_manager or CallbackManager([])
|
||||
for node_postprocessor in self._node_postprocessors:
|
||||
node_postprocessor.callback_manager = self.callback_manager
|
||||
|
||||
self._token_counter = TokenCounter()
|
||||
self._verbose = verbose
|
||||
|
||||
@classmethod
|
||||
def from_defaults(
|
||||
cls,
|
||||
retriever: BaseRetriever,
|
||||
service_context: Optional[ServiceContext] = None,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
context_prompt: Optional[str] = None,
|
||||
condense_prompt: Optional[str] = None,
|
||||
skip_condense: bool = False,
|
||||
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "CondensePlusContextChatEngine":
|
||||
"""Initialize a CondensePlusContextChatEngine from default parameters."""
|
||||
service_context = service_context or ServiceContext.from_defaults()
|
||||
llm = service_context.llm
|
||||
chat_history = chat_history or []
|
||||
memory = memory or ChatMemoryBuffer.from_defaults(
|
||||
chat_history=chat_history, token_limit=llm.metadata.context_window - 256
|
||||
)
|
||||
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
context_prompt=context_prompt,
|
||||
condense_prompt=condense_prompt,
|
||||
skip_condense=skip_condense,
|
||||
callback_manager=service_context.callback_manager,
|
||||
node_postprocessors=node_postprocessors,
|
||||
system_prompt=system_prompt,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
def _condense_question(
|
||||
self, chat_history: List[ChatMessage], latest_message: str
|
||||
) -> str:
|
||||
"""Condense a conversation history and latest user message to a standalone question."""
|
||||
if self._skip_condense or len(chat_history) == 0:
|
||||
return latest_message
|
||||
|
||||
chat_history_str = messages_to_history_str(chat_history)
|
||||
logger.debug(chat_history_str)
|
||||
|
||||
return self._llm.predict(
|
||||
self._condense_prompt_template,
|
||||
question=latest_message,
|
||||
chat_history=chat_history_str,
|
||||
)
|
||||
|
||||
async def _acondense_question(
|
||||
self, chat_history: List[ChatMessage], latest_message: str
|
||||
) -> str:
|
||||
"""Condense a conversation history and latest user message to a standalone question."""
|
||||
if self._skip_condense or len(chat_history) == 0:
|
||||
return latest_message
|
||||
|
||||
chat_history_str = messages_to_history_str(chat_history)
|
||||
logger.debug(chat_history_str)
|
||||
|
||||
return await self._llm.apredict(
|
||||
self._condense_prompt_template,
|
||||
question=latest_message,
|
||||
chat_history=chat_history_str,
|
||||
)
|
||||
|
||||
def _retrieve_context(self, message: str) -> Tuple[str, List[NodeWithScore]]:
|
||||
"""Build context for a message from retriever."""
|
||||
nodes = self._retriever.retrieve(message)
|
||||
for postprocessor in self._node_postprocessors:
|
||||
nodes = postprocessor.postprocess_nodes(
|
||||
nodes, query_bundle=QueryBundle(message)
|
||||
)
|
||||
|
||||
context_str = "\n\n".join(
|
||||
[n.node.get_content(metadata_mode=MetadataMode.LLM).strip() for n in nodes]
|
||||
)
|
||||
return context_str, nodes
|
||||
|
||||
async def _aretrieve_context(self, message: str) -> Tuple[str, List[NodeWithScore]]:
|
||||
"""Build context for a message from retriever."""
|
||||
nodes = await self._retriever.aretrieve(message)
|
||||
context_str = "\n\n".join(
|
||||
[n.node.get_content(metadata_mode=MetadataMode.LLM).strip() for n in nodes]
|
||||
)
|
||||
return context_str, nodes
|
||||
|
||||
def _run_c3(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> Tuple[List[ChatMessage], ToolOutput, List[NodeWithScore]]:
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
|
||||
chat_history = self._memory.get()
|
||||
|
||||
# Condense conversation history and latest message to a standalone question
|
||||
condensed_question = self._condense_question(chat_history, message)
|
||||
logger.info(f"Condensed question: {condensed_question}")
|
||||
if self._verbose:
|
||||
print(f"Condensed question: {condensed_question}")
|
||||
|
||||
# Build context for the standalone question from a retriever
|
||||
context_str, context_nodes = self._retrieve_context(condensed_question)
|
||||
context_source = ToolOutput(
|
||||
tool_name="retriever",
|
||||
content=context_str,
|
||||
raw_input={"message": condensed_question},
|
||||
raw_output=context_str,
|
||||
)
|
||||
logger.debug(f"Context: {context_str}")
|
||||
if self._verbose:
|
||||
print(f"Context: {context_str}")
|
||||
|
||||
system_message_content = self._context_prompt_template.format(
|
||||
context_str=context_str
|
||||
)
|
||||
if self._system_prompt:
|
||||
system_message_content = self._system_prompt + "\n" + system_message_content
|
||||
|
||||
system_message = ChatMessage(
|
||||
content=system_message_content, role=self._llm.metadata.system_role
|
||||
)
|
||||
|
||||
initial_token_count = self._token_counter.estimate_tokens_in_messages(
|
||||
[system_message]
|
||||
)
|
||||
|
||||
self._memory.put(ChatMessage(content=message, role=MessageRole.USER))
|
||||
chat_messages = [
|
||||
system_message,
|
||||
*self._memory.get(initial_token_count=initial_token_count),
|
||||
]
|
||||
return chat_messages, context_source, context_nodes
|
||||
|
||||
async def _arun_c3(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> Tuple[List[ChatMessage], ToolOutput, List[NodeWithScore]]:
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
|
||||
chat_history = self._memory.get()
|
||||
|
||||
# Condense conversation history and latest message to a standalone question
|
||||
condensed_question = await self._acondense_question(chat_history, message)
|
||||
logger.info(f"Condensed question: {condensed_question}")
|
||||
if self._verbose:
|
||||
print(f"Condensed question: {condensed_question}")
|
||||
|
||||
# Build context for the standalone question from a retriever
|
||||
context_str, context_nodes = await self._aretrieve_context(condensed_question)
|
||||
context_source = ToolOutput(
|
||||
tool_name="retriever",
|
||||
content=context_str,
|
||||
raw_input={"message": condensed_question},
|
||||
raw_output=context_str,
|
||||
)
|
||||
logger.debug(f"Context: {context_str}")
|
||||
if self._verbose:
|
||||
print(f"Context: {context_str}")
|
||||
|
||||
system_message_content = self._context_prompt_template.format(
|
||||
context_str=context_str
|
||||
)
|
||||
if self._system_prompt:
|
||||
system_message_content = self._system_prompt + "\n" + system_message_content
|
||||
|
||||
system_message = ChatMessage(
|
||||
content=system_message_content, role=self._llm.metadata.system_role
|
||||
)
|
||||
|
||||
initial_token_count = self._token_counter.estimate_tokens_in_messages(
|
||||
[system_message]
|
||||
)
|
||||
|
||||
self._memory.put(ChatMessage(content=message, role=MessageRole.USER))
|
||||
chat_messages = [
|
||||
system_message,
|
||||
*self._memory.get(initial_token_count=initial_token_count),
|
||||
]
|
||||
|
||||
return chat_messages, context_source, context_nodes
|
||||
|
||||
@trace_method("chat")
|
||||
def chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AgentChatResponse:
|
||||
chat_messages, context_source, context_nodes = self._run_c3(
|
||||
message, chat_history
|
||||
)
|
||||
|
||||
# pass the context, system prompt and user message as chat to LLM to generate a response
|
||||
chat_response = self._llm.chat(chat_messages)
|
||||
assistant_message = chat_response.message
|
||||
self._memory.put(assistant_message)
|
||||
|
||||
return AgentChatResponse(
|
||||
response=str(assistant_message.content),
|
||||
sources=[context_source],
|
||||
source_nodes=context_nodes,
|
||||
)
|
||||
|
||||
@trace_method("chat")
|
||||
def stream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
chat_messages, context_source, context_nodes = self._run_c3(
|
||||
message, chat_history
|
||||
)
|
||||
|
||||
# pass the context, system prompt and user message as chat to LLM to generate a response
|
||||
chat_response = StreamingAgentChatResponse(
|
||||
chat_stream=self._llm.stream_chat(chat_messages),
|
||||
sources=[context_source],
|
||||
source_nodes=context_nodes,
|
||||
)
|
||||
thread = Thread(
|
||||
target=chat_response.write_response_to_history, args=(self._memory,)
|
||||
)
|
||||
thread.start()
|
||||
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
async def achat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AgentChatResponse:
|
||||
chat_messages, context_source, context_nodes = await self._arun_c3(
|
||||
message, chat_history
|
||||
)
|
||||
|
||||
# pass the context, system prompt and user message as chat to LLM to generate a response
|
||||
chat_response = await self._llm.achat(chat_messages)
|
||||
assistant_message = chat_response.message
|
||||
self._memory.put(assistant_message)
|
||||
|
||||
return AgentChatResponse(
|
||||
response=str(assistant_message.content),
|
||||
sources=[context_source],
|
||||
source_nodes=context_nodes,
|
||||
)
|
||||
|
||||
@trace_method("chat")
|
||||
async def astream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
chat_messages, context_source, context_nodes = await self._arun_c3(
|
||||
message, chat_history
|
||||
)
|
||||
|
||||
# pass the context, system prompt and user message as chat to LLM to generate a response
|
||||
chat_response = StreamingAgentChatResponse(
|
||||
achat_stream=await self._llm.astream_chat(chat_messages),
|
||||
sources=[context_source],
|
||||
source_nodes=context_nodes,
|
||||
)
|
||||
thread = Thread(
|
||||
target=lambda x: asyncio.run(chat_response.awrite_response_to_history(x)),
|
||||
args=(self._memory,),
|
||||
)
|
||||
thread.start()
|
||||
|
||||
return chat_response
|
||||
|
||||
def reset(self) -> None:
|
||||
# Clear chat history
|
||||
self._memory.reset()
|
||||
|
||||
@property
|
||||
def chat_history(self) -> List[ChatMessage]:
|
||||
"""Get chat history."""
|
||||
return self._memory.get_all()
|
||||
|
|
@ -0,0 +1,362 @@
|
|||
import logging
|
||||
from threading import Thread
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
from llama_index.callbacks import CallbackManager, trace_method
|
||||
from llama_index.chat_engine.types import (
|
||||
AgentChatResponse,
|
||||
BaseChatEngine,
|
||||
StreamingAgentChatResponse,
|
||||
)
|
||||
from llama_index.chat_engine.utils import response_gen_from_query_engine
|
||||
from llama_index.core.base_query_engine import BaseQueryEngine
|
||||
from llama_index.core.llms.types import ChatMessage, MessageRole
|
||||
from llama_index.core.response.schema import RESPONSE_TYPE, StreamingResponse
|
||||
from llama_index.llm_predictor.base import LLMPredictorType
|
||||
from llama_index.llms.generic_utils import messages_to_history_str
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.memory import BaseMemory, ChatMemoryBuffer
|
||||
from llama_index.prompts.base import BasePromptTemplate, PromptTemplate
|
||||
from llama_index.service_context import ServiceContext
|
||||
from llama_index.token_counter.mock_embed_model import MockEmbedding
|
||||
from llama_index.tools import ToolOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_TEMPLATE = """\
|
||||
Given a conversation (between Human and Assistant) and a follow up message from Human, \
|
||||
rewrite the message to be a standalone question that captures all relevant context \
|
||||
from the conversation.
|
||||
|
||||
<Chat History>
|
||||
{chat_history}
|
||||
|
||||
<Follow Up Message>
|
||||
{question}
|
||||
|
||||
<Standalone question>
|
||||
"""
|
||||
|
||||
DEFAULT_PROMPT = PromptTemplate(DEFAULT_TEMPLATE)
|
||||
|
||||
|
||||
class CondenseQuestionChatEngine(BaseChatEngine):
|
||||
"""Condense Question Chat Engine.
|
||||
|
||||
First generate a standalone question from conversation context and last message,
|
||||
then query the query engine for a response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_engine: BaseQueryEngine,
|
||||
condense_question_prompt: BasePromptTemplate,
|
||||
memory: BaseMemory,
|
||||
llm: LLMPredictorType,
|
||||
verbose: bool = False,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
) -> None:
|
||||
self._query_engine = query_engine
|
||||
self._condense_question_prompt = condense_question_prompt
|
||||
self._memory = memory
|
||||
self._llm = llm
|
||||
self._verbose = verbose
|
||||
self.callback_manager = callback_manager or CallbackManager([])
|
||||
|
||||
@classmethod
|
||||
def from_defaults(
|
||||
cls,
|
||||
query_engine: BaseQueryEngine,
|
||||
condense_question_prompt: Optional[BasePromptTemplate] = None,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
|
||||
service_context: Optional[ServiceContext] = None,
|
||||
verbose: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
prefix_messages: Optional[List[ChatMessage]] = None,
|
||||
llm: Optional[LLM] = None,
|
||||
**kwargs: Any,
|
||||
) -> "CondenseQuestionChatEngine":
|
||||
"""Initialize a CondenseQuestionChatEngine from default parameters."""
|
||||
condense_question_prompt = condense_question_prompt or DEFAULT_PROMPT
|
||||
|
||||
if llm is None:
|
||||
service_context = service_context or ServiceContext.from_defaults(
|
||||
embed_model=MockEmbedding(embed_dim=2)
|
||||
)
|
||||
llm = service_context.llm
|
||||
else:
|
||||
service_context = service_context or ServiceContext.from_defaults(
|
||||
llm=llm, embed_model=MockEmbedding(embed_dim=2)
|
||||
)
|
||||
|
||||
chat_history = chat_history or []
|
||||
memory = memory or memory_cls.from_defaults(chat_history=chat_history, llm=llm)
|
||||
|
||||
if system_prompt is not None:
|
||||
raise NotImplementedError(
|
||||
"system_prompt is not supported for CondenseQuestionChatEngine."
|
||||
)
|
||||
if prefix_messages is not None:
|
||||
raise NotImplementedError(
|
||||
"prefix_messages is not supported for CondenseQuestionChatEngine."
|
||||
)
|
||||
|
||||
return cls(
|
||||
query_engine,
|
||||
condense_question_prompt,
|
||||
memory,
|
||||
llm,
|
||||
verbose=verbose,
|
||||
callback_manager=service_context.callback_manager,
|
||||
)
|
||||
|
||||
def _condense_question(
|
||||
self, chat_history: List[ChatMessage], last_message: str
|
||||
) -> str:
|
||||
"""
|
||||
Generate standalone question from conversation context and last message.
|
||||
"""
|
||||
chat_history_str = messages_to_history_str(chat_history)
|
||||
logger.debug(chat_history_str)
|
||||
|
||||
return self._llm.predict(
|
||||
self._condense_question_prompt,
|
||||
question=last_message,
|
||||
chat_history=chat_history_str,
|
||||
)
|
||||
|
||||
async def _acondense_question(
|
||||
self, chat_history: List[ChatMessage], last_message: str
|
||||
) -> str:
|
||||
"""
|
||||
Generate standalone question from conversation context and last message.
|
||||
"""
|
||||
chat_history_str = messages_to_history_str(chat_history)
|
||||
logger.debug(chat_history_str)
|
||||
|
||||
return await self._llm.apredict(
|
||||
self._condense_question_prompt,
|
||||
question=last_message,
|
||||
chat_history=chat_history_str,
|
||||
)
|
||||
|
||||
def _get_tool_output_from_response(
|
||||
self, query: str, response: RESPONSE_TYPE
|
||||
) -> ToolOutput:
|
||||
if isinstance(response, StreamingResponse):
|
||||
return ToolOutput(
|
||||
content="",
|
||||
tool_name="query_engine",
|
||||
raw_input={"query": query},
|
||||
raw_output=response,
|
||||
)
|
||||
else:
|
||||
return ToolOutput(
|
||||
content=str(response),
|
||||
tool_name="query_engine",
|
||||
raw_input={"query": query},
|
||||
raw_output=response,
|
||||
)
|
||||
|
||||
@trace_method("chat")
|
||||
def chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AgentChatResponse:
|
||||
chat_history = chat_history or self._memory.get()
|
||||
|
||||
# Generate standalone question from conversation context and last message
|
||||
condensed_question = self._condense_question(chat_history, message)
|
||||
|
||||
log_str = f"Querying with: {condensed_question}"
|
||||
logger.info(log_str)
|
||||
if self._verbose:
|
||||
print(log_str)
|
||||
|
||||
# TODO: right now, query engine uses class attribute to configure streaming,
|
||||
# we are moving towards separate streaming and non-streaming methods.
|
||||
# In the meanwhile, use this hack to toggle streaming.
|
||||
from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
|
||||
|
||||
if isinstance(self._query_engine, RetrieverQueryEngine):
|
||||
is_streaming = self._query_engine._response_synthesizer._streaming
|
||||
self._query_engine._response_synthesizer._streaming = False
|
||||
|
||||
# Query with standalone question
|
||||
query_response = self._query_engine.query(condensed_question)
|
||||
|
||||
# NOTE: reset streaming flag
|
||||
if isinstance(self._query_engine, RetrieverQueryEngine):
|
||||
self._query_engine._response_synthesizer._streaming = is_streaming
|
||||
|
||||
tool_output = self._get_tool_output_from_response(
|
||||
condensed_question, query_response
|
||||
)
|
||||
|
||||
# Record response
|
||||
self._memory.put(ChatMessage(role=MessageRole.USER, content=message))
|
||||
self._memory.put(
|
||||
ChatMessage(role=MessageRole.ASSISTANT, content=str(query_response))
|
||||
)
|
||||
|
||||
return AgentChatResponse(response=str(query_response), sources=[tool_output])
|
||||
|
||||
@trace_method("chat")
|
||||
def stream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
chat_history = chat_history or self._memory.get()
|
||||
|
||||
# Generate standalone question from conversation context and last message
|
||||
condensed_question = self._condense_question(chat_history, message)
|
||||
|
||||
log_str = f"Querying with: {condensed_question}"
|
||||
logger.info(log_str)
|
||||
if self._verbose:
|
||||
print(log_str)
|
||||
|
||||
# TODO: right now, query engine uses class attribute to configure streaming,
|
||||
# we are moving towards separate streaming and non-streaming methods.
|
||||
# In the meanwhile, use this hack to toggle streaming.
|
||||
from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
|
||||
|
||||
if isinstance(self._query_engine, RetrieverQueryEngine):
|
||||
is_streaming = self._query_engine._response_synthesizer._streaming
|
||||
self._query_engine._response_synthesizer._streaming = True
|
||||
|
||||
# Query with standalone question
|
||||
query_response = self._query_engine.query(condensed_question)
|
||||
|
||||
# NOTE: reset streaming flag
|
||||
if isinstance(self._query_engine, RetrieverQueryEngine):
|
||||
self._query_engine._response_synthesizer._streaming = is_streaming
|
||||
|
||||
tool_output = self._get_tool_output_from_response(
|
||||
condensed_question, query_response
|
||||
)
|
||||
|
||||
# Record response
|
||||
if (
|
||||
isinstance(query_response, StreamingResponse)
|
||||
and query_response.response_gen is not None
|
||||
):
|
||||
# override the generator to include writing to chat history
|
||||
self._memory.put(ChatMessage(role=MessageRole.USER, content=message))
|
||||
response = StreamingAgentChatResponse(
|
||||
chat_stream=response_gen_from_query_engine(query_response.response_gen),
|
||||
sources=[tool_output],
|
||||
)
|
||||
thread = Thread(
|
||||
target=response.write_response_to_history, args=(self._memory, True)
|
||||
)
|
||||
thread.start()
|
||||
else:
|
||||
raise ValueError("Streaming is not enabled. Please use chat() instead.")
|
||||
return response
|
||||
|
||||
@trace_method("chat")
|
||||
async def achat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AgentChatResponse:
|
||||
chat_history = chat_history or self._memory.get()
|
||||
|
||||
# Generate standalone question from conversation context and last message
|
||||
condensed_question = await self._acondense_question(chat_history, message)
|
||||
|
||||
log_str = f"Querying with: {condensed_question}"
|
||||
logger.info(log_str)
|
||||
if self._verbose:
|
||||
print(log_str)
|
||||
|
||||
# TODO: right now, query engine uses class attribute to configure streaming,
|
||||
# we are moving towards separate streaming and non-streaming methods.
|
||||
# In the meanwhile, use this hack to toggle streaming.
|
||||
from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
|
||||
|
||||
if isinstance(self._query_engine, RetrieverQueryEngine):
|
||||
is_streaming = self._query_engine._response_synthesizer._streaming
|
||||
self._query_engine._response_synthesizer._streaming = False
|
||||
|
||||
# Query with standalone question
|
||||
query_response = await self._query_engine.aquery(condensed_question)
|
||||
|
||||
# NOTE: reset streaming flag
|
||||
if isinstance(self._query_engine, RetrieverQueryEngine):
|
||||
self._query_engine._response_synthesizer._streaming = is_streaming
|
||||
|
||||
tool_output = self._get_tool_output_from_response(
|
||||
condensed_question, query_response
|
||||
)
|
||||
|
||||
# Record response
|
||||
self._memory.put(ChatMessage(role=MessageRole.USER, content=message))
|
||||
self._memory.put(
|
||||
ChatMessage(role=MessageRole.ASSISTANT, content=str(query_response))
|
||||
)
|
||||
|
||||
return AgentChatResponse(response=str(query_response), sources=[tool_output])
|
||||
|
||||
@trace_method("chat")
|
||||
async def astream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
chat_history = chat_history or self._memory.get()
|
||||
|
||||
# Generate standalone question from conversation context and last message
|
||||
condensed_question = await self._acondense_question(chat_history, message)
|
||||
|
||||
log_str = f"Querying with: {condensed_question}"
|
||||
logger.info(log_str)
|
||||
if self._verbose:
|
||||
print(log_str)
|
||||
|
||||
# TODO: right now, query engine uses class attribute to configure streaming,
|
||||
# we are moving towards separate streaming and non-streaming methods.
|
||||
# In the meanwhile, use this hack to toggle streaming.
|
||||
from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
|
||||
|
||||
if isinstance(self._query_engine, RetrieverQueryEngine):
|
||||
is_streaming = self._query_engine._response_synthesizer._streaming
|
||||
self._query_engine._response_synthesizer._streaming = True
|
||||
|
||||
# Query with standalone question
|
||||
query_response = await self._query_engine.aquery(condensed_question)
|
||||
|
||||
# NOTE: reset streaming flag
|
||||
if isinstance(self._query_engine, RetrieverQueryEngine):
|
||||
self._query_engine._response_synthesizer._streaming = is_streaming
|
||||
|
||||
tool_output = self._get_tool_output_from_response(
|
||||
condensed_question, query_response
|
||||
)
|
||||
|
||||
# Record response
|
||||
if (
|
||||
isinstance(query_response, StreamingResponse)
|
||||
and query_response.response_gen is not None
|
||||
):
|
||||
# override the generator to include writing to chat history
|
||||
# TODO: query engine does not support async generator yet
|
||||
self._memory.put(ChatMessage(role=MessageRole.USER, content=message))
|
||||
response = StreamingAgentChatResponse(
|
||||
chat_stream=response_gen_from_query_engine(query_response.response_gen),
|
||||
sources=[tool_output],
|
||||
)
|
||||
thread = Thread(
|
||||
target=response.write_response_to_history, args=(self._memory,)
|
||||
)
|
||||
thread.start()
|
||||
else:
|
||||
raise ValueError("Streaming is not enabled. Please use achat() instead.")
|
||||
return response
|
||||
|
||||
def reset(self) -> None:
|
||||
# Clear chat history
|
||||
self._memory.reset()
|
||||
|
||||
@property
|
||||
def chat_history(self) -> List[ChatMessage]:
|
||||
"""Get chat history."""
|
||||
return self._memory.get_all()
|
||||
|
|
@ -0,0 +1,301 @@
|
|||
import asyncio
|
||||
from threading import Thread
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from llama_index.callbacks import CallbackManager, trace_method
|
||||
from llama_index.chat_engine.types import (
|
||||
AgentChatResponse,
|
||||
BaseChatEngine,
|
||||
StreamingAgentChatResponse,
|
||||
ToolOutput,
|
||||
)
|
||||
from llama_index.core.base_retriever import BaseRetriever
|
||||
from llama_index.core.llms.types import ChatMessage, MessageRole
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.memory import BaseMemory, ChatMemoryBuffer
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle
|
||||
from llama_index.service_context import ServiceContext
|
||||
|
||||
DEFAULT_CONTEXT_TEMPLATE = (
|
||||
"Context information is below."
|
||||
"\n--------------------\n"
|
||||
"{context_str}"
|
||||
"\n--------------------\n"
|
||||
)
|
||||
|
||||
|
||||
class ContextChatEngine(BaseChatEngine):
|
||||
"""Context Chat Engine.
|
||||
|
||||
Uses a retriever to retrieve a context, set the context in the system prompt,
|
||||
and then uses an LLM to generate a response, for a fluid chat experience.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever: BaseRetriever,
|
||||
llm: LLM,
|
||||
memory: BaseMemory,
|
||||
prefix_messages: List[ChatMessage],
|
||||
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
||||
context_template: Optional[str] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
) -> None:
|
||||
self._retriever = retriever
|
||||
self._llm = llm
|
||||
self._memory = memory
|
||||
self._prefix_messages = prefix_messages
|
||||
self._node_postprocessors = node_postprocessors or []
|
||||
self._context_template = context_template or DEFAULT_CONTEXT_TEMPLATE
|
||||
|
||||
self.callback_manager = callback_manager or CallbackManager([])
|
||||
for node_postprocessor in self._node_postprocessors:
|
||||
node_postprocessor.callback_manager = self.callback_manager
|
||||
|
||||
@classmethod
|
||||
def from_defaults(
|
||||
cls,
|
||||
retriever: BaseRetriever,
|
||||
service_context: Optional[ServiceContext] = None,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
prefix_messages: Optional[List[ChatMessage]] = None,
|
||||
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
||||
context_template: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> "ContextChatEngine":
|
||||
"""Initialize a ContextChatEngine from default parameters."""
|
||||
service_context = service_context or ServiceContext.from_defaults()
|
||||
llm = service_context.llm
|
||||
|
||||
chat_history = chat_history or []
|
||||
memory = memory or ChatMemoryBuffer.from_defaults(
|
||||
chat_history=chat_history, token_limit=llm.metadata.context_window - 256
|
||||
)
|
||||
|
||||
if system_prompt is not None:
|
||||
if prefix_messages is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both system_prompt and prefix_messages"
|
||||
)
|
||||
prefix_messages = [
|
||||
ChatMessage(content=system_prompt, role=llm.metadata.system_role)
|
||||
]
|
||||
|
||||
prefix_messages = prefix_messages or []
|
||||
node_postprocessors = node_postprocessors or []
|
||||
|
||||
return cls(
|
||||
retriever,
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
prefix_messages=prefix_messages,
|
||||
node_postprocessors=node_postprocessors,
|
||||
callback_manager=service_context.callback_manager,
|
||||
context_template=context_template,
|
||||
)
|
||||
|
||||
def _generate_context(self, message: str) -> Tuple[str, List[NodeWithScore]]:
|
||||
"""Generate context information from a message."""
|
||||
nodes = self._retriever.retrieve(message)
|
||||
for postprocessor in self._node_postprocessors:
|
||||
nodes = postprocessor.postprocess_nodes(
|
||||
nodes, query_bundle=QueryBundle(message)
|
||||
)
|
||||
|
||||
context_str = "\n\n".join(
|
||||
[n.node.get_content(metadata_mode=MetadataMode.LLM).strip() for n in nodes]
|
||||
)
|
||||
|
||||
return self._context_template.format(context_str=context_str), nodes
|
||||
|
||||
async def _agenerate_context(self, message: str) -> Tuple[str, List[NodeWithScore]]:
|
||||
"""Generate context information from a message."""
|
||||
nodes = await self._retriever.aretrieve(message)
|
||||
for postprocessor in self._node_postprocessors:
|
||||
nodes = postprocessor.postprocess_nodes(
|
||||
nodes, query_bundle=QueryBundle(message)
|
||||
)
|
||||
context_str = "\n\n".join(
|
||||
[n.node.get_content(metadata_mode=MetadataMode.LLM).strip() for n in nodes]
|
||||
)
|
||||
|
||||
return self._context_template.format(context_str=context_str), nodes
|
||||
|
||||
def _get_prefix_messages_with_context(self, context_str: str) -> List[ChatMessage]:
|
||||
"""Get the prefix messages with context."""
|
||||
# ensure we grab the user-configured system prompt
|
||||
system_prompt = ""
|
||||
prefix_messages = self._prefix_messages
|
||||
if (
|
||||
len(self._prefix_messages) != 0
|
||||
and self._prefix_messages[0].role == MessageRole.SYSTEM
|
||||
):
|
||||
system_prompt = str(self._prefix_messages[0].content)
|
||||
prefix_messages = self._prefix_messages[1:]
|
||||
|
||||
context_str_w_sys_prompt = system_prompt.strip() + "\n" + context_str
|
||||
return [
|
||||
ChatMessage(
|
||||
content=context_str_w_sys_prompt, role=self._llm.metadata.system_role
|
||||
),
|
||||
*prefix_messages,
|
||||
]
|
||||
|
||||
@trace_method("chat")
|
||||
def chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AgentChatResponse:
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
self._memory.put(ChatMessage(content=message, role="user"))
|
||||
|
||||
context_str_template, nodes = self._generate_context(message)
|
||||
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
|
||||
prefix_messages_token_count = len(
|
||||
self._memory.tokenizer_fn(
|
||||
" ".join([(m.content or "") for m in prefix_messages])
|
||||
)
|
||||
)
|
||||
all_messages = prefix_messages + self._memory.get(
|
||||
initial_token_count=prefix_messages_token_count
|
||||
)
|
||||
chat_response = self._llm.chat(all_messages)
|
||||
ai_message = chat_response.message
|
||||
self._memory.put(ai_message)
|
||||
|
||||
return AgentChatResponse(
|
||||
response=str(chat_response.message.content),
|
||||
sources=[
|
||||
ToolOutput(
|
||||
tool_name="retriever",
|
||||
content=str(prefix_messages[0]),
|
||||
raw_input={"message": message},
|
||||
raw_output=prefix_messages[0],
|
||||
)
|
||||
],
|
||||
source_nodes=nodes,
|
||||
)
|
||||
|
||||
@trace_method("chat")
|
||||
def stream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
self._memory.put(ChatMessage(content=message, role="user"))
|
||||
|
||||
context_str_template, nodes = self._generate_context(message)
|
||||
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
|
||||
initial_token_count = len(
|
||||
self._memory.tokenizer_fn(
|
||||
" ".join([(m.content or "") for m in prefix_messages])
|
||||
)
|
||||
)
|
||||
all_messages = prefix_messages + self._memory.get(
|
||||
initial_token_count=initial_token_count
|
||||
)
|
||||
|
||||
chat_response = StreamingAgentChatResponse(
|
||||
chat_stream=self._llm.stream_chat(all_messages),
|
||||
sources=[
|
||||
ToolOutput(
|
||||
tool_name="retriever",
|
||||
content=str(prefix_messages[0]),
|
||||
raw_input={"message": message},
|
||||
raw_output=prefix_messages[0],
|
||||
)
|
||||
],
|
||||
source_nodes=nodes,
|
||||
)
|
||||
thread = Thread(
|
||||
target=chat_response.write_response_to_history, args=(self._memory,)
|
||||
)
|
||||
thread.start()
|
||||
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
async def achat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AgentChatResponse:
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
self._memory.put(ChatMessage(content=message, role="user"))
|
||||
|
||||
context_str_template, nodes = await self._agenerate_context(message)
|
||||
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
|
||||
initial_token_count = len(
|
||||
self._memory.tokenizer_fn(
|
||||
" ".join([(m.content or "") for m in prefix_messages])
|
||||
)
|
||||
)
|
||||
all_messages = prefix_messages + self._memory.get(
|
||||
initial_token_count=initial_token_count
|
||||
)
|
||||
|
||||
chat_response = await self._llm.achat(all_messages)
|
||||
ai_message = chat_response.message
|
||||
self._memory.put(ai_message)
|
||||
|
||||
return AgentChatResponse(
|
||||
response=str(chat_response.message.content),
|
||||
sources=[
|
||||
ToolOutput(
|
||||
tool_name="retriever",
|
||||
content=str(prefix_messages[0]),
|
||||
raw_input={"message": message},
|
||||
raw_output=prefix_messages[0],
|
||||
)
|
||||
],
|
||||
source_nodes=nodes,
|
||||
)
|
||||
|
||||
@trace_method("chat")
|
||||
async def astream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
self._memory.put(ChatMessage(content=message, role="user"))
|
||||
|
||||
context_str_template, nodes = await self._agenerate_context(message)
|
||||
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
|
||||
initial_token_count = len(
|
||||
self._memory.tokenizer_fn(
|
||||
" ".join([(m.content or "") for m in prefix_messages])
|
||||
)
|
||||
)
|
||||
all_messages = prefix_messages + self._memory.get(
|
||||
initial_token_count=initial_token_count
|
||||
)
|
||||
|
||||
chat_response = StreamingAgentChatResponse(
|
||||
achat_stream=await self._llm.astream_chat(all_messages),
|
||||
sources=[
|
||||
ToolOutput(
|
||||
tool_name="retriever",
|
||||
content=str(prefix_messages[0]),
|
||||
raw_input={"message": message},
|
||||
raw_output=prefix_messages[0],
|
||||
)
|
||||
],
|
||||
source_nodes=nodes,
|
||||
)
|
||||
thread = Thread(
|
||||
target=lambda x: asyncio.run(chat_response.awrite_response_to_history(x)),
|
||||
args=(self._memory,),
|
||||
)
|
||||
thread.start()
|
||||
|
||||
return chat_response
|
||||
|
||||
def reset(self) -> None:
|
||||
self._memory.reset()
|
||||
|
||||
@property
|
||||
def chat_history(self) -> List[ChatMessage]:
|
||||
"""Get chat history."""
|
||||
return self._memory.get_all()
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
import asyncio
|
||||
from threading import Thread
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
from llama_index.callbacks import CallbackManager, trace_method
|
||||
from llama_index.chat_engine.types import (
|
||||
AgentChatResponse,
|
||||
BaseChatEngine,
|
||||
StreamingAgentChatResponse,
|
||||
)
|
||||
from llama_index.core.llms.types import ChatMessage
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.memory import BaseMemory, ChatMemoryBuffer
|
||||
from llama_index.service_context import ServiceContext
|
||||
|
||||
|
||||
class SimpleChatEngine(BaseChatEngine):
|
||||
"""Simple Chat Engine.
|
||||
|
||||
Have a conversation with the LLM.
|
||||
This does not make use of a knowledge base.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
memory: BaseMemory,
|
||||
prefix_messages: List[ChatMessage],
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
) -> None:
|
||||
self._llm = llm
|
||||
self._memory = memory
|
||||
self._prefix_messages = prefix_messages
|
||||
self.callback_manager = callback_manager or CallbackManager([])
|
||||
|
||||
@classmethod
|
||||
def from_defaults(
|
||||
cls,
|
||||
service_context: Optional[ServiceContext] = None,
|
||||
chat_history: Optional[List[ChatMessage]] = None,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
memory_cls: Type[BaseMemory] = ChatMemoryBuffer,
|
||||
system_prompt: Optional[str] = None,
|
||||
prefix_messages: Optional[List[ChatMessage]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "SimpleChatEngine":
|
||||
"""Initialize a SimpleChatEngine from default parameters."""
|
||||
service_context = service_context or ServiceContext.from_defaults()
|
||||
llm = service_context.llm
|
||||
|
||||
chat_history = chat_history or []
|
||||
memory = memory or memory_cls.from_defaults(chat_history=chat_history, llm=llm)
|
||||
|
||||
if system_prompt is not None:
|
||||
if prefix_messages is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both system_prompt and prefix_messages"
|
||||
)
|
||||
prefix_messages = [
|
||||
ChatMessage(content=system_prompt, role=llm.metadata.system_role)
|
||||
]
|
||||
|
||||
prefix_messages = prefix_messages or []
|
||||
|
||||
return cls(
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
prefix_messages=prefix_messages,
|
||||
callback_manager=service_context.callback_manager,
|
||||
)
|
||||
|
||||
@trace_method("chat")
|
||||
def chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AgentChatResponse:
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
self._memory.put(ChatMessage(content=message, role="user"))
|
||||
initial_token_count = len(
|
||||
self._memory.tokenizer_fn(
|
||||
" ".join([(m.content or "") for m in self._prefix_messages])
|
||||
)
|
||||
)
|
||||
all_messages = self._prefix_messages + self._memory.get(
|
||||
initial_token_count=initial_token_count
|
||||
)
|
||||
|
||||
chat_response = self._llm.chat(all_messages)
|
||||
ai_message = chat_response.message
|
||||
self._memory.put(ai_message)
|
||||
|
||||
return AgentChatResponse(response=str(chat_response.message.content))
|
||||
|
||||
@trace_method("chat")
|
||||
def stream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
self._memory.put(ChatMessage(content=message, role="user"))
|
||||
initial_token_count = len(
|
||||
self._memory.tokenizer_fn(
|
||||
" ".join([(m.content or "") for m in self._prefix_messages])
|
||||
)
|
||||
)
|
||||
all_messages = self._prefix_messages + self._memory.get(
|
||||
initial_token_count=initial_token_count
|
||||
)
|
||||
|
||||
chat_response = StreamingAgentChatResponse(
|
||||
chat_stream=self._llm.stream_chat(all_messages)
|
||||
)
|
||||
thread = Thread(
|
||||
target=chat_response.write_response_to_history, args=(self._memory,)
|
||||
)
|
||||
thread.start()
|
||||
|
||||
return chat_response
|
||||
|
||||
@trace_method("chat")
|
||||
async def achat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AgentChatResponse:
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
self._memory.put(ChatMessage(content=message, role="user"))
|
||||
initial_token_count = len(
|
||||
self._memory.tokenizer_fn(
|
||||
" ".join([(m.content or "") for m in self._prefix_messages])
|
||||
)
|
||||
)
|
||||
all_messages = self._prefix_messages + self._memory.get(
|
||||
initial_token_count=initial_token_count
|
||||
)
|
||||
|
||||
chat_response = await self._llm.achat(all_messages)
|
||||
ai_message = chat_response.message
|
||||
self._memory.put(ai_message)
|
||||
|
||||
return AgentChatResponse(response=str(chat_response.message.content))
|
||||
|
||||
@trace_method("chat")
|
||||
async def astream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
if chat_history is not None:
|
||||
self._memory.set(chat_history)
|
||||
self._memory.put(ChatMessage(content=message, role="user"))
|
||||
initial_token_count = len(
|
||||
self._memory.tokenizer_fn(
|
||||
" ".join([(m.content or "") for m in self._prefix_messages])
|
||||
)
|
||||
)
|
||||
all_messages = self._prefix_messages + self._memory.get(
|
||||
initial_token_count=initial_token_count
|
||||
)
|
||||
|
||||
chat_response = StreamingAgentChatResponse(
|
||||
achat_stream=await self._llm.astream_chat(all_messages)
|
||||
)
|
||||
thread = Thread(
|
||||
target=lambda x: asyncio.run(chat_response.awrite_response_to_history(x)),
|
||||
args=(self._memory,),
|
||||
)
|
||||
thread.start()
|
||||
|
||||
return chat_response
|
||||
|
||||
def reset(self) -> None:
|
||||
self._memory.reset()
|
||||
|
||||
@property
|
||||
def chat_history(self) -> List[ChatMessage]:
|
||||
"""Get chat history."""
|
||||
return self._memory.get_all()
|
||||
|
|
@ -0,0 +1,314 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from threading import Event
|
||||
from typing import AsyncGenerator, Generator, List, Optional, Union
|
||||
|
||||
from llama_index.core.llms.types import (
|
||||
ChatMessage,
|
||||
ChatResponseAsyncGen,
|
||||
ChatResponseGen,
|
||||
)
|
||||
from llama_index.core.response.schema import Response, StreamingResponse
|
||||
from llama_index.memory import BaseMemory
|
||||
from llama_index.schema import NodeWithScore
|
||||
from llama_index.tools import ToolOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def is_function(message: ChatMessage) -> bool:
|
||||
"""Utility for ChatMessage responses from OpenAI models."""
|
||||
return "tool_calls" in message.additional_kwargs
|
||||
|
||||
|
||||
class ChatResponseMode(str, Enum):
|
||||
"""Flag toggling waiting/streaming in `Agent._chat`."""
|
||||
|
||||
WAIT = "wait"
|
||||
STREAM = "stream"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentChatResponse:
|
||||
"""Agent chat response."""
|
||||
|
||||
response: str = ""
|
||||
sources: List[ToolOutput] = field(default_factory=list)
|
||||
source_nodes: List[NodeWithScore] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.sources and not self.source_nodes:
|
||||
for tool_output in self.sources:
|
||||
if isinstance(tool_output.raw_output, (Response, StreamingResponse)):
|
||||
self.source_nodes.extend(tool_output.raw_output.source_nodes)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.response
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingAgentChatResponse:
|
||||
"""Streaming chat response to user and writing to chat history."""
|
||||
|
||||
response: str = ""
|
||||
sources: List[ToolOutput] = field(default_factory=list)
|
||||
chat_stream: Optional[ChatResponseGen] = None
|
||||
achat_stream: Optional[ChatResponseAsyncGen] = None
|
||||
source_nodes: List[NodeWithScore] = field(default_factory=list)
|
||||
_unformatted_response: str = ""
|
||||
_queue: queue.Queue = field(default_factory=queue.Queue)
|
||||
_aqueue: asyncio.Queue = field(default_factory=asyncio.Queue)
|
||||
# flag when chat message is a function call
|
||||
_is_function: Optional[bool] = None
|
||||
# flag when processing done
|
||||
_is_done = False
|
||||
# signal when a new item is added to the queue
|
||||
_new_item_event: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
# NOTE: async code uses two events rather than one since it yields
|
||||
# control when waiting for queue item
|
||||
# signal when the OpenAI functions stop executing
|
||||
_is_function_false_event: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
# signal when an OpenAI function is being executed
|
||||
_is_function_not_none_thread_event: Event = field(default_factory=Event)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.sources and not self.source_nodes:
|
||||
for tool_output in self.sources:
|
||||
if isinstance(tool_output.raw_output, (Response, StreamingResponse)):
|
||||
self.source_nodes.extend(tool_output.raw_output.source_nodes)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self._is_done and not self._queue.empty() and not self._is_function:
|
||||
while self._queue.queue:
|
||||
delta = self._queue.queue.popleft()
|
||||
self._unformatted_response += delta
|
||||
self.response = self._unformatted_response.strip()
|
||||
return self.response
|
||||
|
||||
def put_in_queue(self, delta: Optional[str]) -> None:
|
||||
self._queue.put_nowait(delta)
|
||||
self._is_function_not_none_thread_event.set()
|
||||
|
||||
def aput_in_queue(self, delta: Optional[str]) -> None:
|
||||
self._aqueue.put_nowait(delta)
|
||||
self._new_item_event.set()
|
||||
|
||||
def write_response_to_history(
|
||||
self, memory: BaseMemory, raise_error: bool = False
|
||||
) -> None:
|
||||
if self.chat_stream is None:
|
||||
raise ValueError(
|
||||
"chat_stream is None. Cannot write to history without chat_stream."
|
||||
)
|
||||
|
||||
# try/except to prevent hanging on error
|
||||
try:
|
||||
final_text = ""
|
||||
for chat in self.chat_stream:
|
||||
self._is_function = is_function(chat.message)
|
||||
self.put_in_queue(chat.delta)
|
||||
final_text += chat.delta or ""
|
||||
if self._is_function is not None: # if loop has gone through iteration
|
||||
# NOTE: this is to handle the special case where we consume some of the
|
||||
# chat stream, but not all of it (e.g. in react agent)
|
||||
chat.message.content = final_text.strip() # final message
|
||||
memory.put(chat.message)
|
||||
except Exception as e:
|
||||
if not raise_error:
|
||||
logger.warning(
|
||||
f"Encountered exception writing response to history: {e}"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
self._is_done = True
|
||||
|
||||
# This act as is_done events for any consumers waiting
|
||||
self._is_function_not_none_thread_event.set()
|
||||
|
||||
async def awrite_response_to_history(
|
||||
self,
|
||||
memory: BaseMemory,
|
||||
) -> None:
|
||||
if self.achat_stream is None:
|
||||
raise ValueError(
|
||||
"achat_stream is None. Cannot asynchronously write to "
|
||||
"history without achat_stream."
|
||||
)
|
||||
|
||||
# try/except to prevent hanging on error
|
||||
try:
|
||||
final_text = ""
|
||||
async for chat in self.achat_stream:
|
||||
self._is_function = is_function(chat.message)
|
||||
self.aput_in_queue(chat.delta)
|
||||
final_text += chat.delta or ""
|
||||
self._new_item_event.set()
|
||||
if self._is_function is False:
|
||||
self._is_function_false_event.set()
|
||||
if self._is_function is not None: # if loop has gone through iteration
|
||||
# NOTE: this is to handle the special case where we consume some of the
|
||||
# chat stream, but not all of it (e.g. in react agent)
|
||||
chat.message.content = final_text.strip() # final message
|
||||
memory.put(chat.message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Encountered exception writing response to history: {e}")
|
||||
self._is_done = True
|
||||
|
||||
# These act as is_done events for any consumers waiting
|
||||
self._is_function_false_event.set()
|
||||
self._new_item_event.set()
|
||||
|
||||
@property
|
||||
def response_gen(self) -> Generator[str, None, None]:
|
||||
while not self._is_done or not self._queue.empty():
|
||||
try:
|
||||
delta = self._queue.get(block=False)
|
||||
self._unformatted_response += delta
|
||||
yield delta
|
||||
except queue.Empty:
|
||||
# Queue is empty, but we're not done yet
|
||||
time.sleep(0.01)
|
||||
self.response = self._unformatted_response.strip()
|
||||
|
||||
async def async_response_gen(self) -> AsyncGenerator[str, None]:
|
||||
while not self._is_done or not self._aqueue.empty():
|
||||
if not self._aqueue.empty():
|
||||
delta = self._aqueue.get_nowait()
|
||||
self._unformatted_response += delta
|
||||
yield delta
|
||||
else:
|
||||
await self._new_item_event.wait() # Wait until a new item is added
|
||||
self._new_item_event.clear() # Clear the event for the next wait
|
||||
self.response = self._unformatted_response.strip()
|
||||
|
||||
def print_response_stream(self) -> None:
|
||||
for token in self.response_gen:
|
||||
print(token, end="", flush=True)
|
||||
|
||||
async def aprint_response_stream(self) -> None:
|
||||
async for token in self.async_response_gen():
|
||||
print(token, end="", flush=True)
|
||||
|
||||
|
||||
AGENT_CHAT_RESPONSE_TYPE = Union[AgentChatResponse, StreamingAgentChatResponse]
|
||||
|
||||
|
||||
class BaseChatEngine(ABC):
|
||||
"""Base Chat Engine."""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset conversation state."""
|
||||
|
||||
@abstractmethod
|
||||
def chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Main chat interface."""
|
||||
|
||||
@abstractmethod
|
||||
def stream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
"""Stream chat interface."""
|
||||
|
||||
@abstractmethod
|
||||
async def achat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Async version of main chat interface."""
|
||||
|
||||
@abstractmethod
|
||||
async def astream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
"""Async version of main chat interface."""
|
||||
|
||||
def chat_repl(self) -> None:
|
||||
"""Enter interactive chat REPL."""
|
||||
print("===== Entering Chat REPL =====")
|
||||
print('Type "exit" to exit.\n')
|
||||
self.reset()
|
||||
message = input("Human: ")
|
||||
while message != "exit":
|
||||
response = self.chat(message)
|
||||
print(f"Assistant: {response}\n")
|
||||
message = input("Human: ")
|
||||
|
||||
def streaming_chat_repl(self) -> None:
|
||||
"""Enter interactive chat REPL with streaming responses."""
|
||||
print("===== Entering Chat REPL =====")
|
||||
print('Type "exit" to exit.\n')
|
||||
self.reset()
|
||||
message = input("Human: ")
|
||||
while message != "exit":
|
||||
response = self.stream_chat(message)
|
||||
print("Assistant: ", end="", flush=True)
|
||||
response.print_response_stream()
|
||||
print("\n")
|
||||
message = input("Human: ")
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def chat_history(self) -> List[ChatMessage]:
|
||||
pass
|
||||
|
||||
|
||||
class ChatMode(str, Enum):
|
||||
"""Chat Engine Modes."""
|
||||
|
||||
SIMPLE = "simple"
|
||||
"""Corresponds to `SimpleChatEngine`.
|
||||
|
||||
Chat with LLM, without making use of a knowledge base.
|
||||
"""
|
||||
|
||||
CONDENSE_QUESTION = "condense_question"
|
||||
"""Corresponds to `CondenseQuestionChatEngine`.
|
||||
|
||||
First generate a standalone question from conversation context and last message,
|
||||
then query the query engine for a response.
|
||||
"""
|
||||
|
||||
CONTEXT = "context"
|
||||
"""Corresponds to `ContextChatEngine`.
|
||||
|
||||
First retrieve text from the index using the user's message, then use the context
|
||||
in the system prompt to generate a response.
|
||||
"""
|
||||
|
||||
CONDENSE_PLUS_CONTEXT = "condense_plus_context"
|
||||
"""Corresponds to `CondensePlusContextChatEngine`.
|
||||
|
||||
First condense a conversation and latest user message to a standalone question.
|
||||
Then build a context for the standalone question from a retriever,
|
||||
Then pass the context along with prompt and user message to LLM to generate a response.
|
||||
"""
|
||||
|
||||
REACT = "react"
|
||||
"""Corresponds to `ReActAgent`.
|
||||
|
||||
Use a ReAct agent loop with query engine tools.
|
||||
"""
|
||||
|
||||
OPENAI = "openai"
|
||||
"""Corresponds to `OpenAIAgent`.
|
||||
|
||||
Use an OpenAI function calling agent loop.
|
||||
|
||||
NOTE: only works with OpenAI models that support function calling API.
|
||||
"""
|
||||
|
||||
BEST = "best"
|
||||
"""Select the best chat engine based on the current LLM.
|
||||
|
||||
Corresponds to `OpenAIAgent` if using an OpenAI model that supports
|
||||
function calling API, otherwise, corresponds to `ReActAgent`.
|
||||
"""
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from llama_index.core.llms.types import (
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ChatResponseGen,
|
||||
MessageRole,
|
||||
)
|
||||
from llama_index.types import TokenGen
|
||||
|
||||
|
||||
def response_gen_from_query_engine(response_gen: TokenGen) -> ChatResponseGen:
|
||||
response_str = ""
|
||||
for token in response_gen:
|
||||
response_str += token
|
||||
yield ChatResponse(
|
||||
message=ChatMessage(role=MessageRole.ASSISTANT, content=response_str),
|
||||
delta=token,
|
||||
)
|
||||
|
|
@ -0,0 +1,172 @@
|
|||
import argparse
|
||||
from typing import Any, Optional
|
||||
|
||||
from llama_index.command_line.rag import RagCLI, default_ragcli_persist_dir
|
||||
from llama_index.embeddings import OpenAIEmbedding
|
||||
from llama_index.ingestion import IngestionCache, IngestionPipeline
|
||||
from llama_index.llama_dataset.download import (
|
||||
LLAMA_DATASETS_LFS_URL,
|
||||
LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL,
|
||||
download_llama_dataset,
|
||||
)
|
||||
from llama_index.llama_pack.download import LLAMA_HUB_URL, download_llama_pack
|
||||
from llama_index.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
from llama_index.vector_stores import ChromaVectorStore
|
||||
|
||||
|
||||
def handle_download_llama_pack(
|
||||
llama_pack_class: Optional[str] = None,
|
||||
download_dir: Optional[str] = None,
|
||||
llama_hub_url: str = LLAMA_HUB_URL,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
assert llama_pack_class is not None
|
||||
assert download_dir is not None
|
||||
|
||||
download_llama_pack(
|
||||
llama_pack_class=llama_pack_class,
|
||||
download_dir=download_dir,
|
||||
llama_hub_url=llama_hub_url,
|
||||
)
|
||||
print(f"Successfully downloaded {llama_pack_class} to {download_dir}")
|
||||
|
||||
|
||||
def handle_download_llama_dataset(
|
||||
llama_dataset_class: Optional[str] = None,
|
||||
download_dir: Optional[str] = None,
|
||||
llama_hub_url: str = LLAMA_HUB_URL,
|
||||
llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL,
|
||||
llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
assert llama_dataset_class is not None
|
||||
assert download_dir is not None
|
||||
|
||||
download_llama_dataset(
|
||||
llama_dataset_class=llama_dataset_class,
|
||||
download_dir=download_dir,
|
||||
llama_hub_url=llama_hub_url,
|
||||
llama_datasets_lfs_url=llama_datasets_lfs_url,
|
||||
llama_datasets_source_files_tree_url=llama_datasets_source_files_tree_url,
|
||||
show_progress=True,
|
||||
load_documents=False,
|
||||
)
|
||||
|
||||
print(f"Successfully downloaded {llama_dataset_class} to {download_dir}")
|
||||
|
||||
|
||||
def default_rag_cli() -> RagCLI:
|
||||
import chromadb
|
||||
|
||||
persist_dir = default_ragcli_persist_dir()
|
||||
chroma_client = chromadb.PersistentClient(path=persist_dir)
|
||||
chroma_collection = chroma_client.create_collection("default", get_or_create=True)
|
||||
vector_store = ChromaVectorStore(
|
||||
chroma_collection=chroma_collection, persist_dir=persist_dir
|
||||
)
|
||||
docstore = SimpleDocumentStore()
|
||||
|
||||
ingestion_pipeline = IngestionPipeline(
|
||||
transformations=[SentenceSplitter(), OpenAIEmbedding()],
|
||||
vector_store=vector_store,
|
||||
docstore=docstore,
|
||||
cache=IngestionCache(),
|
||||
)
|
||||
try:
|
||||
ingestion_pipeline.load(persist_dir=persist_dir)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return RagCLI(
|
||||
ingestion_pipeline=ingestion_pipeline,
|
||||
verbose=False,
|
||||
persist_dir=persist_dir,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="LlamaIndex CLI tool.")
|
||||
|
||||
# Subparsers for the main commands
|
||||
subparsers = parser.add_subparsers(title="commands", dest="command", required=True)
|
||||
|
||||
# llama rag command
|
||||
llamarag_parser = subparsers.add_parser(
|
||||
"rag", help="Ask a question to a document / a directory of documents."
|
||||
)
|
||||
RagCLI.add_parser_args(llamarag_parser, default_rag_cli)
|
||||
|
||||
# download llamapacks command
|
||||
llamapack_parser = subparsers.add_parser(
|
||||
"download-llamapack", help="Download a llama-pack"
|
||||
)
|
||||
llamapack_parser.add_argument(
|
||||
"llama_pack_class",
|
||||
type=str,
|
||||
help=(
|
||||
"The name of the llama-pack class you want to download, "
|
||||
"such as `GmailOpenAIAgentPack`."
|
||||
),
|
||||
)
|
||||
llamapack_parser.add_argument(
|
||||
"-d",
|
||||
"--download-dir",
|
||||
type=str,
|
||||
default="./llama_packs",
|
||||
help="Custom dirpath to download the pack into.",
|
||||
)
|
||||
llamapack_parser.add_argument(
|
||||
"--llama-hub-url",
|
||||
type=str,
|
||||
default=LLAMA_HUB_URL,
|
||||
help="URL to llama hub.",
|
||||
)
|
||||
llamapack_parser.set_defaults(
|
||||
func=lambda args: handle_download_llama_pack(**vars(args))
|
||||
)
|
||||
|
||||
# download llamadatasets command
|
||||
llamadataset_parser = subparsers.add_parser(
|
||||
"download-llamadataset", help="Download a llama-dataset"
|
||||
)
|
||||
llamadataset_parser.add_argument(
|
||||
"llama_dataset_class",
|
||||
type=str,
|
||||
help=(
|
||||
"The name of the llama-dataset class you want to download, "
|
||||
"such as `PaulGrahamEssayDataset`."
|
||||
),
|
||||
)
|
||||
llamadataset_parser.add_argument(
|
||||
"-d",
|
||||
"--download-dir",
|
||||
type=str,
|
||||
default="./llama_datasets",
|
||||
help="Custom dirpath to download the pack into.",
|
||||
)
|
||||
llamadataset_parser.add_argument(
|
||||
"--llama-hub-url",
|
||||
type=str,
|
||||
default=LLAMA_HUB_URL,
|
||||
help="URL to llama hub.",
|
||||
)
|
||||
llamadataset_parser.add_argument(
|
||||
"--llama-datasets-lfs-url",
|
||||
type=str,
|
||||
default=LLAMA_DATASETS_LFS_URL,
|
||||
help="URL to llama datasets.",
|
||||
)
|
||||
llamadataset_parser.set_defaults(
|
||||
func=lambda args: handle_download_llama_dataset(**vars(args))
|
||||
)
|
||||
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Call the appropriate function based on the command
|
||||
args.func(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,373 @@
|
|||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
from argparse import ArgumentParser
|
||||
from glob import iglob
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Union, cast
|
||||
|
||||
from llama_index import (
|
||||
Response,
|
||||
ServiceContext,
|
||||
SimpleDirectoryReader,
|
||||
VectorStoreIndex,
|
||||
)
|
||||
from llama_index.bridge.pydantic import BaseModel, Field, validator
|
||||
from llama_index.chat_engine import CondenseQuestionChatEngine
|
||||
from llama_index.core.response.schema import RESPONSE_TYPE, StreamingResponse
|
||||
from llama_index.embeddings.base import BaseEmbedding
|
||||
from llama_index.ingestion import IngestionPipeline
|
||||
from llama_index.llms import LLM, OpenAI
|
||||
from llama_index.query_engine import CustomQueryEngine
|
||||
from llama_index.query_pipeline import FnComponent
|
||||
from llama_index.query_pipeline.query import QueryPipeline
|
||||
from llama_index.readers.base import BaseReader
|
||||
from llama_index.response_synthesizers import CompactAndRefine
|
||||
from llama_index.utils import get_cache_dir
|
||||
|
||||
RAG_HISTORY_FILE_NAME = "files_history.txt"
|
||||
|
||||
|
||||
def default_ragcli_persist_dir() -> str:
|
||||
return str(Path(get_cache_dir()) / "rag_cli")
|
||||
|
||||
|
||||
def query_input(query_str: Optional[str] = None) -> str:
|
||||
return query_str or ""
|
||||
|
||||
|
||||
class QueryPipelineQueryEngine(CustomQueryEngine):
|
||||
query_pipeline: QueryPipeline = Field(
|
||||
description="Query Pipeline to use for Q&A.",
|
||||
)
|
||||
|
||||
def custom_query(self, query_str: str) -> RESPONSE_TYPE:
|
||||
return self.query_pipeline.run(query_str=query_str)
|
||||
|
||||
async def acustom_query(self, query_str: str) -> RESPONSE_TYPE:
|
||||
return await self.query_pipeline.arun(query_str=query_str)
|
||||
|
||||
|
||||
class RagCLI(BaseModel):
|
||||
"""
|
||||
CLI tool for chatting with output of a IngestionPipeline via a QueryPipeline.
|
||||
"""
|
||||
|
||||
ingestion_pipeline: IngestionPipeline = Field(
|
||||
description="Ingestion pipeline to run for RAG ingestion."
|
||||
)
|
||||
verbose: bool = Field(
|
||||
description="Whether to print out verbose information during execution.",
|
||||
default=False,
|
||||
)
|
||||
persist_dir: str = Field(
|
||||
description="Directory to persist ingestion pipeline.",
|
||||
default_factory=default_ragcli_persist_dir,
|
||||
)
|
||||
llm: LLM = Field(
|
||||
description="Language model to use for response generation.",
|
||||
default_factory=lambda: OpenAI(model="gpt-3.5-turbo", streaming=True),
|
||||
)
|
||||
query_pipeline: Optional[QueryPipeline] = Field(
|
||||
description="Query Pipeline to use for Q&A.",
|
||||
default=None,
|
||||
)
|
||||
chat_engine: Optional[CondenseQuestionChatEngine] = Field(
|
||||
description="Chat engine to use for chatting.",
|
||||
default_factory=None,
|
||||
)
|
||||
file_extractor: Optional[Dict[str, BaseReader]] = Field(
|
||||
description="File extractor to use for extracting text from files.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("query_pipeline", always=True)
|
||||
def query_pipeline_from_ingestion_pipeline(
|
||||
cls, query_pipeline: Any, values: Dict[str, Any]
|
||||
) -> Optional[QueryPipeline]:
|
||||
"""
|
||||
If query_pipeline is not provided, create one from ingestion_pipeline.
|
||||
"""
|
||||
if query_pipeline is not None:
|
||||
return query_pipeline
|
||||
|
||||
ingestion_pipeline = cast(IngestionPipeline, values["ingestion_pipeline"])
|
||||
if ingestion_pipeline.vector_store is None:
|
||||
return None
|
||||
verbose = cast(bool, values["verbose"])
|
||||
query_component = FnComponent(
|
||||
fn=query_input, output_key="output", req_params={"query_str"}
|
||||
)
|
||||
llm = cast(LLM, values["llm"])
|
||||
|
||||
# get embed_model from transformations if possible
|
||||
embed_model = None
|
||||
if ingestion_pipeline.transformations is not None:
|
||||
for transformation in ingestion_pipeline.transformations:
|
||||
if isinstance(transformation, BaseEmbedding):
|
||||
embed_model = transformation
|
||||
break
|
||||
|
||||
service_context = ServiceContext.from_defaults(
|
||||
llm=llm, embed_model=embed_model or "default"
|
||||
)
|
||||
retriever = VectorStoreIndex.from_vector_store(
|
||||
ingestion_pipeline.vector_store, service_context=service_context
|
||||
).as_retriever(similarity_top_k=8)
|
||||
response_synthesizer = CompactAndRefine(
|
||||
service_context=service_context, streaming=True, verbose=verbose
|
||||
)
|
||||
|
||||
# define query pipeline
|
||||
query_pipeline = QueryPipeline(verbose=verbose)
|
||||
query_pipeline.add_modules(
|
||||
{
|
||||
"query": query_component,
|
||||
"retriever": retriever,
|
||||
"summarizer": response_synthesizer,
|
||||
}
|
||||
)
|
||||
query_pipeline.add_link("query", "retriever")
|
||||
query_pipeline.add_link("retriever", "summarizer", dest_key="nodes")
|
||||
query_pipeline.add_link("query", "summarizer", dest_key="query_str")
|
||||
return query_pipeline
|
||||
|
||||
@validator("chat_engine", always=True)
|
||||
def chat_engine_from_query_pipeline(
|
||||
cls, chat_engine: Any, values: Dict[str, Any]
|
||||
) -> Optional[CondenseQuestionChatEngine]:
|
||||
"""
|
||||
If chat_engine is not provided, create one from query_pipeline.
|
||||
"""
|
||||
if chat_engine is not None:
|
||||
return chat_engine
|
||||
|
||||
if values.get("query_pipeline", None) is None:
|
||||
values["query_pipeline"] = cls.query_pipeline_from_ingestion_pipeline(
|
||||
query_pipeline=None, values=values
|
||||
)
|
||||
|
||||
query_pipeline = cast(QueryPipeline, values["query_pipeline"])
|
||||
if query_pipeline is None:
|
||||
return None
|
||||
query_engine = QueryPipelineQueryEngine(query_pipeline=query_pipeline) # type: ignore
|
||||
verbose = cast(bool, values["verbose"])
|
||||
llm = cast(LLM, values["llm"])
|
||||
return CondenseQuestionChatEngine.from_defaults(
|
||||
query_engine=query_engine, llm=llm, verbose=verbose
|
||||
)
|
||||
|
||||
async def handle_cli(
|
||||
self,
|
||||
files: Optional[str] = None,
|
||||
question: Optional[str] = None,
|
||||
chat: bool = False,
|
||||
verbose: bool = False,
|
||||
clear: bool = False,
|
||||
create_llama: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Entrypoint for local document RAG CLI tool.
|
||||
"""
|
||||
if clear:
|
||||
# delete self.persist_dir directory including all subdirectories and files
|
||||
if os.path.exists(self.persist_dir):
|
||||
# Ask for confirmation
|
||||
response = input(
|
||||
f"Are you sure you want to delete data within {self.persist_dir}? [y/N] "
|
||||
)
|
||||
if response.strip().lower() != "y":
|
||||
print("Aborted.")
|
||||
return
|
||||
os.system(f"rm -rf {self.persist_dir}")
|
||||
print(f"Successfully cleared {self.persist_dir}")
|
||||
|
||||
self.verbose = verbose
|
||||
ingestion_pipeline = cast(IngestionPipeline, self.ingestion_pipeline)
|
||||
if self.verbose:
|
||||
print("Saving/Loading from persist_dir: ", self.persist_dir)
|
||||
if files is not None:
|
||||
documents = []
|
||||
for _file in iglob(files, recursive=True):
|
||||
_file = os.path.abspath(_file)
|
||||
if os.path.isdir(_file):
|
||||
reader = SimpleDirectoryReader(
|
||||
input_dir=_file,
|
||||
filename_as_id=True,
|
||||
file_extractor=self.file_extractor,
|
||||
)
|
||||
else:
|
||||
reader = SimpleDirectoryReader(
|
||||
input_files=[_file],
|
||||
filename_as_id=True,
|
||||
file_extractor=self.file_extractor,
|
||||
)
|
||||
|
||||
documents.extend(reader.load_data(show_progress=verbose))
|
||||
|
||||
await ingestion_pipeline.arun(show_progress=verbose, documents=documents)
|
||||
ingestion_pipeline.persist(persist_dir=self.persist_dir)
|
||||
|
||||
# Append the `--files` argument to the history file
|
||||
with open(f"{self.persist_dir}/{RAG_HISTORY_FILE_NAME}", "a") as f:
|
||||
f.write(files + "\n")
|
||||
|
||||
if create_llama:
|
||||
if shutil.which("npx") is None:
|
||||
print(
|
||||
"`npx` is not installed. Please install it by calling `npm install -g npx`"
|
||||
)
|
||||
else:
|
||||
history_file_path = Path(f"{self.persist_dir}/{RAG_HISTORY_FILE_NAME}")
|
||||
if not history_file_path.exists():
|
||||
print(
|
||||
"No data has been ingested, "
|
||||
"please specify `--files` to create llama dataset."
|
||||
)
|
||||
else:
|
||||
with open(history_file_path) as f:
|
||||
stored_paths = {line.strip() for line in f if line.strip()}
|
||||
if len(stored_paths) == 0:
|
||||
print(
|
||||
"No data has been ingested, "
|
||||
"please specify `--files` to create llama dataset."
|
||||
)
|
||||
elif len(stored_paths) > 1:
|
||||
print(
|
||||
"Multiple files or folders were ingested, which is not supported by create-llama. "
|
||||
"Please call `llamaindex-cli rag --clear` to clear the cache first, "
|
||||
"then call `llamaindex-cli rag --files` again with a single folder or file"
|
||||
)
|
||||
else:
|
||||
path = stored_paths.pop()
|
||||
if "*" in path:
|
||||
print(
|
||||
"Glob pattern is not supported by create-llama. "
|
||||
"Please call `llamaindex-cli rag --clear` to clear the cache first, "
|
||||
"then call `llamaindex-cli rag --files` again with a single folder or file."
|
||||
)
|
||||
elif not os.path.exists(path):
|
||||
print(
|
||||
f"The path {path} does not exist. "
|
||||
"Please call `llamaindex-cli rag --clear` to clear the cache first, "
|
||||
"then call `llamaindex-cli rag --files` again with a single folder or file."
|
||||
)
|
||||
else:
|
||||
print(f"Calling create-llama using data from {path} ...")
|
||||
command_args = [
|
||||
"npx",
|
||||
"create-llama@latest",
|
||||
"--frontend",
|
||||
"--template",
|
||||
"streaming",
|
||||
"--framework",
|
||||
"fastapi",
|
||||
"--ui",
|
||||
"shadcn",
|
||||
"--vector-db",
|
||||
"none",
|
||||
"--engine",
|
||||
"context",
|
||||
f"--files {path}",
|
||||
]
|
||||
os.system(" ".join(command_args))
|
||||
|
||||
if question is not None:
|
||||
await self.handle_question(question)
|
||||
if chat:
|
||||
await self.start_chat_repl()
|
||||
|
||||
async def handle_question(self, question: str) -> None:
|
||||
if self.query_pipeline is None:
|
||||
raise ValueError("query_pipeline is not defined.")
|
||||
query_pipeline = cast(QueryPipeline, self.query_pipeline)
|
||||
query_pipeline.verbose = self.verbose
|
||||
chat_engine = cast(CondenseQuestionChatEngine, self.chat_engine)
|
||||
response = chat_engine.chat(question)
|
||||
|
||||
if isinstance(response, StreamingResponse):
|
||||
response.print_response_stream()
|
||||
else:
|
||||
response = cast(Response, response)
|
||||
print(response)
|
||||
|
||||
async def start_chat_repl(self) -> None:
|
||||
"""
|
||||
Start a REPL for chatting with the agent.
|
||||
"""
|
||||
if self.query_pipeline is None:
|
||||
raise ValueError("query_pipeline is not defined.")
|
||||
chat_engine = cast(CondenseQuestionChatEngine, self.chat_engine)
|
||||
chat_engine.streaming_chat_repl()
|
||||
|
||||
@classmethod
|
||||
def add_parser_args(
|
||||
cls,
|
||||
parser: Union[ArgumentParser, Any],
|
||||
instance_generator: Callable[[], "RagCLI"],
|
||||
) -> None:
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--question",
|
||||
type=str,
|
||||
help="The question you want to ask.",
|
||||
required=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--files",
|
||||
type=str,
|
||||
help=(
|
||||
"The name of the file or directory you want to ask a question about,"
|
||||
'such as "file.pdf".'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--chat",
|
||||
help="If flag is present, opens a chat REPL.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
help="Whether to print out verbose information during execution.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clear",
|
||||
help="Clears out all currently embedded data.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--create-llama",
|
||||
help="Create a LlamaIndex application with your embedded data.",
|
||||
required=False,
|
||||
action="store_true",
|
||||
)
|
||||
parser.set_defaults(
|
||||
func=lambda args: asyncio.run(instance_generator().handle_cli(**vars(args)))
|
||||
)
|
||||
|
||||
def cli(self) -> None:
|
||||
"""
|
||||
Entrypoint for CLI tool.
|
||||
"""
|
||||
parser = ArgumentParser(description="LlamaIndex RAG Q&A tool.")
|
||||
subparsers = parser.add_subparsers(
|
||||
title="commands", dest="command", required=True
|
||||
)
|
||||
llamarag_parser = subparsers.add_parser(
|
||||
"rag", help="Ask a question to a document / a directory of documents."
|
||||
)
|
||||
self.add_parser_args(llamarag_parser, lambda: self)
|
||||
# Parse the command-line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Call the appropriate function based on the command
|
||||
args.func(args)
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""Init composability."""
|
||||
|
||||
|
||||
from llama_index.composability.base import ComposableGraph
|
||||
from llama_index.composability.joint_qa_summary import QASummaryQueryEngineBuilder
|
||||
|
||||
__all__ = ["ComposableGraph", "QASummaryQueryEngineBuilder"]
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
"""Composable graph."""
|
||||
|
||||
# TODO: remove this file, only keep for backwards compatibility
|
||||
from llama_index.indices.composability.graph import ComposableGraph # noqa
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
"""Joint QA Summary graph."""
|
||||
|
||||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from llama_index.indices.list.base import SummaryIndex
|
||||
from llama_index.indices.vector_store import VectorStoreIndex
|
||||
from llama_index.ingestion import run_transformations
|
||||
from llama_index.query_engine.router_query_engine import RouterQueryEngine
|
||||
from llama_index.schema import Document
|
||||
from llama_index.service_context import ServiceContext
|
||||
from llama_index.storage.storage_context import StorageContext
|
||||
from llama_index.tools.query_engine import QueryEngineTool
|
||||
|
||||
DEFAULT_SUMMARY_TEXT = "Use this index for summarization queries"
|
||||
DEFAULT_QA_TEXT = (
|
||||
"Use this index for queries that require retrieval of specific "
|
||||
"context from documents."
|
||||
)
|
||||
|
||||
|
||||
class QASummaryQueryEngineBuilder:
|
||||
"""Joint QA Summary graph builder.
|
||||
|
||||
Can build a graph that provides a unified query interface
|
||||
for both QA and summarization tasks.
|
||||
|
||||
NOTE: this is a beta feature. The API may change in the future.
|
||||
|
||||
Args:
|
||||
docstore (BaseDocumentStore): A BaseDocumentStore to use for storing nodes.
|
||||
service_context (ServiceContext): A ServiceContext to use for
|
||||
building indices.
|
||||
summary_text (str): Text to use for the summary index.
|
||||
qa_text (str): Text to use for the QA index.
|
||||
node_parser (NodeParser): A NodeParser to use for parsing.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_context: Optional[StorageContext] = None,
|
||||
service_context: Optional[ServiceContext] = None,
|
||||
summary_text: str = DEFAULT_SUMMARY_TEXT,
|
||||
qa_text: str = DEFAULT_QA_TEXT,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
self._storage_context = storage_context or StorageContext.from_defaults()
|
||||
self._service_context = service_context or ServiceContext.from_defaults()
|
||||
self._summary_text = summary_text
|
||||
self._qa_text = qa_text
|
||||
|
||||
def build_from_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
) -> RouterQueryEngine:
|
||||
"""Build query engine."""
|
||||
# parse nodes
|
||||
nodes = run_transformations(
|
||||
documents, self._service_context.transformations # type: ignore
|
||||
)
|
||||
|
||||
# ingest nodes
|
||||
self._storage_context.docstore.add_documents(nodes, allow_update=True)
|
||||
|
||||
# build indices
|
||||
vector_index = VectorStoreIndex(
|
||||
nodes,
|
||||
service_context=self._service_context,
|
||||
storage_context=self._storage_context,
|
||||
)
|
||||
summary_index = SummaryIndex(
|
||||
nodes,
|
||||
service_context=self._service_context,
|
||||
storage_context=self._storage_context,
|
||||
)
|
||||
|
||||
vector_query_engine = vector_index.as_query_engine(
|
||||
service_context=self._service_context
|
||||
)
|
||||
list_query_engine = summary_index.as_query_engine(
|
||||
service_context=self._service_context,
|
||||
response_mode="tree_summarize",
|
||||
)
|
||||
|
||||
# build query engine
|
||||
return RouterQueryEngine.from_defaults(
|
||||
query_engine_tools=[
|
||||
QueryEngineTool.from_defaults(
|
||||
vector_query_engine, description=self._qa_text
|
||||
),
|
||||
QueryEngineTool.from_defaults(
|
||||
list_query_engine, description=self._summary_text
|
||||
),
|
||||
],
|
||||
service_context=self._service_context,
|
||||
select_multi=False,
|
||||
)
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
"""Set of constants."""
|
||||
|
||||
DEFAULT_TEMPERATURE = 0.1
|
||||
DEFAULT_CONTEXT_WINDOW = 3900 # tokens
|
||||
DEFAULT_NUM_OUTPUTS = 256 # tokens
|
||||
DEFAULT_NUM_INPUT_FILES = 10 # files
|
||||
|
||||
DEFAULT_EMBED_BATCH_SIZE = 10
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1024 # tokens
|
||||
DEFAULT_CHUNK_OVERLAP = 20 # tokens
|
||||
DEFAULT_SIMILARITY_TOP_K = 2
|
||||
DEFAULT_IMAGE_SIMILARITY_TOP_K = 2
|
||||
|
||||
# NOTE: for text-embedding-ada-002
|
||||
DEFAULT_EMBEDDING_DIM = 1536
|
||||
|
||||
# context window size for llm predictor
|
||||
COHERE_CONTEXT_WINDOW = 2048
|
||||
AI21_J2_CONTEXT_WINDOW = 8192
|
||||
|
||||
|
||||
TYPE_KEY = "__type__"
|
||||
DATA_KEY = "__data__"
|
||||
VECTOR_STORE_KEY = "vector_store"
|
||||
IMAGE_STORE_KEY = "image_store"
|
||||
GRAPH_STORE_KEY = "graph_store"
|
||||
INDEX_STORE_KEY = "index_store"
|
||||
DOC_STORE_KEY = "doc_store"
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from llama_index.bridge.pydantic import BaseModel
|
||||
from llama_index.core.base_retriever import BaseRetriever
|
||||
from llama_index.schema import NodeWithScore, QueryBundle
|
||||
|
||||
|
||||
class BaseAutoRetriever(BaseRetriever):
|
||||
"""Base auto retriever."""
|
||||
|
||||
@abstractmethod
|
||||
def generate_retrieval_spec(
|
||||
self, query_bundle: QueryBundle, **kwargs: Any
|
||||
) -> BaseModel:
|
||||
"""Generate retrieval spec synchronously."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate_retrieval_spec(
|
||||
self, query_bundle: QueryBundle, **kwargs: Any
|
||||
) -> BaseModel:
|
||||
"""Generate retrieval spec asynchronously."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _build_retriever_from_spec(
|
||||
self, retrieval_spec: BaseModel
|
||||
) -> Tuple[BaseRetriever, QueryBundle]:
|
||||
"""Build retriever from spec and provide query bundle."""
|
||||
...
|
||||
|
||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
"""Retrieve using generated spec."""
|
||||
retrieval_spec = self.generate_retrieval_spec(query_bundle)
|
||||
retriever, new_query_bundle = self._build_retriever_from_spec(retrieval_spec)
|
||||
return retriever.retrieve(new_query_bundle)
|
||||
|
||||
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
"""Retrieve using generated spec asynchronously."""
|
||||
retrieval_spec = await self.agenerate_retrieval_spec(query_bundle)
|
||||
retriever, new_query_bundle = self._build_retriever_from_spec(retrieval_spec)
|
||||
return await retriever.aretrieve(new_query_bundle)
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
"""base multi modal retriever."""
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
from llama_index.core.base_retriever import BaseRetriever
|
||||
from llama_index.core.image_retriever import BaseImageRetriever
|
||||
from llama_index.indices.query.schema import QueryType
|
||||
from llama_index.schema import NodeWithScore
|
||||
|
||||
|
||||
class MultiModalRetriever(BaseRetriever, BaseImageRetriever):
|
||||
"""Multi Modal base retriever."""
|
||||
|
||||
@abstractmethod
|
||||
def text_retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
|
||||
"""Retrieve text nodes given text query.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def text_to_image_retrieve(
|
||||
self, str_or_query_bundle: QueryType
|
||||
) -> List[NodeWithScore]:
|
||||
"""Retrieve image nodes given text query.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def image_to_image_retrieve(
|
||||
self, str_or_query_bundle: QueryType
|
||||
) -> List[NodeWithScore]:
|
||||
"""Retrieve image nodes given image query.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def atext_retrieve(
|
||||
self, str_or_query_bundle: QueryType
|
||||
) -> List[NodeWithScore]:
|
||||
"""Async Retrieve text nodes given text query.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def atext_to_image_retrieve(
|
||||
self, str_or_query_bundle: QueryType
|
||||
) -> List[NodeWithScore]:
|
||||
"""Async Retrieve image nodes given text query.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aimage_to_image_retrieve(
|
||||
self, str_or_query_bundle: QueryType
|
||||
) -> List[NodeWithScore]:
|
||||
"""Async Retrieve image nodes given image query.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
"""Base query engine."""
|
||||
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from llama_index.bridge.pydantic import Field
|
||||
from llama_index.callbacks.base import CallbackManager
|
||||
from llama_index.core.query_pipeline.query_component import (
|
||||
ChainableMixin,
|
||||
InputKeys,
|
||||
OutputKeys,
|
||||
QueryComponent,
|
||||
validate_and_convert_stringable,
|
||||
)
|
||||
from llama_index.core.response.schema import RESPONSE_TYPE
|
||||
from llama_index.prompts.mixin import PromptDictType, PromptMixin
|
||||
from llama_index.schema import NodeWithScore, QueryBundle, QueryType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseQueryEngine(ChainableMixin, PromptMixin):
|
||||
"""Base query engine."""
|
||||
|
||||
def __init__(self, callback_manager: Optional[CallbackManager]) -> None:
|
||||
self.callback_manager = callback_manager or CallbackManager([])
|
||||
|
||||
def _get_prompts(self) -> Dict[str, Any]:
|
||||
"""Get prompts."""
|
||||
return {}
|
||||
|
||||
def _update_prompts(self, prompts: PromptDictType) -> None:
|
||||
"""Update prompts."""
|
||||
|
||||
def query(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE:
|
||||
with self.callback_manager.as_trace("query"):
|
||||
if isinstance(str_or_query_bundle, str):
|
||||
str_or_query_bundle = QueryBundle(str_or_query_bundle)
|
||||
return self._query(str_or_query_bundle)
|
||||
|
||||
async def aquery(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE:
|
||||
with self.callback_manager.as_trace("query"):
|
||||
if isinstance(str_or_query_bundle, str):
|
||||
str_or_query_bundle = QueryBundle(str_or_query_bundle)
|
||||
return await self._aquery(str_or_query_bundle)
|
||||
|
||||
def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
raise NotImplementedError(
|
||||
"This query engine does not support retrieve, use query directly"
|
||||
)
|
||||
|
||||
def synthesize(
|
||||
self,
|
||||
query_bundle: QueryBundle,
|
||||
nodes: List[NodeWithScore],
|
||||
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
|
||||
) -> RESPONSE_TYPE:
|
||||
raise NotImplementedError(
|
||||
"This query engine does not support synthesize, use query directly"
|
||||
)
|
||||
|
||||
async def asynthesize(
|
||||
self,
|
||||
query_bundle: QueryBundle,
|
||||
nodes: List[NodeWithScore],
|
||||
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
|
||||
) -> RESPONSE_TYPE:
|
||||
raise NotImplementedError(
|
||||
"This query engine does not support asynthesize, use aquery directly"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
||||
pass
|
||||
|
||||
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
|
||||
"""Return a query component."""
|
||||
return QueryEngineComponent(query_engine=self)
|
||||
|
||||
|
||||
class QueryEngineComponent(QueryComponent):
|
||||
"""Query engine component."""
|
||||
|
||||
query_engine: BaseQueryEngine = Field(..., description="Query engine")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
||||
"""Set callback manager."""
|
||||
self.query_engine.callback_manager = callback_manager
|
||||
|
||||
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component inputs during run_component."""
|
||||
# make sure input is a string
|
||||
input["input"] = validate_and_convert_stringable(input["input"])
|
||||
return input
|
||||
|
||||
def _run_component(self, **kwargs: Any) -> Any:
|
||||
"""Run component."""
|
||||
output = self.query_engine.query(kwargs["input"])
|
||||
return {"output": output}
|
||||
|
||||
async def _arun_component(self, **kwargs: Any) -> Any:
|
||||
"""Run component."""
|
||||
output = await self.query_engine.aquery(kwargs["input"])
|
||||
return {"output": output}
|
||||
|
||||
@property
|
||||
def input_keys(self) -> InputKeys:
|
||||
"""Input keys."""
|
||||
return InputKeys.from_keys({"input"})
|
||||
|
||||
@property
|
||||
def output_keys(self) -> OutputKeys:
|
||||
"""Output keys."""
|
||||
return OutputKeys.from_keys({"output"})
|
||||
|
|
@ -0,0 +1,324 @@
|
|||
"""Base retriever."""
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_index.bridge.pydantic import Field
|
||||
from llama_index.callbacks.base import CallbackManager
|
||||
from llama_index.callbacks.schema import CBEventType, EventPayload
|
||||
from llama_index.core.base_query_engine import BaseQueryEngine
|
||||
from llama_index.core.query_pipeline.query_component import (
|
||||
ChainableMixin,
|
||||
InputKeys,
|
||||
OutputKeys,
|
||||
QueryComponent,
|
||||
validate_and_convert_stringable,
|
||||
)
|
||||
from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType
|
||||
from llama_index.schema import (
|
||||
BaseNode,
|
||||
IndexNode,
|
||||
NodeWithScore,
|
||||
QueryBundle,
|
||||
QueryType,
|
||||
TextNode,
|
||||
)
|
||||
from llama_index.service_context import ServiceContext
|
||||
from llama_index.utils import print_text
|
||||
|
||||
|
||||
class BaseRetriever(ChainableMixin, PromptMixin):
|
||||
"""Base retriever."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
object_map: Optional[Dict] = None,
|
||||
objects: Optional[List[IndexNode]] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
self.callback_manager = callback_manager or CallbackManager()
|
||||
|
||||
if objects is not None:
|
||||
object_map = {obj.index_id: obj.obj for obj in objects}
|
||||
|
||||
self.object_map = object_map or {}
|
||||
self._verbose = verbose
|
||||
|
||||
def _check_callback_manager(self) -> None:
|
||||
"""Check callback manager."""
|
||||
if not hasattr(self, "callback_manager"):
|
||||
self.callback_manager = CallbackManager()
|
||||
|
||||
def _get_prompts(self) -> PromptDictType:
|
||||
"""Get prompts."""
|
||||
return {}
|
||||
|
||||
def _get_prompt_modules(self) -> PromptMixinType:
|
||||
"""Get prompt modules."""
|
||||
return {}
|
||||
|
||||
def _update_prompts(self, prompts: PromptDictType) -> None:
|
||||
"""Update prompts."""
|
||||
|
||||
def _retrieve_from_object(
|
||||
self,
|
||||
obj: Any,
|
||||
query_bundle: QueryBundle,
|
||||
score: float,
|
||||
) -> List[NodeWithScore]:
|
||||
"""Retrieve nodes from object."""
|
||||
if self._verbose:
|
||||
print_text(
|
||||
f"Retrieving from object {obj.__class__.__name__} with query {query_bundle.query_str}\n",
|
||||
color="llama_pink",
|
||||
)
|
||||
if isinstance(obj, NodeWithScore):
|
||||
return [obj]
|
||||
elif isinstance(obj, BaseNode):
|
||||
return [NodeWithScore(node=obj, score=score)]
|
||||
elif isinstance(obj, BaseQueryEngine):
|
||||
response = obj.query(query_bundle)
|
||||
return [
|
||||
NodeWithScore(
|
||||
node=TextNode(text=str(response), metadata=response.metadata or {}),
|
||||
score=score,
|
||||
)
|
||||
]
|
||||
elif isinstance(obj, BaseRetriever):
|
||||
return obj.retrieve(query_bundle)
|
||||
elif isinstance(obj, QueryComponent):
|
||||
component_keys = obj.input_keys.required_keys
|
||||
if len(component_keys) > 1:
|
||||
raise ValueError(
|
||||
f"QueryComponent {obj} has more than one input key: {component_keys}"
|
||||
)
|
||||
elif len(component_keys) == 0:
|
||||
component_response = obj.run_component()
|
||||
else:
|
||||
kwargs = {next(iter(component_keys)): query_bundle.query_str}
|
||||
component_response = obj.run_component(**kwargs)
|
||||
|
||||
result_output = str(next(iter(component_response.values())))
|
||||
return [NodeWithScore(node=TextNode(text=result_output), score=score)]
|
||||
else:
|
||||
raise ValueError(f"Object {obj} is not retrievable.")
|
||||
|
||||
async def _aretrieve_from_object(
|
||||
self,
|
||||
obj: Any,
|
||||
query_bundle: QueryBundle,
|
||||
score: float,
|
||||
) -> List[NodeWithScore]:
|
||||
"""Retrieve nodes from object."""
|
||||
if isinstance(obj, NodeWithScore):
|
||||
return [obj]
|
||||
elif isinstance(obj, BaseNode):
|
||||
return [NodeWithScore(node=obj, score=score)]
|
||||
elif isinstance(obj, BaseQueryEngine):
|
||||
response = await obj.aquery(query_bundle)
|
||||
return [NodeWithScore(node=TextNode(text=str(response)), score=score)]
|
||||
elif isinstance(obj, BaseRetriever):
|
||||
return await obj.aretrieve(query_bundle)
|
||||
elif isinstance(obj, QueryComponent):
|
||||
component_keys = obj.input_keys.required_keys
|
||||
if len(component_keys) > 1:
|
||||
raise ValueError(
|
||||
f"QueryComponent {obj} has more than one input key: {component_keys}"
|
||||
)
|
||||
elif len(component_keys) == 0:
|
||||
component_response = await obj.arun_component()
|
||||
else:
|
||||
kwargs = {next(iter(component_keys)): query_bundle.query_str}
|
||||
component_response = await obj.arun_component(**kwargs)
|
||||
|
||||
result_output = str(next(iter(component_response.values())))
|
||||
return [NodeWithScore(node=TextNode(text=result_output), score=score)]
|
||||
else:
|
||||
raise ValueError(f"Object {obj} is not retrievable.")
|
||||
|
||||
def _handle_recursive_retrieval(
|
||||
self, query_bundle: QueryBundle, nodes: List[NodeWithScore]
|
||||
) -> List[NodeWithScore]:
|
||||
retrieved_nodes: List[NodeWithScore] = []
|
||||
for n in nodes:
|
||||
node = n.node
|
||||
score = n.score or 1.0
|
||||
if isinstance(node, IndexNode):
|
||||
obj = self.object_map.get(node.index_id, None)
|
||||
if obj is not None:
|
||||
if self._verbose:
|
||||
print_text(
|
||||
f"Retrieval entering {node.index_id}: {obj.__class__.__name__}\n",
|
||||
color="llama_turquoise",
|
||||
)
|
||||
retrieved_nodes.extend(
|
||||
self._retrieve_from_object(
|
||||
obj, query_bundle=query_bundle, score=score
|
||||
)
|
||||
)
|
||||
else:
|
||||
retrieved_nodes.append(n)
|
||||
else:
|
||||
retrieved_nodes.append(n)
|
||||
|
||||
seen = set()
|
||||
return [
|
||||
n
|
||||
for n in retrieved_nodes
|
||||
if not (n.node.hash in seen or seen.add(n.node.hash)) # type: ignore[func-returns-value]
|
||||
]
|
||||
|
||||
async def _ahandle_recursive_retrieval(
|
||||
self, query_bundle: QueryBundle, nodes: List[NodeWithScore]
|
||||
) -> List[NodeWithScore]:
|
||||
retrieved_nodes: List[NodeWithScore] = []
|
||||
for n in nodes:
|
||||
node = n.node
|
||||
score = n.score or 1.0
|
||||
if isinstance(node, IndexNode):
|
||||
obj = self.object_map.get(node.index_id, None)
|
||||
if obj is not None:
|
||||
if self._verbose:
|
||||
print_text(
|
||||
f"Retrieval entering {node.index_id}: {obj.__class__.__name__}\n",
|
||||
color="llama_turquoise",
|
||||
)
|
||||
# TODO: Add concurrent execution via `run_jobs()` ?
|
||||
retrieved_nodes.extend(
|
||||
await self._aretrieve_from_object(
|
||||
obj, query_bundle=query_bundle, score=score
|
||||
)
|
||||
)
|
||||
else:
|
||||
retrieved_nodes.append(n)
|
||||
else:
|
||||
retrieved_nodes.append(n)
|
||||
|
||||
# remove any duplicates based on hash
|
||||
seen = set()
|
||||
return [
|
||||
n
|
||||
for n in retrieved_nodes
|
||||
if not (n.node.hash in seen or seen.add(n.node.hash)) # type: ignore[func-returns-value]
|
||||
]
|
||||
|
||||
def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
|
||||
"""Retrieve nodes given query.
|
||||
|
||||
Args:
|
||||
str_or_query_bundle (QueryType): Either a query string or
|
||||
a QueryBundle object.
|
||||
|
||||
"""
|
||||
self._check_callback_manager()
|
||||
|
||||
if isinstance(str_or_query_bundle, str):
|
||||
query_bundle = QueryBundle(str_or_query_bundle)
|
||||
else:
|
||||
query_bundle = str_or_query_bundle
|
||||
with self.callback_manager.as_trace("query"):
|
||||
with self.callback_manager.event(
|
||||
CBEventType.RETRIEVE,
|
||||
payload={EventPayload.QUERY_STR: query_bundle.query_str},
|
||||
) as retrieve_event:
|
||||
nodes = self._retrieve(query_bundle)
|
||||
nodes = self._handle_recursive_retrieval(query_bundle, nodes)
|
||||
retrieve_event.on_end(
|
||||
payload={EventPayload.NODES: nodes},
|
||||
)
|
||||
|
||||
return nodes
|
||||
|
||||
async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
|
||||
self._check_callback_manager()
|
||||
|
||||
if isinstance(str_or_query_bundle, str):
|
||||
query_bundle = QueryBundle(str_or_query_bundle)
|
||||
else:
|
||||
query_bundle = str_or_query_bundle
|
||||
with self.callback_manager.as_trace("query"):
|
||||
with self.callback_manager.event(
|
||||
CBEventType.RETRIEVE,
|
||||
payload={EventPayload.QUERY_STR: query_bundle.query_str},
|
||||
) as retrieve_event:
|
||||
nodes = await self._aretrieve(query_bundle)
|
||||
nodes = await self._ahandle_recursive_retrieval(query_bundle, nodes)
|
||||
retrieve_event.on_end(
|
||||
payload={EventPayload.NODES: nodes},
|
||||
)
|
||||
|
||||
return nodes
|
||||
|
||||
@abstractmethod
|
||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
"""Retrieve nodes given query.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
|
||||
# TODO: make this abstract
|
||||
# @abstractmethod
|
||||
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
"""Asynchronously retrieve nodes given query.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
return self._retrieve(query_bundle)
|
||||
|
||||
def get_service_context(self) -> Optional[ServiceContext]:
|
||||
"""Attempts to resolve a service context.
|
||||
Short-circuits at self.service_context, self._service_context,
|
||||
or self._index.service_context.
|
||||
"""
|
||||
if hasattr(self, "service_context"):
|
||||
return self.service_context
|
||||
if hasattr(self, "_service_context"):
|
||||
return self._service_context
|
||||
elif hasattr(self, "_index") and hasattr(self._index, "service_context"):
|
||||
return self._index.service_context
|
||||
return None
|
||||
|
||||
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
|
||||
"""Return a query component."""
|
||||
return RetrieverComponent(retriever=self)
|
||||
|
||||
|
||||
class RetrieverComponent(QueryComponent):
|
||||
"""Retriever component."""
|
||||
|
||||
retriever: BaseRetriever = Field(..., description="Retriever")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
||||
"""Set callback manager."""
|
||||
self.retriever.callback_manager = callback_manager
|
||||
|
||||
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component inputs during run_component."""
|
||||
# make sure input is a string
|
||||
input["input"] = validate_and_convert_stringable(input["input"])
|
||||
return input
|
||||
|
||||
def _run_component(self, **kwargs: Any) -> Any:
|
||||
"""Run component."""
|
||||
output = self.retriever.retrieve(kwargs["input"])
|
||||
return {"output": output}
|
||||
|
||||
async def _arun_component(self, **kwargs: Any) -> Any:
|
||||
"""Run component."""
|
||||
output = await self.retriever.aretrieve(kwargs["input"])
|
||||
return {"output": output}
|
||||
|
||||
@property
|
||||
def input_keys(self) -> InputKeys:
|
||||
"""Input keys."""
|
||||
return InputKeys.from_keys({"input"})
|
||||
|
||||
@property
|
||||
def output_keys(self) -> OutputKeys:
|
||||
"""Output keys."""
|
||||
return OutputKeys.from_keys({"output"})
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Any, List, Sequence, Union
|
||||
|
||||
from llama_index.bridge.pydantic import BaseModel
|
||||
from llama_index.core.query_pipeline.query_component import (
|
||||
ChainableMixin,
|
||||
QueryComponent,
|
||||
)
|
||||
from llama_index.prompts.mixin import PromptMixin, PromptMixinType
|
||||
from llama_index.schema import QueryBundle, QueryType
|
||||
from llama_index.tools.types import ToolMetadata
|
||||
|
||||
MetadataType = Union[str, ToolMetadata]
|
||||
|
||||
|
||||
class SingleSelection(BaseModel):
|
||||
"""A single selection of a choice."""
|
||||
|
||||
index: int
|
||||
reason: str
|
||||
|
||||
|
||||
class MultiSelection(BaseModel):
|
||||
"""A multi-selection of choices."""
|
||||
|
||||
selections: List[SingleSelection]
|
||||
|
||||
@property
|
||||
def ind(self) -> int:
|
||||
if len(self.selections) != 1:
|
||||
raise ValueError(
|
||||
f"There are {len(self.selections)} selections, " "please use .inds."
|
||||
)
|
||||
return self.selections[0].index
|
||||
|
||||
@property
|
||||
def reason(self) -> str:
|
||||
if len(self.reasons) != 1:
|
||||
raise ValueError(
|
||||
f"There are {len(self.reasons)} selections, " "please use .reasons."
|
||||
)
|
||||
return self.selections[0].reason
|
||||
|
||||
@property
|
||||
def inds(self) -> List[int]:
|
||||
return [x.index for x in self.selections]
|
||||
|
||||
@property
|
||||
def reasons(self) -> List[str]:
|
||||
return [x.reason for x in self.selections]
|
||||
|
||||
|
||||
# separate name for clarity and to not confuse function calling model
|
||||
SelectorResult = MultiSelection
|
||||
|
||||
|
||||
def _wrap_choice(choice: MetadataType) -> ToolMetadata:
|
||||
if isinstance(choice, ToolMetadata):
|
||||
return choice
|
||||
elif isinstance(choice, str):
|
||||
return ToolMetadata(description=choice)
|
||||
else:
|
||||
raise ValueError(f"Unexpected type: {type(choice)}")
|
||||
|
||||
|
||||
def _wrap_query(query: QueryType) -> QueryBundle:
|
||||
if isinstance(query, QueryBundle):
|
||||
return query
|
||||
elif isinstance(query, str):
|
||||
return QueryBundle(query_str=query)
|
||||
else:
|
||||
raise ValueError(f"Unexpected type: {type(query)}")
|
||||
|
||||
|
||||
class BaseSelector(PromptMixin, ChainableMixin):
|
||||
"""Base selector."""
|
||||
|
||||
def _get_prompt_modules(self) -> PromptMixinType:
|
||||
"""Get prompt sub-modules."""
|
||||
return {}
|
||||
|
||||
def select(
|
||||
self, choices: Sequence[MetadataType], query: QueryType
|
||||
) -> SelectorResult:
|
||||
metadatas = [_wrap_choice(choice) for choice in choices]
|
||||
query_bundle = _wrap_query(query)
|
||||
return self._select(choices=metadatas, query=query_bundle)
|
||||
|
||||
async def aselect(
|
||||
self, choices: Sequence[MetadataType], query: QueryType
|
||||
) -> SelectorResult:
|
||||
metadatas = [_wrap_choice(choice) for choice in choices]
|
||||
query_bundle = _wrap_query(query)
|
||||
return await self._aselect(choices=metadatas, query=query_bundle)
|
||||
|
||||
@abstractmethod
|
||||
def _select(
|
||||
self, choices: Sequence[ToolMetadata], query: QueryBundle
|
||||
) -> SelectorResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _aselect(
|
||||
self, choices: Sequence[ToolMetadata], query: QueryBundle
|
||||
) -> SelectorResult:
|
||||
pass
|
||||
|
||||
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
|
||||
"""As query component."""
|
||||
from llama_index.query_pipeline.components.router import SelectorComponent
|
||||
|
||||
return SelectorComponent(selector=self)
|
||||
|
|
@ -0,0 +1,354 @@
|
|||
"""Base embeddings file."""
|
||||
|
||||
import asyncio
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Coroutine, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from llama_index.bridge.pydantic import Field, validator
|
||||
from llama_index.callbacks.base import CallbackManager
|
||||
from llama_index.callbacks.schema import CBEventType, EventPayload
|
||||
from llama_index.constants import (
|
||||
DEFAULT_EMBED_BATCH_SIZE,
|
||||
)
|
||||
from llama_index.schema import BaseNode, MetadataMode, TransformComponent
|
||||
from llama_index.utils import get_tqdm_iterable
|
||||
|
||||
# TODO: change to numpy array
|
||||
Embedding = List[float]
|
||||
|
||||
|
||||
class SimilarityMode(str, Enum):
|
||||
"""Modes for similarity/distance."""
|
||||
|
||||
DEFAULT = "cosine"
|
||||
DOT_PRODUCT = "dot_product"
|
||||
EUCLIDEAN = "euclidean"
|
||||
|
||||
|
||||
def mean_agg(embeddings: List[Embedding]) -> Embedding:
|
||||
"""Mean aggregation for embeddings."""
|
||||
return list(np.array(embeddings).mean(axis=0))
|
||||
|
||||
|
||||
def similarity(
|
||||
embedding1: Embedding,
|
||||
embedding2: Embedding,
|
||||
mode: SimilarityMode = SimilarityMode.DEFAULT,
|
||||
) -> float:
|
||||
"""Get embedding similarity."""
|
||||
if mode == SimilarityMode.EUCLIDEAN:
|
||||
# Using -euclidean distance as similarity to achieve same ranking order
|
||||
return -float(np.linalg.norm(np.array(embedding1) - np.array(embedding2)))
|
||||
elif mode == SimilarityMode.DOT_PRODUCT:
|
||||
return np.dot(embedding1, embedding2)
|
||||
else:
|
||||
product = np.dot(embedding1, embedding2)
|
||||
norm = np.linalg.norm(embedding1) * np.linalg.norm(embedding2)
|
||||
return product / norm
|
||||
|
||||
|
||||
class BaseEmbedding(TransformComponent):
|
||||
"""Base class for embeddings."""
|
||||
|
||||
model_name: str = Field(
|
||||
default="unknown", description="The name of the embedding model."
|
||||
)
|
||||
embed_batch_size: int = Field(
|
||||
default=DEFAULT_EMBED_BATCH_SIZE,
|
||||
description="The batch size for embedding calls.",
|
||||
gt=0,
|
||||
lte=2048,
|
||||
)
|
||||
callback_manager: CallbackManager = Field(
|
||||
default_factory=lambda: CallbackManager([]), exclude=True
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True)
|
||||
def _validate_callback_manager(
|
||||
cls, v: Optional[CallbackManager]
|
||||
) -> CallbackManager:
|
||||
if v is None:
|
||||
return CallbackManager([])
|
||||
return v
|
||||
|
||||
@abstractmethod
|
||||
def _get_query_embedding(self, query: str) -> Embedding:
|
||||
"""
|
||||
Embed the input query synchronously.
|
||||
|
||||
Subclasses should implement this method. Reference get_query_embedding's
|
||||
docstring for more information.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_query_embedding(self, query: str) -> Embedding:
|
||||
"""
|
||||
Embed the input query asynchronously.
|
||||
|
||||
Subclasses should implement this method. Reference get_query_embedding's
|
||||
docstring for more information.
|
||||
"""
|
||||
|
||||
def get_query_embedding(self, query: str) -> Embedding:
|
||||
"""
|
||||
Embed the input query.
|
||||
|
||||
When embedding a query, depending on the model, a special instruction
|
||||
can be prepended to the raw query string. For example, "Represent the
|
||||
question for retrieving supporting documents: ". If you're curious,
|
||||
other examples of predefined instructions can be found in
|
||||
embeddings/huggingface_utils.py.
|
||||
"""
|
||||
with self.callback_manager.event(
|
||||
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
|
||||
) as event:
|
||||
query_embedding = self._get_query_embedding(query)
|
||||
|
||||
event.on_end(
|
||||
payload={
|
||||
EventPayload.CHUNKS: [query],
|
||||
EventPayload.EMBEDDINGS: [query_embedding],
|
||||
},
|
||||
)
|
||||
return query_embedding
|
||||
|
||||
async def aget_query_embedding(self, query: str) -> Embedding:
|
||||
"""Get query embedding."""
|
||||
with self.callback_manager.event(
|
||||
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
|
||||
) as event:
|
||||
query_embedding = await self._aget_query_embedding(query)
|
||||
|
||||
event.on_end(
|
||||
payload={
|
||||
EventPayload.CHUNKS: [query],
|
||||
EventPayload.EMBEDDINGS: [query_embedding],
|
||||
},
|
||||
)
|
||||
return query_embedding
|
||||
|
||||
def get_agg_embedding_from_queries(
|
||||
self,
|
||||
queries: List[str],
|
||||
agg_fn: Optional[Callable[..., Embedding]] = None,
|
||||
) -> Embedding:
|
||||
"""Get aggregated embedding from multiple queries."""
|
||||
query_embeddings = [self.get_query_embedding(query) for query in queries]
|
||||
agg_fn = agg_fn or mean_agg
|
||||
return agg_fn(query_embeddings)
|
||||
|
||||
async def aget_agg_embedding_from_queries(
|
||||
self,
|
||||
queries: List[str],
|
||||
agg_fn: Optional[Callable[..., Embedding]] = None,
|
||||
) -> Embedding:
|
||||
"""Async get aggregated embedding from multiple queries."""
|
||||
query_embeddings = [await self.aget_query_embedding(query) for query in queries]
|
||||
agg_fn = agg_fn or mean_agg
|
||||
return agg_fn(query_embeddings)
|
||||
|
||||
@abstractmethod
|
||||
def _get_text_embedding(self, text: str) -> Embedding:
|
||||
"""
|
||||
Embed the input text synchronously.
|
||||
|
||||
Subclasses should implement this method. Reference get_text_embedding's
|
||||
docstring for more information.
|
||||
"""
|
||||
|
||||
async def _aget_text_embedding(self, text: str) -> Embedding:
|
||||
"""
|
||||
Embed the input text asynchronously.
|
||||
|
||||
Subclasses can implement this method if there is a true async
|
||||
implementation. Reference get_text_embedding's docstring for more
|
||||
information.
|
||||
"""
|
||||
# Default implementation just falls back on _get_text_embedding
|
||||
return self._get_text_embedding(text)
|
||||
|
||||
def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
|
||||
"""
|
||||
Embed the input sequence of text synchronously.
|
||||
|
||||
Subclasses can implement this method if batch queries are supported.
|
||||
"""
|
||||
# Default implementation just loops over _get_text_embedding
|
||||
return [self._get_text_embedding(text) for text in texts]
|
||||
|
||||
async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
|
||||
"""
|
||||
Embed the input sequence of text asynchronously.
|
||||
|
||||
Subclasses can implement this method if batch queries are supported.
|
||||
"""
|
||||
return await asyncio.gather(
|
||||
*[self._aget_text_embedding(text) for text in texts]
|
||||
)
|
||||
|
||||
def get_text_embedding(self, text: str) -> Embedding:
|
||||
"""
|
||||
Embed the input text.
|
||||
|
||||
When embedding text, depending on the model, a special instruction
|
||||
can be prepended to the raw text string. For example, "Represent the
|
||||
document for retrieval: ". If you're curious, other examples of
|
||||
predefined instructions can be found in embeddings/huggingface_utils.py.
|
||||
"""
|
||||
with self.callback_manager.event(
|
||||
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
|
||||
) as event:
|
||||
text_embedding = self._get_text_embedding(text)
|
||||
|
||||
event.on_end(
|
||||
payload={
|
||||
EventPayload.CHUNKS: [text],
|
||||
EventPayload.EMBEDDINGS: [text_embedding],
|
||||
}
|
||||
)
|
||||
|
||||
return text_embedding
|
||||
|
||||
async def aget_text_embedding(self, text: str) -> Embedding:
|
||||
"""Async get text embedding."""
|
||||
with self.callback_manager.event(
|
||||
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
|
||||
) as event:
|
||||
text_embedding = await self._aget_text_embedding(text)
|
||||
|
||||
event.on_end(
|
||||
payload={
|
||||
EventPayload.CHUNKS: [text],
|
||||
EventPayload.EMBEDDINGS: [text_embedding],
|
||||
}
|
||||
)
|
||||
|
||||
return text_embedding
|
||||
|
||||
def get_text_embedding_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
show_progress: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Embedding]:
|
||||
"""Get a list of text embeddings, with batching."""
|
||||
cur_batch: List[str] = []
|
||||
result_embeddings: List[Embedding] = []
|
||||
|
||||
queue_with_progress = enumerate(
|
||||
get_tqdm_iterable(texts, show_progress, "Generating embeddings")
|
||||
)
|
||||
|
||||
for idx, text in queue_with_progress:
|
||||
cur_batch.append(text)
|
||||
if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size:
|
||||
# flush
|
||||
with self.callback_manager.event(
|
||||
CBEventType.EMBEDDING,
|
||||
payload={EventPayload.SERIALIZED: self.to_dict()},
|
||||
) as event:
|
||||
embeddings = self._get_text_embeddings(cur_batch)
|
||||
result_embeddings.extend(embeddings)
|
||||
event.on_end(
|
||||
payload={
|
||||
EventPayload.CHUNKS: cur_batch,
|
||||
EventPayload.EMBEDDINGS: embeddings,
|
||||
},
|
||||
)
|
||||
cur_batch = []
|
||||
|
||||
return result_embeddings
|
||||
|
||||
async def aget_text_embedding_batch(
|
||||
self, texts: List[str], show_progress: bool = False
|
||||
) -> List[Embedding]:
|
||||
"""Asynchronously get a list of text embeddings, with batching."""
|
||||
cur_batch: List[str] = []
|
||||
callback_payloads: List[Tuple[str, List[str]]] = []
|
||||
result_embeddings: List[Embedding] = []
|
||||
embeddings_coroutines: List[Coroutine] = []
|
||||
for idx, text in enumerate(texts):
|
||||
cur_batch.append(text)
|
||||
if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size:
|
||||
# flush
|
||||
event_id = self.callback_manager.on_event_start(
|
||||
CBEventType.EMBEDDING,
|
||||
payload={EventPayload.SERIALIZED: self.to_dict()},
|
||||
)
|
||||
callback_payloads.append((event_id, cur_batch))
|
||||
embeddings_coroutines.append(self._aget_text_embeddings(cur_batch))
|
||||
cur_batch = []
|
||||
|
||||
# flatten the results of asyncio.gather, which is a list of embeddings lists
|
||||
nested_embeddings = []
|
||||
if show_progress:
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
nested_embeddings = [
|
||||
await f
|
||||
for f in tqdm(
|
||||
asyncio.as_completed(embeddings_coroutines),
|
||||
total=len(embeddings_coroutines),
|
||||
desc="Generating embeddings",
|
||||
)
|
||||
]
|
||||
except ImportError:
|
||||
nested_embeddings = await asyncio.gather(*embeddings_coroutines)
|
||||
else:
|
||||
nested_embeddings = await asyncio.gather(*embeddings_coroutines)
|
||||
|
||||
result_embeddings = [
|
||||
embedding for embeddings in nested_embeddings for embedding in embeddings
|
||||
]
|
||||
|
||||
for (event_id, text_batch), embeddings in zip(
|
||||
callback_payloads, nested_embeddings
|
||||
):
|
||||
self.callback_manager.on_event_end(
|
||||
CBEventType.EMBEDDING,
|
||||
payload={
|
||||
EventPayload.CHUNKS: text_batch,
|
||||
EventPayload.EMBEDDINGS: embeddings,
|
||||
},
|
||||
event_id=event_id,
|
||||
)
|
||||
|
||||
return result_embeddings
|
||||
|
||||
def similarity(
|
||||
self,
|
||||
embedding1: Embedding,
|
||||
embedding2: Embedding,
|
||||
mode: SimilarityMode = SimilarityMode.DEFAULT,
|
||||
) -> float:
|
||||
"""Get embedding similarity."""
|
||||
return similarity(embedding1=embedding1, embedding2=embedding2, mode=mode)
|
||||
|
||||
def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:
|
||||
embeddings = self.get_text_embedding_batch(
|
||||
[node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for node, embedding in zip(nodes, embeddings):
|
||||
node.embedding = embedding
|
||||
|
||||
return nodes
|
||||
|
||||
async def acall(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:
|
||||
embeddings = await self.aget_text_embedding_batch(
|
||||
[node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for node, embedding in zip(nodes, embeddings):
|
||||
node.embedding = embedding
|
||||
|
||||
return nodes
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
from llama_index.indices.query.schema import QueryBundle, QueryType
|
||||
from llama_index.prompts.mixin import PromptMixin
|
||||
from llama_index.schema import NodeWithScore
|
||||
|
||||
|
||||
class BaseImageRetriever(PromptMixin):
|
||||
"""Base Image Retriever Abstraction."""
|
||||
|
||||
def text_to_image_retrieve(
|
||||
self, str_or_query_bundle: QueryType
|
||||
) -> List[NodeWithScore]:
|
||||
"""Retrieve image nodes given query or single image input.
|
||||
|
||||
Args:
|
||||
str_or_query_bundle (QueryType): a query text
|
||||
string or a QueryBundle object.
|
||||
"""
|
||||
if isinstance(str_or_query_bundle, str):
|
||||
str_or_query_bundle = QueryBundle(query_str=str_or_query_bundle)
|
||||
return self._text_to_image_retrieve(str_or_query_bundle)
|
||||
|
||||
@abstractmethod
|
||||
def _text_to_image_retrieve(
|
||||
self,
|
||||
query_bundle: QueryBundle,
|
||||
) -> List[NodeWithScore]:
|
||||
"""Retrieve image nodes or documents given query text.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
|
||||
def image_to_image_retrieve(
|
||||
self, str_or_query_bundle: QueryType
|
||||
) -> List[NodeWithScore]:
|
||||
"""Retrieve image nodes given single image input.
|
||||
|
||||
Args:
|
||||
str_or_query_bundle (QueryType): a image path
|
||||
string or a QueryBundle object.
|
||||
"""
|
||||
if isinstance(str_or_query_bundle, str):
|
||||
# leave query_str as empty since we are using image_path for image retrieval
|
||||
str_or_query_bundle = QueryBundle(
|
||||
query_str="", image_path=str_or_query_bundle
|
||||
)
|
||||
return self._image_to_image_retrieve(str_or_query_bundle)
|
||||
|
||||
@abstractmethod
|
||||
def _image_to_image_retrieve(
|
||||
self,
|
||||
query_bundle: QueryBundle,
|
||||
) -> List[NodeWithScore]:
|
||||
"""Retrieve image nodes or documents given image.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
|
||||
# Async Methods
|
||||
async def atext_to_image_retrieve(
|
||||
self,
|
||||
str_or_query_bundle: QueryType,
|
||||
) -> List[NodeWithScore]:
|
||||
if isinstance(str_or_query_bundle, str):
|
||||
str_or_query_bundle = QueryBundle(query_str=str_or_query_bundle)
|
||||
return await self._atext_to_image_retrieve(str_or_query_bundle)
|
||||
|
||||
@abstractmethod
|
||||
async def _atext_to_image_retrieve(
|
||||
self,
|
||||
query_bundle: QueryBundle,
|
||||
) -> List[NodeWithScore]:
|
||||
"""Async retrieve image nodes or documents given query text.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
|
||||
async def aimage_to_image_retrieve(
|
||||
self,
|
||||
str_or_query_bundle: QueryType,
|
||||
) -> List[NodeWithScore]:
|
||||
if isinstance(str_or_query_bundle, str):
|
||||
# leave query_str as empty since we are using image_path for image retrieval
|
||||
str_or_query_bundle = QueryBundle(
|
||||
query_str="", image_path=str_or_query_bundle
|
||||
)
|
||||
return await self._aimage_to_image_retrieve(str_or_query_bundle)
|
||||
|
||||
@abstractmethod
|
||||
async def _aimage_to_image_retrieve(
|
||||
self,
|
||||
query_bundle: QueryBundle,
|
||||
) -> List[NodeWithScore]:
|
||||
"""Async retrieve image nodes or documents given image.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
"""
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
from enum import Enum
|
||||
from typing import Any, AsyncGenerator, Generator, Optional
|
||||
|
||||
from llama_index.bridge.pydantic import BaseModel, Field
|
||||
from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""Message role."""
|
||||
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
FUNCTION = "function"
|
||||
TOOL = "tool"
|
||||
CHATBOT = "chatbot"
|
||||
|
||||
|
||||
# ===== Generic Model Input - Chat =====
|
||||
class ChatMessage(BaseModel):
|
||||
"""Chat message."""
|
||||
|
||||
role: MessageRole = MessageRole.USER
|
||||
content: Optional[Any] = ""
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.role.value}: {self.content}"
|
||||
|
||||
|
||||
# ===== Generic Model Output - Chat =====
|
||||
class ChatResponse(BaseModel):
|
||||
"""Chat response."""
|
||||
|
||||
message: ChatMessage
|
||||
raw: Optional[dict] = None
|
||||
delta: Optional[str] = None
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.message)
|
||||
|
||||
|
||||
ChatResponseGen = Generator[ChatResponse, None, None]
|
||||
ChatResponseAsyncGen = AsyncGenerator[ChatResponse, None]
|
||||
|
||||
|
||||
# ===== Generic Model Output - Completion =====
|
||||
class CompletionResponse(BaseModel):
|
||||
"""
|
||||
Completion response.
|
||||
|
||||
Fields:
|
||||
text: Text content of the response if not streaming, or if streaming,
|
||||
the current extent of streamed text.
|
||||
additional_kwargs: Additional information on the response(i.e. token
|
||||
counts, function calling information).
|
||||
raw: Optional raw JSON that was parsed to populate text, if relevant.
|
||||
delta: New text that just streamed in (only relevant when streaming).
|
||||
"""
|
||||
|
||||
text: str
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
raw: Optional[dict] = None
|
||||
delta: Optional[str] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
|
||||
CompletionResponseGen = Generator[CompletionResponse, None, None]
|
||||
CompletionResponseAsyncGen = AsyncGenerator[CompletionResponse, None]
|
||||
|
||||
|
||||
class LLMMetadata(BaseModel):
|
||||
context_window: int = Field(
|
||||
default=DEFAULT_CONTEXT_WINDOW,
|
||||
description=(
|
||||
"Total number of tokens the model can be input and output for one response."
|
||||
),
|
||||
)
|
||||
num_output: int = Field(
|
||||
default=DEFAULT_NUM_OUTPUTS,
|
||||
description="Number of tokens the model can output when generating a response.",
|
||||
)
|
||||
is_chat_model: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Set True if the model exposes a chat interface (i.e. can be passed a"
|
||||
" sequence of messages, rather than text), like OpenAI's"
|
||||
" /v1/chat/completions endpoint."
|
||||
),
|
||||
)
|
||||
is_function_calling_model: bool = Field(
|
||||
default=False,
|
||||
# SEE: https://openai.com/blog/function-calling-and-other-api-updates
|
||||
description=(
|
||||
"Set True if the model supports function calling messages, similar to"
|
||||
" OpenAI's function calling API. For example, converting 'Email Anya to"
|
||||
" see if she wants to get coffee next Friday' to a function call like"
|
||||
" `send_email(to: string, body: string)`."
|
||||
),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="unknown",
|
||||
description=(
|
||||
"The model's name used for logging, testing, and sanity checking. For some"
|
||||
" models this can be automatically discerned. For other models, like"
|
||||
" locally loaded models, this must be manually specified."
|
||||
),
|
||||
)
|
||||
system_role: MessageRole = Field(
|
||||
default=MessageRole.SYSTEM,
|
||||
description="The role this specific LLM provider"
|
||||
"expects for system prompt. E.g. 'SYSTEM' for OpenAI, 'CHATBOT' for Cohere",
|
||||
)
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
"""Query pipeline components."""
|
||||
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple
|
||||
|
||||
from llama_index.bridge.pydantic import Field, PrivateAttr
|
||||
from llama_index.callbacks.base import CallbackManager
|
||||
from llama_index.core.query_pipeline.query_component import (
|
||||
InputKeys,
|
||||
OutputKeys,
|
||||
QueryComponent,
|
||||
)
|
||||
|
||||
|
||||
def get_parameters(fn: Callable) -> Tuple[Set[str], Set[str]]:
|
||||
"""Get parameters from function.
|
||||
|
||||
Returns:
|
||||
Tuple[Set[str], Set[str]]: required and optional parameters
|
||||
|
||||
"""
|
||||
# please write function below
|
||||
params = signature(fn).parameters
|
||||
required_params = set()
|
||||
optional_params = set()
|
||||
for param_name in params:
|
||||
param_default = params[param_name].default
|
||||
if param_default is params[param_name].empty:
|
||||
required_params.add(param_name)
|
||||
else:
|
||||
optional_params.add(param_name)
|
||||
return required_params, optional_params
|
||||
|
||||
|
||||
class FnComponent(QueryComponent):
|
||||
"""Query component that takes in an arbitrary function."""
|
||||
|
||||
fn: Callable = Field(..., description="Function to run.")
|
||||
async_fn: Optional[Callable] = Field(
|
||||
None, description="Async function to run. If not provided, will run `fn`."
|
||||
)
|
||||
output_key: str = Field(
|
||||
default="output", description="Output key for component output."
|
||||
)
|
||||
|
||||
_req_params: Set[str] = PrivateAttr()
|
||||
_opt_params: Set[str] = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn: Callable,
|
||||
async_fn: Optional[Callable] = None,
|
||||
req_params: Optional[Set[str]] = None,
|
||||
opt_params: Optional[Set[str]] = None,
|
||||
output_key: str = "output",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize."""
|
||||
# determine parameters
|
||||
default_req_params, default_opt_params = get_parameters(fn)
|
||||
if req_params is None:
|
||||
req_params = default_req_params
|
||||
if opt_params is None:
|
||||
opt_params = default_opt_params
|
||||
|
||||
self._req_params = req_params
|
||||
self._opt_params = opt_params
|
||||
super().__init__(fn=fn, async_fn=async_fn, output_key=output_key, **kwargs)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
||||
"""Set callback manager."""
|
||||
# TODO: implement
|
||||
|
||||
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component inputs during run_component."""
|
||||
# check that all required parameters are present
|
||||
missing_params = self._req_params - set(input.keys())
|
||||
if missing_params:
|
||||
raise ValueError(
|
||||
f"Missing required parameters: {missing_params}. "
|
||||
f"Input keys: {input.keys()}"
|
||||
)
|
||||
|
||||
# check that no extra parameters are present
|
||||
extra_params = set(input.keys()) - self._req_params - self._opt_params
|
||||
if extra_params:
|
||||
raise ValueError(
|
||||
f"Extra parameters: {extra_params}. " f"Input keys: {input.keys()}"
|
||||
)
|
||||
return input
|
||||
|
||||
def _run_component(self, **kwargs: Any) -> Dict:
|
||||
"""Run component."""
|
||||
return {self.output_key: self.fn(**kwargs)}
|
||||
|
||||
async def _arun_component(self, **kwargs: Any) -> Any:
|
||||
"""Run component (async)."""
|
||||
if self.async_fn is None:
|
||||
return self._run_component(**kwargs)
|
||||
else:
|
||||
return {self.output_key: await self.async_fn(**kwargs)}
|
||||
|
||||
@property
|
||||
def input_keys(self) -> InputKeys:
|
||||
"""Input keys."""
|
||||
return InputKeys.from_keys(
|
||||
required_keys=self._req_params, optional_keys=self._opt_params
|
||||
)
|
||||
|
||||
@property
|
||||
def output_keys(self) -> OutputKeys:
|
||||
"""Output keys."""
|
||||
return OutputKeys.from_keys({self.output_key})
|
||||
|
||||
|
||||
class InputComponent(QueryComponent):
|
||||
"""Input component."""
|
||||
|
||||
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return input
|
||||
|
||||
def _validate_component_outputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return input
|
||||
|
||||
def validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component inputs."""
|
||||
# NOTE: we override this to do nothing
|
||||
return input
|
||||
|
||||
def validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component outputs."""
|
||||
# NOTE: we override this to do nothing
|
||||
return output
|
||||
|
||||
def set_callback_manager(self, callback_manager: Any) -> None:
|
||||
"""Set callback manager."""
|
||||
|
||||
def _run_component(self, **kwargs: Any) -> Any:
|
||||
"""Run component."""
|
||||
return kwargs
|
||||
|
||||
async def _arun_component(self, **kwargs: Any) -> Any:
|
||||
"""Run component (async)."""
|
||||
return self._run_component(**kwargs)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> InputKeys:
|
||||
"""Input keys."""
|
||||
# NOTE: this shouldn't be used
|
||||
return InputKeys.from_keys(set(), optional_keys=set())
|
||||
# return InputComponentKeys.from_keys(set(), optional_keys=set())
|
||||
|
||||
@property
|
||||
def output_keys(self) -> OutputKeys:
|
||||
"""Output keys."""
|
||||
return OutputKeys.from_keys(set())
|
||||
|
||||
|
||||
class ArgPackComponent(QueryComponent):
|
||||
"""Arg pack component.
|
||||
|
||||
Packs arbitrary number of args into a list.
|
||||
|
||||
"""
|
||||
|
||||
convert_fn: Optional[Callable] = Field(
|
||||
default=None, description="Function to convert output."
|
||||
)
|
||||
|
||||
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component inputs during run_component."""
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component inputs."""
|
||||
return input
|
||||
|
||||
def _validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component outputs."""
|
||||
# make sure output value is a list
|
||||
if not isinstance(output["output"], list):
|
||||
raise ValueError(f"Output is not a list.")
|
||||
return output
|
||||
|
||||
def set_callback_manager(self, callback_manager: Any) -> None:
|
||||
"""Set callback manager."""
|
||||
|
||||
def _run_component(self, **kwargs: Any) -> Any:
|
||||
"""Run component."""
|
||||
# combine all lists into one
|
||||
output = []
|
||||
for v in kwargs.values():
|
||||
if self.convert_fn is not None:
|
||||
v = self.convert_fn(v)
|
||||
output.append(v)
|
||||
return {"output": output}
|
||||
|
||||
async def _arun_component(self, **kwargs: Any) -> Any:
|
||||
"""Run component (async)."""
|
||||
return self._run_component(**kwargs)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> InputKeys:
|
||||
"""Input keys."""
|
||||
# NOTE: this shouldn't be used
|
||||
return InputKeys.from_keys(set(), optional_keys=set())
|
||||
|
||||
@property
|
||||
def output_keys(self) -> OutputKeys:
|
||||
"""Output keys."""
|
||||
return OutputKeys.from_keys({"output"})
|
||||
|
||||
|
||||
class KwargPackComponent(QueryComponent):
|
||||
"""Kwarg pack component.
|
||||
|
||||
Packs arbitrary number of kwargs into a dict.
|
||||
|
||||
"""
|
||||
|
||||
convert_fn: Optional[Callable] = Field(
|
||||
default=None, description="Function to convert output."
|
||||
)
|
||||
|
||||
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component inputs during run_component."""
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component inputs."""
|
||||
return input
|
||||
|
||||
def _validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component outputs."""
|
||||
# make sure output value is a list
|
||||
if not isinstance(output["output"], dict):
|
||||
raise ValueError(f"Output is not a dict.")
|
||||
return output
|
||||
|
||||
def set_callback_manager(self, callback_manager: Any) -> None:
|
||||
"""Set callback manager."""
|
||||
|
||||
def _run_component(self, **kwargs: Any) -> Any:
|
||||
"""Run component."""
|
||||
if self.convert_fn is not None:
|
||||
for k, v in kwargs.items():
|
||||
kwargs[k] = self.convert_fn(v)
|
||||
return {"output": kwargs}
|
||||
|
||||
async def _arun_component(self, **kwargs: Any) -> Any:
|
||||
"""Run component (async)."""
|
||||
return self._run_component(**kwargs)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> InputKeys:
|
||||
"""Input keys."""
|
||||
# NOTE: this shouldn't be used
|
||||
return InputKeys.from_keys(set(), optional_keys=set())
|
||||
|
||||
@property
|
||||
def output_keys(self) -> OutputKeys:
|
||||
"""Output keys."""
|
||||
return OutputKeys.from_keys({"output"})
|
||||
|
|
@ -0,0 +1,338 @@
|
|||
"""Pipeline schema."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
|
||||
from llama_index.bridge.pydantic import BaseModel, Field
|
||||
from llama_index.callbacks.base import CallbackManager
|
||||
from llama_index.core.llms.types import (
|
||||
ChatResponse,
|
||||
CompletionResponse,
|
||||
)
|
||||
from llama_index.core.response.schema import Response
|
||||
from llama_index.schema import NodeWithScore, QueryBundle, TextNode
|
||||
|
||||
## Define common types used throughout these components
|
||||
StringableInput = Union[
|
||||
CompletionResponse,
|
||||
ChatResponse,
|
||||
str,
|
||||
QueryBundle,
|
||||
Response,
|
||||
Generator,
|
||||
NodeWithScore,
|
||||
TextNode,
|
||||
]
|
||||
|
||||
|
||||
def validate_and_convert_stringable(input: Any) -> str:
|
||||
# special handling for generator
|
||||
if isinstance(input, Generator):
|
||||
# iterate through each element, make sure is stringable
|
||||
new_input = ""
|
||||
for elem in input:
|
||||
if not isinstance(elem, get_args(StringableInput)):
|
||||
raise ValueError(f"Input {elem} is not stringable.")
|
||||
elif isinstance(elem, (ChatResponse, CompletionResponse)):
|
||||
new_input += cast(str, elem.delta)
|
||||
else:
|
||||
new_input += str(elem)
|
||||
return new_input
|
||||
elif isinstance(input, List):
|
||||
# iterate through each element, make sure is stringable
|
||||
# do this recursively
|
||||
new_input_list = []
|
||||
for elem in input:
|
||||
new_input_list.append(validate_and_convert_stringable(elem))
|
||||
return str(new_input_list)
|
||||
elif isinstance(input, ChatResponse):
|
||||
return input.message.content or ""
|
||||
elif isinstance(input, get_args(StringableInput)):
|
||||
return str(input)
|
||||
else:
|
||||
raise ValueError(f"Input {input} is not stringable.")
|
||||
|
||||
|
||||
class InputKeys(BaseModel):
|
||||
"""Input keys."""
|
||||
|
||||
required_keys: Set[str] = Field(default_factory=set)
|
||||
optional_keys: Set[str] = Field(default_factory=set)
|
||||
|
||||
@classmethod
|
||||
def from_keys(
|
||||
cls, required_keys: Set[str], optional_keys: Optional[Set[str]] = None
|
||||
) -> "InputKeys":
|
||||
"""Create InputKeys from tuple."""
|
||||
return cls(required_keys=required_keys, optional_keys=optional_keys or set())
|
||||
|
||||
def validate(self, input_keys: Set[str]) -> None:
|
||||
"""Validate input keys."""
|
||||
# check if required keys are present, and that keys all are in required or optional
|
||||
if not self.required_keys.issubset(input_keys):
|
||||
raise ValueError(
|
||||
f"Required keys {self.required_keys} are not present in input keys {input_keys}"
|
||||
)
|
||||
if not input_keys.issubset(self.required_keys.union(self.optional_keys)):
|
||||
raise ValueError(
|
||||
f"Input keys {input_keys} contain keys not in required or optional keys {self.required_keys.union(self.optional_keys)}"
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Length of input keys."""
|
||||
return len(self.required_keys) + len(self.optional_keys)
|
||||
|
||||
def all(self) -> Set[str]:
|
||||
"""Get all input keys."""
|
||||
return self.required_keys.union(self.optional_keys)
|
||||
|
||||
|
||||
class OutputKeys(BaseModel):
|
||||
"""Output keys."""
|
||||
|
||||
required_keys: Set[str] = Field(default_factory=set)
|
||||
|
||||
@classmethod
|
||||
def from_keys(
|
||||
cls,
|
||||
required_keys: Set[str],
|
||||
) -> "InputKeys":
|
||||
"""Create InputKeys from tuple."""
|
||||
return cls(required_keys=required_keys)
|
||||
|
||||
def validate(self, input_keys: Set[str]) -> None:
|
||||
"""Validate input keys."""
|
||||
# validate that input keys exactly match required keys
|
||||
if input_keys != self.required_keys:
|
||||
raise ValueError(
|
||||
f"Input keys {input_keys} do not match required keys {self.required_keys}"
|
||||
)
|
||||
|
||||
|
||||
class ChainableMixin(ABC):
|
||||
"""Chainable mixin.
|
||||
|
||||
A module that can produce a `QueryComponent` from a set of inputs through
|
||||
`as_query_component`.
|
||||
|
||||
If plugged in directly into a `QueryPipeline`, the `ChainableMixin` will be
|
||||
converted into a `QueryComponent` with default parameters.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _as_query_component(self, **kwargs: Any) -> "QueryComponent":
|
||||
"""Get query component."""
|
||||
|
||||
def as_query_component(
|
||||
self, partial: Optional[Dict[str, Any]] = None, **kwargs: Any
|
||||
) -> "QueryComponent":
|
||||
"""Get query component."""
|
||||
component = self._as_query_component(**kwargs)
|
||||
component.partial(**(partial or {}))
|
||||
return component
|
||||
|
||||
|
||||
class QueryComponent(BaseModel):
|
||||
"""Query component.
|
||||
|
||||
Represents a component that can be run in a `QueryPipeline`.
|
||||
|
||||
"""
|
||||
|
||||
partial_dict: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Partial arguments to run_component"
|
||||
)
|
||||
|
||||
# TODO: make this a subclass of BaseComponent (e.g. use Pydantic)
|
||||
|
||||
def partial(self, **kwargs: Any) -> None:
|
||||
"""Update with partial arguments."""
|
||||
self.partial_dict.update(kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
||||
"""Set callback manager."""
|
||||
# TODO: refactor so that callback_manager is always passed in during runtime.
|
||||
|
||||
@property
|
||||
def free_req_input_keys(self) -> Set[str]:
|
||||
"""Get free input keys."""
|
||||
return self.input_keys.required_keys.difference(self.partial_dict.keys())
|
||||
|
||||
@abstractmethod
|
||||
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component inputs during run_component."""
|
||||
|
||||
def _validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component outputs during run_component."""
|
||||
# override if needed
|
||||
return output
|
||||
|
||||
def validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component inputs."""
|
||||
# make sure set of input keys == self.input_keys
|
||||
self.input_keys.validate(set(input.keys()))
|
||||
return self._validate_component_inputs(input)
|
||||
|
||||
def validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component outputs."""
|
||||
# make sure set of output keys == self.output_keys
|
||||
self.output_keys.validate(set(output.keys()))
|
||||
return self._validate_component_outputs(output)
|
||||
|
||||
def run_component(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Run component."""
|
||||
kwargs.update(self.partial_dict)
|
||||
kwargs = self.validate_component_inputs(kwargs)
|
||||
component_outputs = self._run_component(**kwargs)
|
||||
return self.validate_component_outputs(component_outputs)
|
||||
|
||||
async def arun_component(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Run component."""
|
||||
kwargs.update(self.partial_dict)
|
||||
kwargs = self.validate_component_inputs(kwargs)
|
||||
component_outputs = await self._arun_component(**kwargs)
|
||||
return self.validate_component_outputs(component_outputs)
|
||||
|
||||
@abstractmethod
|
||||
def _run_component(self, **kwargs: Any) -> Dict:
|
||||
"""Run component."""
|
||||
|
||||
@abstractmethod
|
||||
async def _arun_component(self, **kwargs: Any) -> Any:
|
||||
"""Run component (async)."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> InputKeys:
|
||||
"""Input keys."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> OutputKeys:
|
||||
"""Output keys."""
|
||||
|
||||
@property
|
||||
def sub_query_components(self) -> List["QueryComponent"]:
|
||||
"""Get sub query components.
|
||||
|
||||
Certain query components may have sub query components, e.g. a
|
||||
query pipeline will have sub query components, and so will
|
||||
an IfElseComponent.
|
||||
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
class CustomQueryComponent(QueryComponent):
|
||||
"""Custom query component."""
|
||||
|
||||
callback_manager: CallbackManager = Field(
|
||||
default_factory=CallbackManager, description="Callback manager"
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
||||
"""Set callback manager."""
|
||||
self.callback_manager = callback_manager
|
||||
|
||||
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate component inputs during run_component."""
|
||||
# NOTE: user can override this method to validate inputs
|
||||
# but we do this by default for convenience
|
||||
return input
|
||||
|
||||
async def _arun_component(self, **kwargs: Any) -> Any:
|
||||
"""Run component (async)."""
|
||||
raise NotImplementedError("This component does not support async run.")
|
||||
|
||||
@property
|
||||
def _input_keys(self) -> Set[str]:
|
||||
"""Input keys dict."""
|
||||
raise NotImplementedError("Not implemented yet. Please override this method.")
|
||||
|
||||
@property
|
||||
def _optional_input_keys(self) -> Set[str]:
|
||||
"""Optional input keys dict."""
|
||||
return set()
|
||||
|
||||
@property
|
||||
def _output_keys(self) -> Set[str]:
|
||||
"""Output keys dict."""
|
||||
raise NotImplementedError("Not implemented yet. Please override this method.")
|
||||
|
||||
@property
|
||||
def input_keys(self) -> InputKeys:
|
||||
"""Input keys."""
|
||||
# NOTE: user can override this too, but we have them implement an
|
||||
# abstract method to make sure they do it
|
||||
|
||||
return InputKeys.from_keys(
|
||||
required_keys=self._input_keys, optional_keys=self._optional_input_keys
|
||||
)
|
||||
|
||||
@property
|
||||
def output_keys(self) -> OutputKeys:
|
||||
"""Output keys."""
|
||||
# NOTE: user can override this too, but we have them implement an
|
||||
# abstract method to make sure they do it
|
||||
return OutputKeys.from_keys(self._output_keys)
|
||||
|
||||
|
||||
class Link(BaseModel):
|
||||
"""Link between two components."""
|
||||
|
||||
src: str = Field(..., description="Source component name")
|
||||
dest: str = Field(..., description="Destination component name")
|
||||
src_key: Optional[str] = Field(
|
||||
default=None, description="Source component output key"
|
||||
)
|
||||
dest_key: Optional[str] = Field(
|
||||
default=None, description="Destination component input key"
|
||||
)
|
||||
|
||||
condition_fn: Optional[Callable] = Field(
|
||||
default=None, description="Condition to determine if link should be followed"
|
||||
)
|
||||
input_fn: Optional[Callable] = Field(
|
||||
default=None, description="Input to destination component"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
src: str,
|
||||
dest: str,
|
||||
src_key: Optional[str] = None,
|
||||
dest_key: Optional[str] = None,
|
||||
condition_fn: Optional[Callable] = None,
|
||||
input_fn: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
# NOTE: This is to enable positional args.
|
||||
super().__init__(
|
||||
src=src,
|
||||
dest=dest,
|
||||
src_key=src_key,
|
||||
dest_key=dest_key,
|
||||
condition_fn=condition_fn,
|
||||
input_fn=input_fn,
|
||||
)
|
||||
|
||||
|
||||
# accept both QueryComponent and ChainableMixin as inputs to query pipeline
|
||||
# ChainableMixin modules will be converted to components via `as_query_component`
|
||||
QUERY_COMPONENT_TYPE = Union[QueryComponent, ChainableMixin]
|
||||
|
|
@ -0,0 +1,142 @@
|
|||
"""Response schema."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from llama_index.bridge.pydantic import BaseModel
|
||||
from llama_index.schema import NodeWithScore
|
||||
from llama_index.types import TokenGen
|
||||
from llama_index.utils import truncate_text
|
||||
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
"""Response object.
|
||||
|
||||
Returned if streaming=False.
|
||||
|
||||
Attributes:
|
||||
response: The response text.
|
||||
|
||||
"""
|
||||
|
||||
response: Optional[str]
|
||||
source_nodes: List[NodeWithScore] = field(default_factory=list)
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Convert to string representation."""
|
||||
return self.response or "None"
|
||||
|
||||
def get_formatted_sources(self, length: int = 100) -> str:
|
||||
"""Get formatted sources text."""
|
||||
texts = []
|
||||
for source_node in self.source_nodes:
|
||||
fmt_text_chunk = truncate_text(source_node.node.get_content(), length)
|
||||
doc_id = source_node.node.node_id or "None"
|
||||
source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}"
|
||||
texts.append(source_text)
|
||||
return "\n\n".join(texts)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PydanticResponse:
|
||||
"""PydanticResponse object.
|
||||
|
||||
Returned if streaming=False.
|
||||
|
||||
Attributes:
|
||||
response: The response text.
|
||||
|
||||
"""
|
||||
|
||||
response: Optional[BaseModel]
|
||||
source_nodes: List[NodeWithScore] = field(default_factory=list)
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Convert to string representation."""
|
||||
return self.response.json() if self.response else "None"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Get attribute, but prioritize the pydantic response object."""
|
||||
if self.response is not None and name in self.response.dict():
|
||||
return getattr(self.response, name)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_formatted_sources(self, length: int = 100) -> str:
|
||||
"""Get formatted sources text."""
|
||||
texts = []
|
||||
for source_node in self.source_nodes:
|
||||
fmt_text_chunk = truncate_text(source_node.node.get_content(), length)
|
||||
doc_id = source_node.node.node_id or "None"
|
||||
source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}"
|
||||
texts.append(source_text)
|
||||
return "\n\n".join(texts)
|
||||
|
||||
def get_response(self) -> Response:
|
||||
"""Get a standard response object."""
|
||||
response_txt = self.response.json() if self.response else "None"
|
||||
return Response(response_txt, self.source_nodes, self.metadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingResponse:
|
||||
"""StreamingResponse object.
|
||||
|
||||
Returned if streaming=True.
|
||||
|
||||
Attributes:
|
||||
response_gen: The response generator.
|
||||
|
||||
"""
|
||||
|
||||
response_gen: TokenGen
|
||||
source_nodes: List[NodeWithScore] = field(default_factory=list)
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
response_txt: Optional[str] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Convert to string representation."""
|
||||
if self.response_txt is None and self.response_gen is not None:
|
||||
response_txt = ""
|
||||
for text in self.response_gen:
|
||||
response_txt += text
|
||||
self.response_txt = response_txt
|
||||
return self.response_txt or "None"
|
||||
|
||||
def get_response(self) -> Response:
|
||||
"""Get a standard response object."""
|
||||
if self.response_txt is None and self.response_gen is not None:
|
||||
response_txt = ""
|
||||
for text in self.response_gen:
|
||||
response_txt += text
|
||||
self.response_txt = response_txt
|
||||
return Response(self.response_txt, self.source_nodes, self.metadata)
|
||||
|
||||
def print_response_stream(self) -> None:
|
||||
"""Print the response stream."""
|
||||
if self.response_txt is None and self.response_gen is not None:
|
||||
response_txt = ""
|
||||
for text in self.response_gen:
|
||||
print(text, end="", flush=True)
|
||||
response_txt += text
|
||||
self.response_txt = response_txt
|
||||
else:
|
||||
print(self.response_txt)
|
||||
|
||||
def get_formatted_sources(self, length: int = 100, trim_text: int = True) -> str:
|
||||
"""Get formatted sources text."""
|
||||
texts = []
|
||||
for source_node in self.source_nodes:
|
||||
fmt_text_chunk = source_node.node.get_content()
|
||||
if trim_text:
|
||||
fmt_text_chunk = truncate_text(fmt_text_chunk, length)
|
||||
node_id = source_node.node.node_id or "None"
|
||||
source_text = f"> Source (Node id: {node_id}): {fmt_text_chunk}"
|
||||
texts.append(source_text)
|
||||
return "\n\n".join(texts)
|
||||
|
||||
|
||||
RESPONSE_TYPE = Union[Response, StreamingResponse, PydanticResponse]
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""Init file."""
|
||||
|
||||
from llama_index.data_structs.data_structs import (
|
||||
IndexDict,
|
||||
IndexGraph,
|
||||
IndexList,
|
||||
KeywordTable,
|
||||
Node,
|
||||
)
|
||||
from llama_index.data_structs.table import StructDatapoint
|
||||
|
||||
__all__ = [
|
||||
"IndexGraph",
|
||||
"KeywordTable",
|
||||
"IndexList",
|
||||
"IndexDict",
|
||||
"StructDatapoint",
|
||||
"Node",
|
||||
]
|
||||
|
|
@ -0,0 +1,267 @@
|
|||
"""Data structures.
|
||||
|
||||
Nodes are decoupled from the indices.
|
||||
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Sequence, Set
|
||||
|
||||
from dataclasses_json import DataClassJsonMixin
|
||||
|
||||
from llama_index.data_structs.struct_type import IndexStructType
|
||||
from llama_index.schema import BaseNode, TextNode
|
||||
|
||||
# TODO: legacy backport of old Node class
|
||||
Node = TextNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexStruct(DataClassJsonMixin):
|
||||
"""A base data struct for a LlamaIndex."""
|
||||
|
||||
index_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
summary: Optional[str] = None
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""Get text summary."""
|
||||
if self.summary is None:
|
||||
raise ValueError("summary field of the index_struct not set.")
|
||||
return self.summary
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get index struct type."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexGraph(IndexStruct):
|
||||
"""A graph representing the tree-structured index."""
|
||||
|
||||
# mapping from index in tree to Node doc id.
|
||||
all_nodes: Dict[int, str] = field(default_factory=dict)
|
||||
root_nodes: Dict[int, str] = field(default_factory=dict)
|
||||
node_id_to_children_ids: Dict[str, List[str]] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def node_id_to_index(self) -> Dict[str, int]:
|
||||
"""Map from node id to index."""
|
||||
return {node_id: index for index, node_id in self.all_nodes.items()}
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get the size of the graph."""
|
||||
return len(self.all_nodes)
|
||||
|
||||
def get_index(self, node: BaseNode) -> int:
|
||||
"""Get index of node."""
|
||||
return self.node_id_to_index[node.node_id]
|
||||
|
||||
def insert(
|
||||
self,
|
||||
node: BaseNode,
|
||||
index: Optional[int] = None,
|
||||
children_nodes: Optional[Sequence[BaseNode]] = None,
|
||||
) -> None:
|
||||
"""Insert node."""
|
||||
index = index or self.size
|
||||
node_id = node.node_id
|
||||
|
||||
self.all_nodes[index] = node_id
|
||||
|
||||
if children_nodes is None:
|
||||
children_nodes = []
|
||||
children_ids = [n.node_id for n in children_nodes]
|
||||
self.node_id_to_children_ids[node_id] = children_ids
|
||||
|
||||
def get_children(self, parent_node: Optional[BaseNode]) -> Dict[int, str]:
|
||||
"""Get children nodes."""
|
||||
if parent_node is None:
|
||||
return self.root_nodes
|
||||
else:
|
||||
parent_id = parent_node.node_id
|
||||
children_ids = self.node_id_to_children_ids[parent_id]
|
||||
return {
|
||||
self.node_id_to_index[child_id]: child_id for child_id in children_ids
|
||||
}
|
||||
|
||||
def insert_under_parent(
|
||||
self,
|
||||
node: BaseNode,
|
||||
parent_node: Optional[BaseNode],
|
||||
new_index: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Insert under parent node."""
|
||||
new_index = new_index or self.size
|
||||
if parent_node is None:
|
||||
self.root_nodes[new_index] = node.node_id
|
||||
self.node_id_to_children_ids[node.node_id] = []
|
||||
else:
|
||||
if parent_node.node_id not in self.node_id_to_children_ids:
|
||||
self.node_id_to_children_ids[parent_node.node_id] = []
|
||||
self.node_id_to_children_ids[parent_node.node_id].append(node.node_id)
|
||||
|
||||
self.all_nodes[new_index] = node.node_id
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.TREE
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeywordTable(IndexStruct):
|
||||
"""A table of keywords mapping keywords to text chunks."""
|
||||
|
||||
table: Dict[str, Set[str]] = field(default_factory=dict)
|
||||
|
||||
def add_node(self, keywords: List[str], node: BaseNode) -> None:
|
||||
"""Add text to table."""
|
||||
for keyword in keywords:
|
||||
if keyword not in self.table:
|
||||
self.table[keyword] = set()
|
||||
self.table[keyword].add(node.node_id)
|
||||
|
||||
@property
|
||||
def node_ids(self) -> Set[str]:
|
||||
"""Get all node ids."""
|
||||
return set.union(*self.table.values())
|
||||
|
||||
@property
|
||||
def keywords(self) -> Set[str]:
|
||||
"""Get all keywords in the table."""
|
||||
return set(self.table.keys())
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get the size of the table."""
|
||||
return len(self.table)
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.KEYWORD_TABLE
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexList(IndexStruct):
|
||||
"""A list of documents."""
|
||||
|
||||
nodes: List[str] = field(default_factory=list)
|
||||
|
||||
def add_node(self, node: BaseNode) -> None:
|
||||
"""Add text to table, return current position in list."""
|
||||
# don't worry about child indices for now, nodes are all in order
|
||||
self.nodes.append(node.node_id)
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.LIST
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexDict(IndexStruct):
|
||||
"""A simple dictionary of documents."""
|
||||
|
||||
# TODO: slightly deprecated, should likely be a list or set now
|
||||
# mapping from vector store id to node doc_id
|
||||
nodes_dict: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# TODO: deprecated, not used
|
||||
# mapping from node doc_id to vector store id
|
||||
doc_id_dict: Dict[str, List[str]] = field(default_factory=dict)
|
||||
|
||||
# TODO: deprecated, not used
|
||||
# this should be empty for all other indices
|
||||
embeddings_dict: Dict[str, List[float]] = field(default_factory=dict)
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
node: BaseNode,
|
||||
text_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Add text to table, return current position in list."""
|
||||
# # don't worry about child indices for now, nodes are all in order
|
||||
# self.nodes_dict[int_id] = node
|
||||
vector_id = text_id if text_id is not None else node.node_id
|
||||
self.nodes_dict[vector_id] = node.node_id
|
||||
|
||||
return vector_id
|
||||
|
||||
def delete(self, doc_id: str) -> None:
|
||||
"""Delete a Node."""
|
||||
del self.nodes_dict[doc_id]
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.VECTOR_STORE
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModelIndexDict(IndexDict):
|
||||
"""A simple dictionary of documents, but loads a MultiModelVectorStore."""
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.MULTIMODAL_VECTOR_STORE
|
||||
|
||||
|
||||
@dataclass
|
||||
class KG(IndexStruct):
|
||||
"""A table of keywords mapping keywords to text chunks."""
|
||||
|
||||
# Unidirectional
|
||||
|
||||
# table of keywords to node ids
|
||||
table: Dict[str, Set[str]] = field(default_factory=dict)
|
||||
|
||||
# TODO: legacy attribute, remove in future releases
|
||||
rel_map: Dict[str, List[List[str]]] = field(default_factory=dict)
|
||||
|
||||
# TBD, should support vector store, now we just persist the embedding memory
|
||||
# maybe chainable abstractions for *_stores could be designed
|
||||
embedding_dict: Dict[str, List[float]] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def node_ids(self) -> Set[str]:
|
||||
"""Get all node ids."""
|
||||
return set.union(*self.table.values())
|
||||
|
||||
def add_to_embedding_dict(self, triplet_str: str, embedding: List[float]) -> None:
|
||||
"""Add embedding to dict."""
|
||||
self.embedding_dict[triplet_str] = embedding
|
||||
|
||||
def add_node(self, keywords: List[str], node: BaseNode) -> None:
|
||||
"""Add text to table."""
|
||||
node_id = node.node_id
|
||||
for keyword in keywords:
|
||||
if keyword not in self.table:
|
||||
self.table[keyword] = set()
|
||||
self.table[keyword].add(node_id)
|
||||
|
||||
def search_node_by_keyword(self, keyword: str) -> List[str]:
|
||||
"""Search for nodes by keyword."""
|
||||
if keyword not in self.table:
|
||||
return []
|
||||
return list(self.table[keyword])
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.KG
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmptyIndexStruct(IndexStruct):
|
||||
"""Empty index."""
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.EMPTY
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
"""Data struct for document summary index."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_index.data_structs.data_structs import IndexStruct
|
||||
from llama_index.data_structs.struct_type import IndexStructType
|
||||
from llama_index.schema import BaseNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexDocumentSummary(IndexStruct):
|
||||
"""A simple struct containing a mapping from summary node_id to doc node_ids.
|
||||
|
||||
Also mapping vice versa.
|
||||
|
||||
"""
|
||||
|
||||
summary_id_to_node_ids: Dict[str, List[str]] = field(default_factory=dict)
|
||||
node_id_to_summary_id: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# track mapping from doc id to node summary id
|
||||
doc_id_to_summary_id: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def add_summary_and_nodes(
|
||||
self,
|
||||
summary_node: BaseNode,
|
||||
nodes: List[BaseNode],
|
||||
) -> str:
|
||||
"""Add node and summary."""
|
||||
summary_id = summary_node.node_id
|
||||
ref_doc_id = summary_node.ref_doc_id
|
||||
if ref_doc_id is None:
|
||||
raise ValueError(
|
||||
"ref_doc_id of node cannot be None when building a document "
|
||||
"summary index"
|
||||
)
|
||||
self.doc_id_to_summary_id[ref_doc_id] = summary_id
|
||||
|
||||
for node in nodes:
|
||||
node_id = node.node_id
|
||||
if summary_id not in self.summary_id_to_node_ids:
|
||||
self.summary_id_to_node_ids[summary_id] = []
|
||||
self.summary_id_to_node_ids[summary_id].append(node_id)
|
||||
|
||||
self.node_id_to_summary_id[node_id] = summary_id
|
||||
|
||||
return summary_id
|
||||
|
||||
@property
|
||||
def summary_ids(self) -> List[str]:
|
||||
"""Get summary ids."""
|
||||
return list(self.summary_id_to_node_ids.keys())
|
||||
|
||||
def delete(self, doc_id: str) -> None:
|
||||
"""Delete a document and its nodes."""
|
||||
summary_id = self.doc_id_to_summary_id[doc_id]
|
||||
del self.doc_id_to_summary_id[doc_id]
|
||||
node_ids = self.summary_id_to_node_ids[summary_id]
|
||||
for node_id in node_ids:
|
||||
del self.node_id_to_summary_id[node_id]
|
||||
del self.summary_id_to_node_ids[summary_id]
|
||||
|
||||
def delete_nodes(self, node_ids: List[str]) -> None:
|
||||
for node_id in node_ids:
|
||||
summary_id = self.node_id_to_summary_id[node_id]
|
||||
self.summary_id_to_node_ids[summary_id].remove(node_id)
|
||||
del self.node_id_to_summary_id[node_id]
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.DOCUMENT_SUMMARY
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
"""Index registry."""
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from llama_index.data_structs.data_structs import (
|
||||
KG,
|
||||
EmptyIndexStruct,
|
||||
IndexDict,
|
||||
IndexGraph,
|
||||
IndexList,
|
||||
IndexStruct,
|
||||
KeywordTable,
|
||||
MultiModelIndexDict,
|
||||
)
|
||||
from llama_index.data_structs.document_summary import IndexDocumentSummary
|
||||
from llama_index.data_structs.struct_type import IndexStructType
|
||||
from llama_index.data_structs.table import PandasStructTable, SQLStructTable
|
||||
|
||||
INDEX_STRUCT_TYPE_TO_INDEX_STRUCT_CLASS: Dict[IndexStructType, Type[IndexStruct]] = {
|
||||
IndexStructType.TREE: IndexGraph,
|
||||
IndexStructType.LIST: IndexList,
|
||||
IndexStructType.KEYWORD_TABLE: KeywordTable,
|
||||
IndexStructType.VECTOR_STORE: IndexDict,
|
||||
IndexStructType.SQL: SQLStructTable,
|
||||
IndexStructType.PANDAS: PandasStructTable,
|
||||
IndexStructType.KG: KG,
|
||||
IndexStructType.EMPTY: EmptyIndexStruct,
|
||||
IndexStructType.DOCUMENT_SUMMARY: IndexDocumentSummary,
|
||||
IndexStructType.MULTIMODAL_VECTOR_STORE: MultiModelIndexDict,
|
||||
}
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
"""IndexStructType class."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class IndexStructType(str, Enum):
|
||||
"""Index struct type. Identifier for a "type" of index.
|
||||
|
||||
Attributes:
|
||||
TREE ("tree"): Tree index. See :ref:`Ref-Indices-Tree` for tree indices.
|
||||
LIST ("list"): Summary index. See :ref:`Ref-Indices-List` for summary indices.
|
||||
KEYWORD_TABLE ("keyword_table"): Keyword table index. See
|
||||
:ref:`Ref-Indices-Table`
|
||||
for keyword table indices.
|
||||
DICT ("dict"): Faiss Vector Store Index. See
|
||||
:ref:`Ref-Indices-VectorStore`
|
||||
for more information on the faiss vector store index.
|
||||
SIMPLE_DICT ("simple_dict"): Simple Vector Store Index. See
|
||||
:ref:`Ref-Indices-VectorStore`
|
||||
for more information on the simple vector store index.
|
||||
WEAVIATE ("weaviate"): Weaviate Vector Store Index.
|
||||
See :ref:`Ref-Indices-VectorStore`
|
||||
for more information on the Weaviate vector store index.
|
||||
PINECONE ("pinecone"): Pinecone Vector Store Index.
|
||||
See :ref:`Ref-Indices-VectorStore`
|
||||
for more information on the Pinecone vector store index.
|
||||
DEEPLAKE ("deeplake"): DeepLake Vector Store Index.
|
||||
See :ref:`Ref-Indices-VectorStore`
|
||||
for more information on the Pinecone vector store index.
|
||||
QDRANT ("qdrant"): Qdrant Vector Store Index.
|
||||
See :ref:`Ref-Indices-VectorStore`
|
||||
for more information on the Qdrant vector store index.
|
||||
LANCEDB ("lancedb"): LanceDB Vector Store Index
|
||||
See :ref:`Ref-Indices-VectorStore`
|
||||
for more information on the LanceDB vector store index.
|
||||
MILVUS ("milvus"): Milvus Vector Store Index.
|
||||
See :ref:`Ref-Indices-VectorStore`
|
||||
for more information on the Milvus vector store index.
|
||||
CHROMA ("chroma"): Chroma Vector Store Index.
|
||||
See :ref:`Ref-Indices-VectorStore`
|
||||
for more information on the Chroma vector store index.
|
||||
OPENSEARCH ("opensearch"): Opensearch Vector Store Index.
|
||||
See :ref:`Ref-Indices-VectorStore`
|
||||
for more information on the Opensearch vector store index.
|
||||
MYSCALE ("myscale"): MyScale Vector Store Index.
|
||||
See :ref:`Ref-Indices-VectorStore`
|
||||
for more information on the MyScale vector store index.
|
||||
EPSILLA ("epsilla"): Epsilla Vector Store Index.
|
||||
See :ref:`Ref-Indices-VectorStore`
|
||||
for more information on the Epsilla vector store index.
|
||||
CHATGPT_RETRIEVAL_PLUGIN ("chatgpt_retrieval_plugin"): ChatGPT
|
||||
retrieval plugin index.
|
||||
SQL ("SQL"): SQL Structured Store Index.
|
||||
See :ref:`Ref-Indices-StructStore`
|
||||
for more information on the SQL vector store index.
|
||||
DASHVECTOR ("dashvector"): DashVector Vector Store Index.
|
||||
See :ref:`Ref-Indices-VectorStore`
|
||||
for more information on the Dashvecotor vector store index.
|
||||
KG ("kg"): Knowledge Graph index.
|
||||
See :ref:`Ref-Indices-Knowledge-Graph` for KG indices.
|
||||
DOCUMENT_SUMMARY ("document_summary"): Document Summary Index.
|
||||
See :ref:`Ref-Indices-Document-Summary` for Summary Indices.
|
||||
|
||||
"""
|
||||
|
||||
# TODO: refactor so these are properties on the base class
|
||||
|
||||
NODE = "node"
|
||||
TREE = "tree"
|
||||
LIST = "list"
|
||||
KEYWORD_TABLE = "keyword_table"
|
||||
|
||||
# faiss
|
||||
DICT = "dict"
|
||||
# simple
|
||||
SIMPLE_DICT = "simple_dict"
|
||||
WEAVIATE = "weaviate"
|
||||
PINECONE = "pinecone"
|
||||
QDRANT = "qdrant"
|
||||
LANCEDB = "lancedb"
|
||||
MILVUS = "milvus"
|
||||
CHROMA = "chroma"
|
||||
MYSCALE = "myscale"
|
||||
VECTOR_STORE = "vector_store"
|
||||
OPENSEARCH = "opensearch"
|
||||
DASHVECTOR = "dashvector"
|
||||
CHATGPT_RETRIEVAL_PLUGIN = "chatgpt_retrieval_plugin"
|
||||
DEEPLAKE = "deeplake"
|
||||
EPSILLA = "epsilla"
|
||||
# multimodal
|
||||
MULTIMODAL_VECTOR_STORE = "multimodal"
|
||||
# for SQL index
|
||||
SQL = "sql"
|
||||
# for KG index
|
||||
KG = "kg"
|
||||
SIMPLE_KG = "simple_kg"
|
||||
NEBULAGRAPH = "nebulagraph"
|
||||
FALKORDB = "falkordb"
|
||||
|
||||
# EMPTY
|
||||
EMPTY = "empty"
|
||||
COMPOSITE = "composite"
|
||||
|
||||
PANDAS = "pandas"
|
||||
|
||||
DOCUMENT_SUMMARY = "document_summary"
|
||||
|
||||
# Managed
|
||||
VECTARA = "vectara"
|
||||
ZILLIZ_CLOUD_PIPELINE = "zilliz_cloud_pipeline"
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
"""Struct store schema."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict
|
||||
|
||||
from dataclasses_json import DataClassJsonMixin
|
||||
|
||||
from llama_index.data_structs.data_structs import IndexStruct
|
||||
from llama_index.data_structs.struct_type import IndexStructType
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructDatapoint(DataClassJsonMixin):
|
||||
"""Struct outputs."""
|
||||
|
||||
# map from field name to StructValue
|
||||
fields: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseStructTable(IndexStruct):
|
||||
"""Struct outputs."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SQLStructTable(BaseStructTable):
|
||||
"""SQL struct outputs."""
|
||||
|
||||
context_dict: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
# TODO: consolidate with IndexStructType
|
||||
return IndexStructType.SQL
|
||||
|
||||
|
||||
@dataclass
|
||||
class PandasStructTable(BaseStructTable):
|
||||
"""Pandas struct outputs."""
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> IndexStructType:
|
||||
"""Get type."""
|
||||
return IndexStructType.PANDAS
|
||||
|
|
@ -0,0 +1,262 @@
|
|||
"""Download."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import tqdm
|
||||
|
||||
from llama_index.download.module import LLAMA_HUB_URL
|
||||
from llama_index.download.utils import (
|
||||
get_file_content,
|
||||
get_file_content_bytes,
|
||||
initialize_directory,
|
||||
)
|
||||
|
||||
LLAMA_DATASETS_LFS_URL = (
|
||||
f"https://media.githubusercontent.com/media/run-llama/llama-datasets/main"
|
||||
)
|
||||
|
||||
LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL = (
|
||||
"https://github.com/run-llama/llama-datasets/tree/main"
|
||||
)
|
||||
LLAMA_SOURCE_FILES_PATH = "source_files"
|
||||
|
||||
DATASET_CLASS_FILENAME_REGISTRY = {
|
||||
"LabelledRagDataset": "rag_dataset.json",
|
||||
"LabeledRagDataset": "rag_dataset.json",
|
||||
"LabelledPairwiseEvaluatorDataset": "pairwise_evaluator_dataset.json",
|
||||
"LabeledPairwiseEvaluatorDataset": "pairwise_evaluator_dataset.json",
|
||||
"LabelledEvaluatorDataset": "evaluator_dataset.json",
|
||||
"LabeledEvaluatorDataset": "evaluator_dataset.json",
|
||||
}
|
||||
|
||||
|
||||
PATH_TYPE = Union[str, Path]
|
||||
|
||||
|
||||
def _resolve_dataset_file_name(class_name: str) -> str:
|
||||
"""Resolve filename based on dataset class."""
|
||||
try:
|
||||
return DATASET_CLASS_FILENAME_REGISTRY[class_name]
|
||||
except KeyError as err:
|
||||
raise ValueError("Invalid dataset filename.") from err
|
||||
|
||||
|
||||
def _get_source_files_list(source_tree_url: str, path: str) -> List[str]:
|
||||
"""Get the list of source files to download."""
|
||||
resp = requests.get(source_tree_url + path + "?recursive=1")
|
||||
payload = resp.json()["payload"]
|
||||
return [item["name"] for item in payload["tree"]["items"]]
|
||||
|
||||
|
||||
def get_dataset_info(
|
||||
local_dir_path: PATH_TYPE,
|
||||
remote_dir_path: PATH_TYPE,
|
||||
remote_source_dir_path: PATH_TYPE,
|
||||
dataset_class: str,
|
||||
refresh_cache: bool = False,
|
||||
library_path: str = "library.json",
|
||||
source_files_path: str = "source_files",
|
||||
disable_library_cache: bool = False,
|
||||
) -> Dict:
|
||||
"""Get dataset info."""
|
||||
if isinstance(local_dir_path, str):
|
||||
local_dir_path = Path(local_dir_path)
|
||||
|
||||
local_library_path = f"{local_dir_path}/{library_path}"
|
||||
dataset_id = None
|
||||
source_files = []
|
||||
|
||||
# Check cache first
|
||||
if not refresh_cache and os.path.exists(local_library_path):
|
||||
with open(local_library_path) as f:
|
||||
library = json.load(f)
|
||||
if dataset_class in library:
|
||||
dataset_id = library[dataset_class]["id"]
|
||||
source_files = library[dataset_class].get("source_files", [])
|
||||
|
||||
# Fetch up-to-date library from remote repo if dataset_id not found
|
||||
if dataset_id is None:
|
||||
library_raw_content, _ = get_file_content(
|
||||
str(remote_dir_path), f"/{library_path}"
|
||||
)
|
||||
library = json.loads(library_raw_content)
|
||||
if dataset_class not in library:
|
||||
raise ValueError("Loader class name not found in library")
|
||||
|
||||
dataset_id = library[dataset_class]["id"]
|
||||
|
||||
# get data card
|
||||
raw_card_content, _ = get_file_content(
|
||||
str(remote_dir_path), f"/{dataset_id}/card.json"
|
||||
)
|
||||
card = json.loads(raw_card_content)
|
||||
dataset_class_name = card["className"]
|
||||
|
||||
source_files = []
|
||||
if dataset_class_name == "LabelledRagDataset":
|
||||
source_files = _get_source_files_list(
|
||||
str(remote_source_dir_path), f"/{dataset_id}/{source_files_path}"
|
||||
)
|
||||
|
||||
# create cache dir if needed
|
||||
local_library_dir = os.path.dirname(local_library_path)
|
||||
if not disable_library_cache:
|
||||
if not os.path.exists(local_library_dir):
|
||||
os.makedirs(local_library_dir)
|
||||
|
||||
# Update cache
|
||||
with open(local_library_path, "w") as f:
|
||||
f.write(library_raw_content)
|
||||
|
||||
if dataset_id is None:
|
||||
raise ValueError("Dataset class name not found in library")
|
||||
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"dataset_class_name": dataset_class_name,
|
||||
"source_files": source_files,
|
||||
}
|
||||
|
||||
|
||||
def download_dataset_and_source_files(
|
||||
local_dir_path: PATH_TYPE,
|
||||
remote_lfs_dir_path: PATH_TYPE,
|
||||
source_files_dir_path: PATH_TYPE,
|
||||
dataset_id: str,
|
||||
dataset_class_name: str,
|
||||
source_files: List[str],
|
||||
refresh_cache: bool = False,
|
||||
base_file_name: str = "rag_dataset.json",
|
||||
override_path: bool = False,
|
||||
show_progress: bool = False,
|
||||
) -> None:
|
||||
"""Download dataset and source files."""
|
||||
if isinstance(local_dir_path, str):
|
||||
local_dir_path = Path(local_dir_path)
|
||||
|
||||
if override_path:
|
||||
module_path = str(local_dir_path)
|
||||
else:
|
||||
module_path = f"{local_dir_path}/{dataset_id}"
|
||||
|
||||
if refresh_cache or not os.path.exists(module_path):
|
||||
os.makedirs(module_path, exist_ok=True)
|
||||
|
||||
base_file_name = _resolve_dataset_file_name(dataset_class_name)
|
||||
|
||||
dataset_raw_content, _ = get_file_content(
|
||||
str(remote_lfs_dir_path), f"/{dataset_id}/{base_file_name}"
|
||||
)
|
||||
|
||||
with open(f"{module_path}/{base_file_name}", "w") as f:
|
||||
f.write(dataset_raw_content)
|
||||
|
||||
# Get content of source files
|
||||
if dataset_class_name == "LabelledRagDataset":
|
||||
os.makedirs(f"{module_path}/{source_files_dir_path}", exist_ok=True)
|
||||
if show_progress:
|
||||
source_files_iterator = tqdm.tqdm(source_files)
|
||||
else:
|
||||
source_files_iterator = source_files
|
||||
for source_file in source_files_iterator:
|
||||
if ".pdf" in source_file:
|
||||
source_file_raw_content_bytes, _ = get_file_content_bytes(
|
||||
str(remote_lfs_dir_path),
|
||||
f"/{dataset_id}/{source_files_dir_path}/{source_file}",
|
||||
)
|
||||
with open(
|
||||
f"{module_path}/{source_files_dir_path}/{source_file}", "wb"
|
||||
) as f:
|
||||
f.write(source_file_raw_content_bytes)
|
||||
else:
|
||||
source_file_raw_content, _ = get_file_content(
|
||||
str(remote_lfs_dir_path),
|
||||
f"/{dataset_id}/{source_files_dir_path}/{source_file}",
|
||||
)
|
||||
with open(
|
||||
f"{module_path}/{source_files_dir_path}/{source_file}", "w"
|
||||
) as f:
|
||||
f.write(source_file_raw_content)
|
||||
|
||||
|
||||
def download_llama_dataset(
|
||||
dataset_class: str,
|
||||
llama_hub_url: str = LLAMA_HUB_URL,
|
||||
llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL,
|
||||
llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL,
|
||||
refresh_cache: bool = False,
|
||||
custom_dir: Optional[str] = None,
|
||||
custom_path: Optional[str] = None,
|
||||
source_files_dirpath: str = LLAMA_SOURCE_FILES_PATH,
|
||||
library_path: str = "llama_datasets/library.json",
|
||||
disable_library_cache: bool = False,
|
||||
override_path: bool = False,
|
||||
show_progress: bool = False,
|
||||
) -> Any:
|
||||
"""
|
||||
Download a module from LlamaHub.
|
||||
|
||||
Can be a loader, tool, pack, or more.
|
||||
|
||||
Args:
|
||||
loader_class: The name of the llama module class you want to download,
|
||||
such as `GmailOpenAIAgentPack`.
|
||||
refresh_cache: If true, the local cache will be skipped and the
|
||||
loader will be fetched directly from the remote repo.
|
||||
custom_dir: Custom dir name to download loader into (under parent folder).
|
||||
custom_path: Custom dirpath to download loader into.
|
||||
library_path: File name of the library file.
|
||||
use_gpt_index_import: If true, the loader files will use
|
||||
llama_index as the base dependency. By default (False),
|
||||
the loader files use llama_index as the base dependency.
|
||||
NOTE: this is a temporary workaround while we fully migrate all usages
|
||||
to llama_index.
|
||||
is_dataset: whether or not downloading a LlamaDataset
|
||||
|
||||
Returns:
|
||||
A Loader, A Pack, An Agent, or A Dataset
|
||||
"""
|
||||
# create directory / get path
|
||||
dirpath = initialize_directory(custom_path=custom_path, custom_dir=custom_dir)
|
||||
|
||||
# fetch info from library.json file
|
||||
dataset_info = get_dataset_info(
|
||||
local_dir_path=dirpath,
|
||||
remote_dir_path=llama_hub_url,
|
||||
remote_source_dir_path=llama_datasets_source_files_tree_url,
|
||||
dataset_class=dataset_class,
|
||||
refresh_cache=refresh_cache,
|
||||
library_path=library_path,
|
||||
disable_library_cache=disable_library_cache,
|
||||
)
|
||||
dataset_id = dataset_info["dataset_id"]
|
||||
source_files = dataset_info["source_files"]
|
||||
dataset_class_name = dataset_info["dataset_class_name"]
|
||||
|
||||
dataset_filename = _resolve_dataset_file_name(dataset_class_name)
|
||||
|
||||
download_dataset_and_source_files(
|
||||
local_dir_path=dirpath,
|
||||
remote_lfs_dir_path=llama_datasets_lfs_url,
|
||||
source_files_dir_path=source_files_dirpath,
|
||||
dataset_id=dataset_id,
|
||||
dataset_class_name=dataset_class_name,
|
||||
source_files=source_files,
|
||||
refresh_cache=refresh_cache,
|
||||
override_path=override_path,
|
||||
show_progress=show_progress,
|
||||
)
|
||||
|
||||
if override_path:
|
||||
module_path = str(dirpath)
|
||||
else:
|
||||
module_path = f"{dirpath}/{dataset_id}"
|
||||
|
||||
return (
|
||||
f"{module_path}/{dataset_filename}",
|
||||
f"{module_path}/{LLAMA_SOURCE_FILES_PATH}",
|
||||
)
|
||||
|
|
@ -0,0 +1,273 @@
|
|||
"""Download."""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from enum import Enum
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import pkg_resources
|
||||
import requests
|
||||
from pkg_resources import DistributionNotFound
|
||||
|
||||
from llama_index.download.utils import (
|
||||
get_exports,
|
||||
get_file_content,
|
||||
initialize_directory,
|
||||
rewrite_exports,
|
||||
)
|
||||
|
||||
LLAMA_HUB_CONTENTS_URL = f"https://raw.githubusercontent.com/run-llama/llama-hub/main"
|
||||
LLAMA_HUB_PATH = "/llama_hub"
|
||||
LLAMA_HUB_URL = LLAMA_HUB_CONTENTS_URL + LLAMA_HUB_PATH
|
||||
|
||||
PATH_TYPE = Union[str, Path]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LLAMAHUB_ANALYTICS_PROXY_SERVER = "https://llamahub.ai/api/analytics/downloads"
|
||||
|
||||
|
||||
class MODULE_TYPE(str, Enum):
|
||||
LOADER = "loader"
|
||||
TOOL = "tool"
|
||||
LLAMAPACK = "llamapack"
|
||||
DATASETS = "datasets"
|
||||
|
||||
|
||||
def get_module_info(
|
||||
local_dir_path: PATH_TYPE,
|
||||
remote_dir_path: PATH_TYPE,
|
||||
module_class: str,
|
||||
refresh_cache: bool = False,
|
||||
library_path: str = "library.json",
|
||||
disable_library_cache: bool = False,
|
||||
) -> Dict:
|
||||
"""Get module info."""
|
||||
if isinstance(local_dir_path, str):
|
||||
local_dir_path = Path(local_dir_path)
|
||||
|
||||
local_library_path = f"{local_dir_path}/{library_path}"
|
||||
module_id = None # e.g. `web/simple_web`
|
||||
extra_files = [] # e.g. `web/simple_web/utils.py`
|
||||
|
||||
# Check cache first
|
||||
if not refresh_cache and os.path.exists(local_library_path):
|
||||
with open(local_library_path) as f:
|
||||
library = json.load(f)
|
||||
if module_class in library:
|
||||
module_id = library[module_class]["id"]
|
||||
extra_files = library[module_class].get("extra_files", [])
|
||||
|
||||
# Fetch up-to-date library from remote repo if module_id not found
|
||||
if module_id is None:
|
||||
library_raw_content, _ = get_file_content(
|
||||
str(remote_dir_path), f"/{library_path}"
|
||||
)
|
||||
library = json.loads(library_raw_content)
|
||||
if module_class not in library:
|
||||
raise ValueError("Loader class name not found in library")
|
||||
|
||||
module_id = library[module_class]["id"]
|
||||
extra_files = library[module_class].get("extra_files", [])
|
||||
|
||||
# create cache dir if needed
|
||||
local_library_dir = os.path.dirname(local_library_path)
|
||||
if not disable_library_cache:
|
||||
if not os.path.exists(local_library_dir):
|
||||
os.makedirs(local_library_dir)
|
||||
|
||||
# Update cache
|
||||
with open(local_library_path, "w") as f:
|
||||
f.write(library_raw_content)
|
||||
|
||||
if module_id is None:
|
||||
raise ValueError("Loader class name not found in library")
|
||||
|
||||
return {
|
||||
"module_id": module_id,
|
||||
"extra_files": extra_files,
|
||||
}
|
||||
|
||||
|
||||
def download_module_and_reqs(
|
||||
local_dir_path: PATH_TYPE,
|
||||
remote_dir_path: PATH_TYPE,
|
||||
module_id: str,
|
||||
extra_files: List[str],
|
||||
refresh_cache: bool = False,
|
||||
use_gpt_index_import: bool = False,
|
||||
base_file_name: str = "base.py",
|
||||
override_path: bool = False,
|
||||
) -> None:
|
||||
"""Load module."""
|
||||
if isinstance(local_dir_path, str):
|
||||
local_dir_path = Path(local_dir_path)
|
||||
|
||||
if override_path:
|
||||
module_path = str(local_dir_path)
|
||||
else:
|
||||
module_path = f"{local_dir_path}/{module_id}"
|
||||
|
||||
if refresh_cache or not os.path.exists(module_path):
|
||||
os.makedirs(module_path, exist_ok=True)
|
||||
|
||||
basepy_raw_content, _ = get_file_content(
|
||||
str(remote_dir_path), f"/{module_id}/{base_file_name}"
|
||||
)
|
||||
if use_gpt_index_import:
|
||||
basepy_raw_content = basepy_raw_content.replace(
|
||||
"import llama_index", "import llama_index"
|
||||
)
|
||||
basepy_raw_content = basepy_raw_content.replace(
|
||||
"from llama_index", "from llama_index"
|
||||
)
|
||||
|
||||
with open(f"{module_path}/{base_file_name}", "w") as f:
|
||||
f.write(basepy_raw_content)
|
||||
|
||||
# Get content of extra files if there are any
|
||||
# and write them under the loader directory
|
||||
for extra_file in extra_files:
|
||||
extra_file_raw_content, _ = get_file_content(
|
||||
str(remote_dir_path), f"/{module_id}/{extra_file}"
|
||||
)
|
||||
# If the extra file is an __init__.py file, we need to
|
||||
# add the exports to the __init__.py file in the modules directory
|
||||
if extra_file == "__init__.py":
|
||||
loader_exports = get_exports(extra_file_raw_content)
|
||||
existing_exports = []
|
||||
init_file_path = local_dir_path / "__init__.py"
|
||||
# if the __init__.py file do not exists, we need to create it
|
||||
mode = "a+" if not os.path.exists(init_file_path) else "r+"
|
||||
with open(init_file_path, mode) as f:
|
||||
f.write(f"from .{module_id} import {', '.join(loader_exports)}")
|
||||
existing_exports = get_exports(f.read())
|
||||
rewrite_exports(existing_exports + loader_exports, str(local_dir_path))
|
||||
|
||||
with open(f"{module_path}/{extra_file}", "w") as f:
|
||||
f.write(extra_file_raw_content)
|
||||
|
||||
# install requirements
|
||||
requirements_path = f"{local_dir_path}/requirements.txt"
|
||||
|
||||
if not os.path.exists(requirements_path):
|
||||
# NOTE: need to check the status code
|
||||
response_txt, status_code = get_file_content(
|
||||
str(remote_dir_path), f"/{module_id}/requirements.txt"
|
||||
)
|
||||
if status_code == 200:
|
||||
with open(requirements_path, "w") as f:
|
||||
f.write(response_txt)
|
||||
|
||||
# Install dependencies if there are any and not already installed
|
||||
if os.path.exists(requirements_path):
|
||||
try:
|
||||
requirements = pkg_resources.parse_requirements(
|
||||
Path(requirements_path).open()
|
||||
)
|
||||
pkg_resources.require([str(r) for r in requirements])
|
||||
except DistributionNotFound:
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install", "-r", requirements_path]
|
||||
)
|
||||
|
||||
|
||||
def download_llama_module(
|
||||
module_class: str,
|
||||
llama_hub_url: str = LLAMA_HUB_URL,
|
||||
refresh_cache: bool = False,
|
||||
custom_dir: Optional[str] = None,
|
||||
custom_path: Optional[str] = None,
|
||||
library_path: str = "library.json",
|
||||
base_file_name: str = "base.py",
|
||||
use_gpt_index_import: bool = False,
|
||||
disable_library_cache: bool = False,
|
||||
override_path: bool = False,
|
||||
skip_load: bool = False,
|
||||
) -> Any:
|
||||
"""Download a module from LlamaHub.
|
||||
|
||||
Can be a loader, tool, pack, or more.
|
||||
|
||||
Args:
|
||||
loader_class: The name of the llama module class you want to download,
|
||||
such as `GmailOpenAIAgentPack`.
|
||||
refresh_cache: If true, the local cache will be skipped and the
|
||||
loader will be fetched directly from the remote repo.
|
||||
custom_dir: Custom dir name to download loader into (under parent folder).
|
||||
custom_path: Custom dirpath to download loader into.
|
||||
library_path: File name of the library file.
|
||||
use_gpt_index_import: If true, the loader files will use
|
||||
llama_index as the base dependency. By default (False),
|
||||
the loader files use llama_index as the base dependency.
|
||||
NOTE: this is a temporary workaround while we fully migrate all usages
|
||||
to llama_index.
|
||||
is_dataset: whether or not downloading a LlamaDataset
|
||||
|
||||
Returns:
|
||||
A Loader, A Pack, An Agent, or A Dataset
|
||||
"""
|
||||
# create directory / get path
|
||||
dirpath = initialize_directory(custom_path=custom_path, custom_dir=custom_dir)
|
||||
|
||||
# fetch info from library.json file
|
||||
module_info = get_module_info(
|
||||
local_dir_path=dirpath,
|
||||
remote_dir_path=llama_hub_url,
|
||||
module_class=module_class,
|
||||
refresh_cache=refresh_cache,
|
||||
library_path=library_path,
|
||||
disable_library_cache=disable_library_cache,
|
||||
)
|
||||
module_id = module_info["module_id"]
|
||||
extra_files = module_info["extra_files"]
|
||||
|
||||
# download the module, install requirements
|
||||
download_module_and_reqs(
|
||||
local_dir_path=dirpath,
|
||||
remote_dir_path=llama_hub_url,
|
||||
module_id=module_id,
|
||||
extra_files=extra_files,
|
||||
refresh_cache=refresh_cache,
|
||||
use_gpt_index_import=use_gpt_index_import,
|
||||
base_file_name=base_file_name,
|
||||
override_path=override_path,
|
||||
)
|
||||
if skip_load:
|
||||
return None
|
||||
|
||||
# loads the module into memory
|
||||
if override_path:
|
||||
path = f"{dirpath}/{base_file_name}"
|
||||
spec = util.spec_from_file_location("custom_module", location=path)
|
||||
if spec is None:
|
||||
raise ValueError(f"Could not find file: {path}.")
|
||||
else:
|
||||
path = f"{dirpath}/{module_id}/{base_file_name}"
|
||||
spec = util.spec_from_file_location("custom_module", location=path)
|
||||
if spec is None:
|
||||
raise ValueError(f"Could not find file: {path}.")
|
||||
|
||||
module = util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
|
||||
return getattr(module, module_class)
|
||||
|
||||
|
||||
def track_download(module_class: str, module_type: str) -> None:
|
||||
"""Tracks number of downloads via Llamahub proxy.
|
||||
|
||||
Args:
|
||||
module_class: The name of the llama module being downloaded, e.g.,`GmailOpenAIAgentPack`.
|
||||
module_type: Can be "loader", "tool", "llamapack", or "datasets"
|
||||
"""
|
||||
try:
|
||||
requests.post(
|
||||
LLAMAHUB_ANALYTICS_PROXY_SERVER,
|
||||
json={"type": module_type, "plugin": module_class},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(f"Error tracking downloads for {module_class} : {e}")
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def get_file_content(url: str, path: str) -> Tuple[str, int]:
|
||||
"""Get the content of a file from the GitHub REST API."""
|
||||
resp = requests.get(url + path)
|
||||
return resp.text, resp.status_code
|
||||
|
||||
|
||||
def get_file_content_bytes(url: str, path: str) -> Tuple[bytes, int]:
|
||||
"""Get the content of a file from the GitHub REST API."""
|
||||
resp = requests.get(url + path)
|
||||
return resp.content, resp.status_code
|
||||
|
||||
|
||||
def get_exports(raw_content: str) -> List:
|
||||
"""Read content of a Python file and returns a list of exported class names.
|
||||
|
||||
For example:
|
||||
```python
|
||||
from .a import A
|
||||
from .b import B
|
||||
|
||||
__all__ = ["A", "B"]
|
||||
```
|
||||
will return `["A", "B"]`.
|
||||
|
||||
Args:
|
||||
- raw_content: The content of a Python file as a string.
|
||||
|
||||
Returns:
|
||||
A list of exported class names.
|
||||
|
||||
"""
|
||||
exports = []
|
||||
for line in raw_content.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("__all__"):
|
||||
exports = line.split("=")[1].strip().strip("[").strip("]").split(",")
|
||||
exports = [export.strip().strip("'").strip('"') for export in exports]
|
||||
return exports
|
||||
|
||||
|
||||
def rewrite_exports(exports: List[str], dirpath: str) -> None:
|
||||
"""Write the `__all__` variable to the `__init__.py` file in the modules dir.
|
||||
|
||||
Removes the line that contains `__all__` and appends a new line with the updated
|
||||
`__all__` variable.
|
||||
|
||||
Args:
|
||||
- exports: A list of exported class names.
|
||||
|
||||
"""
|
||||
init_path = f"{dirpath}/__init__.py"
|
||||
with open(init_path) as f:
|
||||
lines = f.readlines()
|
||||
with open(init_path, "w") as f:
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("__all__"):
|
||||
continue
|
||||
f.write(line + os.linesep)
|
||||
f.write(f"__all__ = {list(set(exports))}" + os.linesep)
|
||||
|
||||
|
||||
def initialize_directory(
|
||||
custom_path: Optional[str] = None, custom_dir: Optional[str] = None
|
||||
) -> Path:
|
||||
"""Initialize directory."""
|
||||
if custom_path is not None and custom_dir is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both `custom_path` and `custom_dir` at the same time."
|
||||
)
|
||||
|
||||
custom_dir = custom_dir or "llamadatasets"
|
||||
if custom_path is not None:
|
||||
dirpath = Path(custom_path)
|
||||
else:
|
||||
dirpath = Path(__file__).parent / custom_dir
|
||||
if not os.path.exists(dirpath):
|
||||
# Create a new directory because it does not exist
|
||||
os.makedirs(dirpath)
|
||||
|
||||
return dirpath
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
"""Init file."""
|
||||
|
||||
from llama_index.embeddings.adapter import (
|
||||
AdapterEmbeddingModel,
|
||||
LinearAdapterEmbeddingModel,
|
||||
)
|
||||
from llama_index.embeddings.anyscale import AnyscaleEmbedding
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.embeddings.base import BaseEmbedding, SimilarityMode
|
||||
from llama_index.embeddings.bedrock import BedrockEmbedding
|
||||
from llama_index.embeddings.clarifai import ClarifaiEmbedding
|
||||
from llama_index.embeddings.clip import ClipEmbedding
|
||||
from llama_index.embeddings.cohereai import CohereEmbedding
|
||||
from llama_index.embeddings.dashscope import (
|
||||
DashScopeBatchTextEmbeddingModels,
|
||||
DashScopeEmbedding,
|
||||
DashScopeMultiModalEmbeddingModels,
|
||||
DashScopeTextEmbeddingModels,
|
||||
DashScopeTextEmbeddingType,
|
||||
)
|
||||
from llama_index.embeddings.elasticsearch import (
|
||||
ElasticsearchEmbedding,
|
||||
ElasticsearchEmbeddings,
|
||||
)
|
||||
from llama_index.embeddings.fastembed import FastEmbedEmbedding
|
||||
from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
from llama_index.embeddings.google import GoogleUnivSentEncoderEmbedding
|
||||
from llama_index.embeddings.google_palm import GooglePaLMEmbedding
|
||||
from llama_index.embeddings.gradient import GradientEmbedding
|
||||
from llama_index.embeddings.huggingface import (
|
||||
HuggingFaceEmbedding,
|
||||
HuggingFaceInferenceAPIEmbedding,
|
||||
HuggingFaceInferenceAPIEmbeddings,
|
||||
)
|
||||
from llama_index.embeddings.huggingface_optimum import OptimumEmbedding
|
||||
from llama_index.embeddings.huggingface_utils import DEFAULT_HUGGINGFACE_EMBEDDING_MODEL
|
||||
from llama_index.embeddings.instructor import InstructorEmbedding
|
||||
from llama_index.embeddings.langchain import LangchainEmbedding
|
||||
from llama_index.embeddings.llm_rails import LLMRailsEmbedding, LLMRailsEmbeddings
|
||||
from llama_index.embeddings.mistralai import MistralAIEmbedding
|
||||
from llama_index.embeddings.nomic import NomicEmbedding
|
||||
from llama_index.embeddings.ollama_embedding import OllamaEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.embeddings.pooling import Pooling
|
||||
from llama_index.embeddings.sagemaker_embedding_endpoint import (
|
||||
SageMakerEmbedding,
|
||||
)
|
||||
from llama_index.embeddings.text_embeddings_inference import TextEmbeddingsInference
|
||||
from llama_index.embeddings.together import TogetherEmbedding
|
||||
from llama_index.embeddings.utils import resolve_embed_model
|
||||
from llama_index.embeddings.voyageai import VoyageEmbedding
|
||||
|
||||
__all__ = [
|
||||
"AdapterEmbeddingModel",
|
||||
"BedrockEmbedding",
|
||||
"ClarifaiEmbedding",
|
||||
"ClipEmbedding",
|
||||
"CohereEmbedding",
|
||||
"BaseEmbedding",
|
||||
"DEFAULT_HUGGINGFACE_EMBEDDING_MODEL",
|
||||
"ElasticsearchEmbedding",
|
||||
"FastEmbedEmbedding",
|
||||
"GoogleUnivSentEncoderEmbedding",
|
||||
"GradientEmbedding",
|
||||
"HuggingFaceInferenceAPIEmbedding",
|
||||
"HuggingFaceEmbedding",
|
||||
"InstructorEmbedding",
|
||||
"LangchainEmbedding",
|
||||
"LinearAdapterEmbeddingModel",
|
||||
"LLMRailsEmbedding",
|
||||
"MistralAIEmbedding",
|
||||
"OpenAIEmbedding",
|
||||
"AzureOpenAIEmbedding",
|
||||
"AnyscaleEmbedding",
|
||||
"OptimumEmbedding",
|
||||
"Pooling",
|
||||
"SageMakerEmbedding",
|
||||
"GooglePaLMEmbedding",
|
||||
"SimilarityMode",
|
||||
"TextEmbeddingsInference",
|
||||
"TogetherEmbedding",
|
||||
"resolve_embed_model",
|
||||
"NomicEmbedding",
|
||||
# Deprecated, kept for backwards compatibility
|
||||
"LLMRailsEmbeddings",
|
||||
"ElasticsearchEmbeddings",
|
||||
"HuggingFaceInferenceAPIEmbeddings",
|
||||
"VoyageEmbedding",
|
||||
"OllamaEmbedding",
|
||||
"GeminiEmbedding",
|
||||
"DashScopeEmbedding",
|
||||
"DashScopeTextEmbeddingModels",
|
||||
"DashScopeTextEmbeddingType",
|
||||
"DashScopeBatchTextEmbeddingModels",
|
||||
"DashScopeMultiModalEmbeddingModels",
|
||||
]
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
"""Embedding adapter model."""
|
||||
|
||||
import logging
|
||||
from typing import Any, List, Optional, Type, cast
|
||||
|
||||
from llama_index.bridge.pydantic import PrivateAttr
|
||||
from llama_index.callbacks import CallbackManager
|
||||
from llama_index.constants import DEFAULT_EMBED_BATCH_SIZE
|
||||
from llama_index.core.embeddings.base import BaseEmbedding
|
||||
from llama_index.utils import infer_torch_device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdapterEmbeddingModel(BaseEmbedding):
|
||||
"""Adapter for any embedding model.
|
||||
|
||||
This is a wrapper around any embedding model that adds an adapter layer \
|
||||
on top of it.
|
||||
This is useful for finetuning an embedding model on a downstream task.
|
||||
The embedding model can be any model - it does not need to expose gradients.
|
||||
|
||||
Args:
|
||||
base_embed_model (BaseEmbedding): Base embedding model.
|
||||
adapter_path (str): Path to adapter.
|
||||
adapter_cls (Optional[Type[Any]]): Adapter class. Defaults to None, in which \
|
||||
case a linear adapter is used.
|
||||
transform_query (bool): Whether to transform query embeddings. Defaults to True.
|
||||
device (Optional[str]): Device to use. Defaults to None.
|
||||
embed_batch_size (int): Batch size for embedding. Defaults to 10.
|
||||
callback_manager (Optional[CallbackManager]): Callback manager. \
|
||||
Defaults to None.
|
||||
|
||||
"""
|
||||
|
||||
_base_embed_model: BaseEmbedding = PrivateAttr()
|
||||
_adapter: Any = PrivateAttr()
|
||||
_transform_query: bool = PrivateAttr()
|
||||
_device: Optional[str] = PrivateAttr()
|
||||
_target_device: Any = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_embed_model: BaseEmbedding,
|
||||
adapter_path: str,
|
||||
adapter_cls: Optional[Type[Any]] = None,
|
||||
transform_query: bool = True,
|
||||
device: Optional[str] = None,
|
||||
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
import torch
|
||||
|
||||
from llama_index.embeddings.adapter_utils import BaseAdapter, LinearLayer
|
||||
|
||||
if device is None:
|
||||
device = infer_torch_device()
|
||||
logger.info(f"Use pytorch device: {device}")
|
||||
self._target_device = torch.device(device)
|
||||
|
||||
self._base_embed_model = base_embed_model
|
||||
|
||||
if adapter_cls is None:
|
||||
adapter_cls = LinearLayer
|
||||
else:
|
||||
adapter_cls = cast(Type[BaseAdapter], adapter_cls)
|
||||
|
||||
adapter = adapter_cls.load(adapter_path)
|
||||
self._adapter = cast(BaseAdapter, adapter)
|
||||
self._adapter.to(self._target_device)
|
||||
|
||||
self._transform_query = transform_query
|
||||
super().__init__(
|
||||
embed_batch_size=embed_batch_size,
|
||||
callback_manager=callback_manager,
|
||||
model_name=f"Adapter for {base_embed_model.model_name}",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "AdapterEmbeddingModel"
|
||||
|
||||
def _get_query_embedding(self, query: str) -> List[float]:
|
||||
"""Get query embedding."""
|
||||
import torch
|
||||
|
||||
query_embedding = self._base_embed_model._get_query_embedding(query)
|
||||
if self._transform_query:
|
||||
query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
|
||||
query_embedding_t = self._adapter.forward(query_embedding_t)
|
||||
query_embedding = query_embedding_t.tolist()
|
||||
|
||||
return query_embedding
|
||||
|
||||
async def _aget_query_embedding(self, query: str) -> List[float]:
|
||||
"""Get query embedding."""
|
||||
import torch
|
||||
|
||||
query_embedding = await self._base_embed_model._aget_query_embedding(query)
|
||||
if self._transform_query:
|
||||
query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
|
||||
query_embedding_t = self._adapter.forward(query_embedding_t)
|
||||
query_embedding = query_embedding_t.tolist()
|
||||
|
||||
return query_embedding
|
||||
|
||||
def _get_text_embedding(self, text: str) -> List[float]:
|
||||
return self._base_embed_model._get_text_embedding(text)
|
||||
|
||||
async def _aget_text_embedding(self, text: str) -> List[float]:
|
||||
return await self._base_embed_model._aget_text_embedding(text)
|
||||
|
||||
|
||||
# Maintain for backwards compatibility
|
||||
LinearAdapterEmbeddingModel = AdapterEmbeddingModel
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
"""Adapter utils."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAdapter(nn.Module):
|
||||
"""Base adapter.
|
||||
|
||||
Can be subclassed to implement custom adapters.
|
||||
To implement a custom adapter, subclass this class and implement the
|
||||
following methods:
|
||||
- get_config_dict
|
||||
- forward
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_config_dict(self) -> Dict:
|
||||
"""Get config dict."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, embed: Tensor) -> Tensor:
|
||||
"""Forward pass."""
|
||||
|
||||
def save(self, output_path: str) -> None:
|
||||
"""Save model."""
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
with open(os.path.join(output_path, "config.json"), "w") as fOut:
|
||||
json.dump(self.get_config_dict(), fOut)
|
||||
torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
|
||||
|
||||
@classmethod
|
||||
def load(cls, input_path: str) -> "BaseAdapter":
|
||||
"""Load model."""
|
||||
with open(os.path.join(input_path, "config.json")) as fIn:
|
||||
config = json.load(fIn)
|
||||
model = cls(**config)
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(input_path, "pytorch_model.bin"),
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
class LinearLayer(BaseAdapter):
|
||||
"""Linear transformation.
|
||||
|
||||
Args:
|
||||
in_features (int): Input dimension.
|
||||
out_features (int): Output dimension.
|
||||
bias (bool): Whether to use bias. Defaults to False.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.bias = bias
|
||||
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
||||
# seed with identity matrix and 0 bias
|
||||
# only works for square matrices
|
||||
self.linear.weight.data.copy_(torch.eye(in_features, out_features))
|
||||
if bias:
|
||||
self.linear.bias.data.copy_(torch.zeros(out_features))
|
||||
|
||||
def forward(self, embed: Tensor) -> Tensor:
|
||||
"""Forward pass (Wv)."""
|
||||
return self.linear(embed)
|
||||
|
||||
def get_config_dict(self) -> Dict:
|
||||
return {
|
||||
"in_features": self.in_features,
|
||||
"out_features": self.out_features,
|
||||
"bias": self.bias,
|
||||
}
|
||||
|
||||
|
||||
def get_activation_function(name: str) -> Callable:
|
||||
"""Get activation function.
|
||||
|
||||
Args:
|
||||
name (str): Name of activation function.
|
||||
|
||||
"""
|
||||
activations: Dict[str, Callable] = {
|
||||
"relu": F.relu,
|
||||
"sigmoid": torch.sigmoid,
|
||||
"tanh": torch.tanh,
|
||||
"leaky_relu": F.leaky_relu,
|
||||
# add more activations here as needed
|
||||
}
|
||||
if name not in activations:
|
||||
raise ValueError(f"Unknown activation function: {name}")
|
||||
return activations[name]
|
||||
|
||||
|
||||
class TwoLayerNN(BaseAdapter):
|
||||
"""Two-layer transformation.
|
||||
|
||||
Args:
|
||||
in_features (int): Input dimension.
|
||||
hidden_features (int): Hidden dimension.
|
||||
out_features (int): Output dimension.
|
||||
bias (bool): Whether to use bias. Defaults to False.
|
||||
activation_fn_str (str): Name of activation function. Defaults to "relu".
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
out_features: int,
|
||||
bias: bool = False,
|
||||
activation_fn_str: str = "relu",
|
||||
add_residual: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.hidden_features = hidden_features
|
||||
self.out_features = out_features
|
||||
self.bias = bias
|
||||
self.activation_fn_str = activation_fn_str
|
||||
|
||||
self.linear1 = nn.Linear(in_features, hidden_features, bias=True)
|
||||
self.linear2 = nn.Linear(hidden_features, out_features, bias=True)
|
||||
# self.linear1.weight.data.copy_(torch.zeros(hidden_features, in_features))
|
||||
# self.linear2.weight.data.copy_(torch.zeros(out_features, hidden_features))
|
||||
# if bias:
|
||||
# self.linear1.bias.data.copy_(torch.zeros(hidden_features))
|
||||
# self.linear2.bias.data.copy_(torch.zeros(out_features))
|
||||
|
||||
self._activation_function = get_activation_function(activation_fn_str)
|
||||
self._add_residual = add_residual
|
||||
# if add_residual, then add residual_weight (init to 0)
|
||||
self.residual_weight = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, embed: Tensor) -> Tensor:
|
||||
"""Forward pass (Wv).
|
||||
|
||||
Args:
|
||||
embed (Tensor): Input tensor.
|
||||
|
||||
"""
|
||||
output1 = self.linear1(embed)
|
||||
output1 = self._activation_function(output1)
|
||||
output2 = self.linear2(output1)
|
||||
|
||||
if self._add_residual:
|
||||
# print(output2)
|
||||
# print(self.residual_weight)
|
||||
# print(self.linear2.weight.data)
|
||||
output2 = self.residual_weight * output2 + embed
|
||||
|
||||
return output2
|
||||
|
||||
def get_config_dict(self) -> Dict:
|
||||
"""Get config dict."""
|
||||
return {
|
||||
"in_features": self.in_features,
|
||||
"hidden_features": self.hidden_features,
|
||||
"out_features": self.out_features,
|
||||
"bias": self.bias,
|
||||
"activation_fn_str": self.activation_fn_str,
|
||||
"add_residual": self._add_residual,
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue