206 lines
7.5 KiB
Python
206 lines
7.5 KiB
Python
from collections import defaultdict
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
|
from llama_index.callbacks.schema import (
|
|
BASE_TRACE_EVENT,
|
|
TIMESTAMP_FORMAT,
|
|
CBEvent,
|
|
CBEventType,
|
|
EventStats,
|
|
)
|
|
|
|
|
|
class LlamaDebugHandler(BaseCallbackHandler):
|
|
"""Callback handler that keeps track of debug info.
|
|
|
|
NOTE: this is a beta feature. The usage within our codebase, and the interface
|
|
may change.
|
|
|
|
This handler simply keeps track of event starts/ends, separated by event types.
|
|
You can use this callback handler to keep track of and debug events.
|
|
|
|
Args:
|
|
event_starts_to_ignore (Optional[List[CBEventType]]): list of event types to
|
|
ignore when tracking event starts.
|
|
event_ends_to_ignore (Optional[List[CBEventType]]): list of event types to
|
|
ignore when tracking event ends.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
event_starts_to_ignore: Optional[List[CBEventType]] = None,
|
|
event_ends_to_ignore: Optional[List[CBEventType]] = None,
|
|
print_trace_on_end: bool = True,
|
|
) -> None:
|
|
"""Initialize the llama debug handler."""
|
|
self._event_pairs_by_type: Dict[CBEventType, List[CBEvent]] = defaultdict(list)
|
|
self._event_pairs_by_id: Dict[str, List[CBEvent]] = defaultdict(list)
|
|
self._sequential_events: List[CBEvent] = []
|
|
self._cur_trace_id: Optional[str] = None
|
|
self._trace_map: Dict[str, List[str]] = defaultdict(list)
|
|
self.print_trace_on_end = print_trace_on_end
|
|
event_starts_to_ignore = (
|
|
event_starts_to_ignore if event_starts_to_ignore else []
|
|
)
|
|
event_ends_to_ignore = event_ends_to_ignore if event_ends_to_ignore else []
|
|
super().__init__(
|
|
event_starts_to_ignore=event_starts_to_ignore,
|
|
event_ends_to_ignore=event_ends_to_ignore,
|
|
)
|
|
|
|
def on_event_start(
|
|
self,
|
|
event_type: CBEventType,
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
event_id: str = "",
|
|
parent_id: str = "",
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Store event start data by event type.
|
|
|
|
Args:
|
|
event_type (CBEventType): event type to store.
|
|
payload (Optional[Dict[str, Any]]): payload to store.
|
|
event_id (str): event id to store.
|
|
parent_id (str): parent event id.
|
|
|
|
"""
|
|
event = CBEvent(event_type, payload=payload, id_=event_id)
|
|
self._event_pairs_by_type[event.event_type].append(event)
|
|
self._event_pairs_by_id[event.id_].append(event)
|
|
self._sequential_events.append(event)
|
|
return event.id_
|
|
|
|
def on_event_end(
|
|
self,
|
|
event_type: CBEventType,
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
event_id: str = "",
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Store event end data by event type.
|
|
|
|
Args:
|
|
event_type (CBEventType): event type to store.
|
|
payload (Optional[Dict[str, Any]]): payload to store.
|
|
event_id (str): event id to store.
|
|
|
|
"""
|
|
event = CBEvent(event_type, payload=payload, id_=event_id)
|
|
self._event_pairs_by_type[event.event_type].append(event)
|
|
self._event_pairs_by_id[event.id_].append(event)
|
|
self._sequential_events.append(event)
|
|
self._trace_map = defaultdict(list)
|
|
|
|
def get_events(self, event_type: Optional[CBEventType] = None) -> List[CBEvent]:
|
|
"""Get all events for a specific event type."""
|
|
if event_type is not None:
|
|
return self._event_pairs_by_type[event_type]
|
|
|
|
return self._sequential_events
|
|
|
|
def _get_event_pairs(self, events: List[CBEvent]) -> List[List[CBEvent]]:
|
|
"""Helper function to pair events according to their ID."""
|
|
event_pairs: Dict[str, List[CBEvent]] = defaultdict(list)
|
|
for event in events:
|
|
event_pairs[event.id_].append(event)
|
|
|
|
return sorted(
|
|
event_pairs.values(),
|
|
key=lambda x: datetime.strptime(x[0].time, TIMESTAMP_FORMAT),
|
|
)
|
|
|
|
def _get_time_stats_from_event_pairs(
|
|
self, event_pairs: List[List[CBEvent]]
|
|
) -> EventStats:
|
|
"""Calculate time-based stats for a set of event pairs."""
|
|
total_secs = 0.0
|
|
for event_pair in event_pairs:
|
|
start_time = datetime.strptime(event_pair[0].time, TIMESTAMP_FORMAT)
|
|
end_time = datetime.strptime(event_pair[-1].time, TIMESTAMP_FORMAT)
|
|
total_secs += (end_time - start_time).total_seconds()
|
|
|
|
return EventStats(
|
|
total_secs=total_secs,
|
|
average_secs=total_secs / len(event_pairs),
|
|
total_count=len(event_pairs),
|
|
)
|
|
|
|
def get_event_pairs(
|
|
self, event_type: Optional[CBEventType] = None
|
|
) -> List[List[CBEvent]]:
|
|
"""Pair events by ID, either all events or a specific type."""
|
|
if event_type is not None:
|
|
return self._get_event_pairs(self._event_pairs_by_type[event_type])
|
|
|
|
return self._get_event_pairs(self._sequential_events)
|
|
|
|
def get_llm_inputs_outputs(self) -> List[List[CBEvent]]:
|
|
"""Get the exact LLM inputs and outputs."""
|
|
return self._get_event_pairs(self._event_pairs_by_type[CBEventType.LLM])
|
|
|
|
def get_event_time_info(
|
|
self, event_type: Optional[CBEventType] = None
|
|
) -> EventStats:
|
|
event_pairs = self.get_event_pairs(event_type)
|
|
return self._get_time_stats_from_event_pairs(event_pairs)
|
|
|
|
def flush_event_logs(self) -> None:
|
|
"""Clear all events from memory."""
|
|
self._event_pairs_by_type = defaultdict(list)
|
|
self._event_pairs_by_id = defaultdict(list)
|
|
self._sequential_events = []
|
|
|
|
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
|
"""Launch a trace."""
|
|
self._trace_map = defaultdict(list)
|
|
self._cur_trace_id = trace_id
|
|
|
|
def end_trace(
|
|
self,
|
|
trace_id: Optional[str] = None,
|
|
trace_map: Optional[Dict[str, List[str]]] = None,
|
|
) -> None:
|
|
"""Shutdown the current trace."""
|
|
self._trace_map = trace_map or defaultdict(list)
|
|
if self.print_trace_on_end:
|
|
self.print_trace_map()
|
|
|
|
def _print_trace_map(self, cur_event_id: str, level: int = 0) -> None:
|
|
"""Recursively print trace map to terminal for debugging."""
|
|
event_pair = self._event_pairs_by_id[cur_event_id]
|
|
if event_pair:
|
|
time_stats = self._get_time_stats_from_event_pairs([event_pair])
|
|
indent = " " * level * 2
|
|
print(
|
|
f"{indent}|_{event_pair[0].event_type} -> ",
|
|
f"{time_stats.total_secs} seconds",
|
|
flush=True,
|
|
)
|
|
|
|
child_event_ids = self._trace_map[cur_event_id]
|
|
for child_event_id in child_event_ids:
|
|
self._print_trace_map(child_event_id, level=level + 1)
|
|
|
|
def print_trace_map(self) -> None:
|
|
"""Print simple trace map to terminal for debugging of the most recent trace."""
|
|
print("*" * 10, flush=True)
|
|
print(f"Trace: {self._cur_trace_id}", flush=True)
|
|
self._print_trace_map(BASE_TRACE_EVENT, level=1)
|
|
print("*" * 10, flush=True)
|
|
|
|
@property
|
|
def event_pairs_by_type(self) -> Dict[CBEventType, List[CBEvent]]:
|
|
return self._event_pairs_by_type
|
|
|
|
@property
|
|
def events_pairs_by_id(self) -> Dict[str, List[CBEvent]]:
|
|
return self._event_pairs_by_id
|
|
|
|
@property
|
|
def sequential_events(self) -> List[CBEvent]:
|
|
return self._sequential_events
|