faiss_rag_enterprise/llama_index/callbacks/base.py

275 lines
9.8 KiB
Python

import logging
import uuid
from abc import ABC
from collections import defaultdict
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, Generator, List, Optional
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import (
BASE_TRACE_EVENT,
LEAF_EVENTS,
CBEventType,
EventPayload,
)
logger = logging.getLogger(__name__)
global_stack_trace = ContextVar("trace", default=[BASE_TRACE_EVENT])
empty_trace_ids: List[str] = []
global_stack_trace_ids = ContextVar("trace_ids", default=empty_trace_ids)
class CallbackManager(BaseCallbackHandler, ABC):
"""
Callback manager that handles callbacks for events within LlamaIndex.
The callback manager provides a way to call handlers on event starts/ends.
Additionally, the callback manager traces the current stack of events.
It does this by using a few key attributes.
- trace_stack - The current stack of events that have not ended yet.
When an event ends, it's removed from the stack.
Since this is a contextvar, it is unique to each
thread/task.
- trace_map - A mapping of event ids to their children events.
On the start of events, the bottom of the trace stack
is used as the current parent event for the trace map.
- trace_id - A simple name for the current trace, usually denoting the
entrypoint (query, index_construction, insert, etc.)
Args:
handlers (List[BaseCallbackHandler]): list of handlers to use.
Usage:
with callback_manager.event(CBEventType.QUERY) as event:
event.on_start(payload={key, val})
...
event.on_end(payload={key, val})
"""
def __init__(self, handlers: Optional[List[BaseCallbackHandler]] = None):
"""Initialize the manager with a list of handlers."""
from llama_index import global_handler
handlers = handlers or []
# add eval handlers based on global defaults
if global_handler is not None:
new_handler = global_handler
# go through existing handlers, check if any are same type as new handler
# if so, error
for existing_handler in handlers:
if isinstance(existing_handler, type(new_handler)):
raise ValueError(
"Cannot add two handlers of the same type "
f"{type(new_handler)} to the callback manager."
)
handlers.append(new_handler)
self.handlers = handlers
self._trace_map: Dict[str, List[str]] = defaultdict(list)
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: Optional[str] = None,
parent_id: Optional[str] = None,
**kwargs: Any,
) -> str:
"""Run handlers when an event starts and return id of event."""
event_id = event_id or str(uuid.uuid4())
# if no trace is running, start a default trace
try:
parent_id = parent_id or global_stack_trace.get()[-1]
except IndexError:
self.start_trace("llama-index")
parent_id = global_stack_trace.get()[-1]
self._trace_map[parent_id].append(event_id)
for handler in self.handlers:
if event_type not in handler.event_starts_to_ignore:
handler.on_event_start(
event_type,
payload,
event_id=event_id,
parent_id=parent_id,
**kwargs,
)
if event_type not in LEAF_EVENTS:
# copy the stack trace to prevent conflicts with threads/coroutines
current_trace_stack = global_stack_trace.get().copy()
current_trace_stack.append(event_id)
global_stack_trace.set(current_trace_stack)
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run handlers when an event ends."""
event_id = event_id or str(uuid.uuid4())
for handler in self.handlers:
if event_type not in handler.event_ends_to_ignore:
handler.on_event_end(event_type, payload, event_id=event_id, **kwargs)
if event_type not in LEAF_EVENTS:
# copy the stack trace to prevent conflicts with threads/coroutines
current_trace_stack = global_stack_trace.get().copy()
current_trace_stack.pop()
global_stack_trace.set(current_trace_stack)
def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = handlers
@contextmanager
def event(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: Optional[str] = None,
) -> Generator["EventContext", None, None]:
"""Context manager for lanching and shutdown of events.
Handles sending on_evnt_start and on_event_end to handlers for specified event.
Usage:
with callback_manager.event(CBEventType.QUERY, payload={key, val}) as event:
...
event.on_end(payload={key, val}) # optional
"""
# create event context wrapper
event = EventContext(self, event_type, event_id=event_id)
event.on_start(payload=payload)
payload = None
try:
yield event
except Exception as e:
# data already logged to trace?
if not hasattr(e, "event_added"):
payload = {EventPayload.EXCEPTION: e}
e.event_added = True # type: ignore
if not event.finished:
event.on_end(payload=payload)
raise
finally:
# ensure event is ended
if not event.finished:
event.on_end(payload=payload)
@contextmanager
def as_trace(self, trace_id: str) -> Generator[None, None, None]:
"""Context manager tracer for lanching and shutdown of traces."""
self.start_trace(trace_id=trace_id)
try:
yield
except Exception as e:
# event already added to trace?
if not hasattr(e, "event_added"):
self.on_event_start(
CBEventType.EXCEPTION, payload={EventPayload.EXCEPTION: e}
)
e.event_added = True # type: ignore
raise
finally:
# ensure trace is ended
self.end_trace(trace_id=trace_id)
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
current_trace_stack_ids = global_stack_trace_ids.get().copy()
if trace_id is not None:
if len(current_trace_stack_ids) == 0:
self._reset_trace_events()
for handler in self.handlers:
handler.start_trace(trace_id=trace_id)
current_trace_stack_ids = [trace_id]
else:
current_trace_stack_ids.append(trace_id)
global_stack_trace_ids.set(current_trace_stack_ids)
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""
current_trace_stack_ids = global_stack_trace_ids.get().copy()
if trace_id is not None and len(current_trace_stack_ids) > 0:
current_trace_stack_ids.pop()
if len(current_trace_stack_ids) == 0:
for handler in self.handlers:
handler.end_trace(trace_id=trace_id, trace_map=self._trace_map)
current_trace_stack_ids = []
global_stack_trace_ids.set(current_trace_stack_ids)
def _reset_trace_events(self) -> None:
"""Helper function to reset the current trace."""
self._trace_map = defaultdict(list)
global_stack_trace.set([BASE_TRACE_EVENT])
@property
def trace_map(self) -> Dict[str, List[str]]:
return self._trace_map
class EventContext:
"""
Simple wrapper to call callbacks on event starts and ends
with an event type and id.
"""
def __init__(
self,
callback_manager: CallbackManager,
event_type: CBEventType,
event_id: Optional[str] = None,
):
self._callback_manager = callback_manager
self._event_type = event_type
self._event_id = event_id or str(uuid.uuid4())
self.started = False
self.finished = False
def on_start(self, payload: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
if not self.started:
self.started = True
self._callback_manager.on_event_start(
self._event_type, payload=payload, event_id=self._event_id, **kwargs
)
else:
logger.warning(
f"Event {self._event_type!s}: {self._event_id} already started!"
)
def on_end(self, payload: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
if not self.finished:
self.finished = True
self._callback_manager.on_event_end(
self._event_type, payload=payload, event_id=self._event_id, **kwargs
)