275 lines
9.8 KiB
Python
275 lines
9.8 KiB
Python
import logging
|
|
import uuid
|
|
from abc import ABC
|
|
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
from contextvars import ContextVar
|
|
from typing import Any, Dict, Generator, List, Optional
|
|
|
|
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
|
from llama_index.callbacks.schema import (
|
|
BASE_TRACE_EVENT,
|
|
LEAF_EVENTS,
|
|
CBEventType,
|
|
EventPayload,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
global_stack_trace = ContextVar("trace", default=[BASE_TRACE_EVENT])
|
|
empty_trace_ids: List[str] = []
|
|
global_stack_trace_ids = ContextVar("trace_ids", default=empty_trace_ids)
|
|
|
|
|
|
class CallbackManager(BaseCallbackHandler, ABC):
|
|
"""
|
|
Callback manager that handles callbacks for events within LlamaIndex.
|
|
|
|
The callback manager provides a way to call handlers on event starts/ends.
|
|
|
|
Additionally, the callback manager traces the current stack of events.
|
|
It does this by using a few key attributes.
|
|
- trace_stack - The current stack of events that have not ended yet.
|
|
When an event ends, it's removed from the stack.
|
|
Since this is a contextvar, it is unique to each
|
|
thread/task.
|
|
- trace_map - A mapping of event ids to their children events.
|
|
On the start of events, the bottom of the trace stack
|
|
is used as the current parent event for the trace map.
|
|
- trace_id - A simple name for the current trace, usually denoting the
|
|
entrypoint (query, index_construction, insert, etc.)
|
|
|
|
Args:
|
|
handlers (List[BaseCallbackHandler]): list of handlers to use.
|
|
|
|
Usage:
|
|
with callback_manager.event(CBEventType.QUERY) as event:
|
|
event.on_start(payload={key, val})
|
|
...
|
|
event.on_end(payload={key, val})
|
|
|
|
"""
|
|
|
|
def __init__(self, handlers: Optional[List[BaseCallbackHandler]] = None):
|
|
"""Initialize the manager with a list of handlers."""
|
|
from llama_index import global_handler
|
|
|
|
handlers = handlers or []
|
|
|
|
# add eval handlers based on global defaults
|
|
if global_handler is not None:
|
|
new_handler = global_handler
|
|
# go through existing handlers, check if any are same type as new handler
|
|
# if so, error
|
|
for existing_handler in handlers:
|
|
if isinstance(existing_handler, type(new_handler)):
|
|
raise ValueError(
|
|
"Cannot add two handlers of the same type "
|
|
f"{type(new_handler)} to the callback manager."
|
|
)
|
|
handlers.append(new_handler)
|
|
|
|
self.handlers = handlers
|
|
self._trace_map: Dict[str, List[str]] = defaultdict(list)
|
|
|
|
def on_event_start(
|
|
self,
|
|
event_type: CBEventType,
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
event_id: Optional[str] = None,
|
|
parent_id: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Run handlers when an event starts and return id of event."""
|
|
event_id = event_id or str(uuid.uuid4())
|
|
|
|
# if no trace is running, start a default trace
|
|
try:
|
|
parent_id = parent_id or global_stack_trace.get()[-1]
|
|
except IndexError:
|
|
self.start_trace("llama-index")
|
|
parent_id = global_stack_trace.get()[-1]
|
|
|
|
self._trace_map[parent_id].append(event_id)
|
|
for handler in self.handlers:
|
|
if event_type not in handler.event_starts_to_ignore:
|
|
handler.on_event_start(
|
|
event_type,
|
|
payload,
|
|
event_id=event_id,
|
|
parent_id=parent_id,
|
|
**kwargs,
|
|
)
|
|
|
|
if event_type not in LEAF_EVENTS:
|
|
# copy the stack trace to prevent conflicts with threads/coroutines
|
|
current_trace_stack = global_stack_trace.get().copy()
|
|
current_trace_stack.append(event_id)
|
|
global_stack_trace.set(current_trace_stack)
|
|
|
|
return event_id
|
|
|
|
def on_event_end(
|
|
self,
|
|
event_type: CBEventType,
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
event_id: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run handlers when an event ends."""
|
|
event_id = event_id or str(uuid.uuid4())
|
|
for handler in self.handlers:
|
|
if event_type not in handler.event_ends_to_ignore:
|
|
handler.on_event_end(event_type, payload, event_id=event_id, **kwargs)
|
|
|
|
if event_type not in LEAF_EVENTS:
|
|
# copy the stack trace to prevent conflicts with threads/coroutines
|
|
current_trace_stack = global_stack_trace.get().copy()
|
|
current_trace_stack.pop()
|
|
global_stack_trace.set(current_trace_stack)
|
|
|
|
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
|
"""Add a handler to the callback manager."""
|
|
self.handlers.append(handler)
|
|
|
|
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
|
"""Remove a handler from the callback manager."""
|
|
self.handlers.remove(handler)
|
|
|
|
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
|
|
"""Set handlers as the only handlers on the callback manager."""
|
|
self.handlers = handlers
|
|
|
|
@contextmanager
|
|
def event(
|
|
self,
|
|
event_type: CBEventType,
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
event_id: Optional[str] = None,
|
|
) -> Generator["EventContext", None, None]:
|
|
"""Context manager for lanching and shutdown of events.
|
|
|
|
Handles sending on_evnt_start and on_event_end to handlers for specified event.
|
|
|
|
Usage:
|
|
with callback_manager.event(CBEventType.QUERY, payload={key, val}) as event:
|
|
...
|
|
event.on_end(payload={key, val}) # optional
|
|
"""
|
|
# create event context wrapper
|
|
event = EventContext(self, event_type, event_id=event_id)
|
|
event.on_start(payload=payload)
|
|
|
|
payload = None
|
|
try:
|
|
yield event
|
|
except Exception as e:
|
|
# data already logged to trace?
|
|
if not hasattr(e, "event_added"):
|
|
payload = {EventPayload.EXCEPTION: e}
|
|
e.event_added = True # type: ignore
|
|
if not event.finished:
|
|
event.on_end(payload=payload)
|
|
raise
|
|
finally:
|
|
# ensure event is ended
|
|
if not event.finished:
|
|
event.on_end(payload=payload)
|
|
|
|
@contextmanager
|
|
def as_trace(self, trace_id: str) -> Generator[None, None, None]:
|
|
"""Context manager tracer for lanching and shutdown of traces."""
|
|
self.start_trace(trace_id=trace_id)
|
|
|
|
try:
|
|
yield
|
|
except Exception as e:
|
|
# event already added to trace?
|
|
if not hasattr(e, "event_added"):
|
|
self.on_event_start(
|
|
CBEventType.EXCEPTION, payload={EventPayload.EXCEPTION: e}
|
|
)
|
|
e.event_added = True # type: ignore
|
|
|
|
raise
|
|
finally:
|
|
# ensure trace is ended
|
|
self.end_trace(trace_id=trace_id)
|
|
|
|
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
|
"""Run when an overall trace is launched."""
|
|
current_trace_stack_ids = global_stack_trace_ids.get().copy()
|
|
if trace_id is not None:
|
|
if len(current_trace_stack_ids) == 0:
|
|
self._reset_trace_events()
|
|
|
|
for handler in self.handlers:
|
|
handler.start_trace(trace_id=trace_id)
|
|
|
|
current_trace_stack_ids = [trace_id]
|
|
else:
|
|
current_trace_stack_ids.append(trace_id)
|
|
|
|
global_stack_trace_ids.set(current_trace_stack_ids)
|
|
|
|
def end_trace(
|
|
self,
|
|
trace_id: Optional[str] = None,
|
|
trace_map: Optional[Dict[str, List[str]]] = None,
|
|
) -> None:
|
|
"""Run when an overall trace is exited."""
|
|
current_trace_stack_ids = global_stack_trace_ids.get().copy()
|
|
if trace_id is not None and len(current_trace_stack_ids) > 0:
|
|
current_trace_stack_ids.pop()
|
|
if len(current_trace_stack_ids) == 0:
|
|
for handler in self.handlers:
|
|
handler.end_trace(trace_id=trace_id, trace_map=self._trace_map)
|
|
current_trace_stack_ids = []
|
|
|
|
global_stack_trace_ids.set(current_trace_stack_ids)
|
|
|
|
def _reset_trace_events(self) -> None:
|
|
"""Helper function to reset the current trace."""
|
|
self._trace_map = defaultdict(list)
|
|
global_stack_trace.set([BASE_TRACE_EVENT])
|
|
|
|
@property
|
|
def trace_map(self) -> Dict[str, List[str]]:
|
|
return self._trace_map
|
|
|
|
|
|
class EventContext:
|
|
"""
|
|
Simple wrapper to call callbacks on event starts and ends
|
|
with an event type and id.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
callback_manager: CallbackManager,
|
|
event_type: CBEventType,
|
|
event_id: Optional[str] = None,
|
|
):
|
|
self._callback_manager = callback_manager
|
|
self._event_type = event_type
|
|
self._event_id = event_id or str(uuid.uuid4())
|
|
self.started = False
|
|
self.finished = False
|
|
|
|
def on_start(self, payload: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
|
|
if not self.started:
|
|
self.started = True
|
|
self._callback_manager.on_event_start(
|
|
self._event_type, payload=payload, event_id=self._event_id, **kwargs
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Event {self._event_type!s}: {self._event_id} already started!"
|
|
)
|
|
|
|
def on_end(self, payload: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
|
|
if not self.finished:
|
|
self.finished = True
|
|
self._callback_manager.on_event_end(
|
|
self._event_type, payload=payload, event_id=self._event_id, **kwargs
|
|
)
|