faiss_rag_enterprise/llama_index/callbacks/llama_debug.py

206 lines
7.5 KiB
Python

from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, List, Optional
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import (
BASE_TRACE_EVENT,
TIMESTAMP_FORMAT,
CBEvent,
CBEventType,
EventStats,
)
class LlamaDebugHandler(BaseCallbackHandler):
"""Callback handler that keeps track of debug info.
NOTE: this is a beta feature. The usage within our codebase, and the interface
may change.
This handler simply keeps track of event starts/ends, separated by event types.
You can use this callback handler to keep track of and debug events.
Args:
event_starts_to_ignore (Optional[List[CBEventType]]): list of event types to
ignore when tracking event starts.
event_ends_to_ignore (Optional[List[CBEventType]]): list of event types to
ignore when tracking event ends.
"""
def __init__(
self,
event_starts_to_ignore: Optional[List[CBEventType]] = None,
event_ends_to_ignore: Optional[List[CBEventType]] = None,
print_trace_on_end: bool = True,
) -> None:
"""Initialize the llama debug handler."""
self._event_pairs_by_type: Dict[CBEventType, List[CBEvent]] = defaultdict(list)
self._event_pairs_by_id: Dict[str, List[CBEvent]] = defaultdict(list)
self._sequential_events: List[CBEvent] = []
self._cur_trace_id: Optional[str] = None
self._trace_map: Dict[str, List[str]] = defaultdict(list)
self.print_trace_on_end = print_trace_on_end
event_starts_to_ignore = (
event_starts_to_ignore if event_starts_to_ignore else []
)
event_ends_to_ignore = event_ends_to_ignore if event_ends_to_ignore else []
super().__init__(
event_starts_to_ignore=event_starts_to_ignore,
event_ends_to_ignore=event_ends_to_ignore,
)
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Store event start data by event type.
Args:
event_type (CBEventType): event type to store.
payload (Optional[Dict[str, Any]]): payload to store.
event_id (str): event id to store.
parent_id (str): parent event id.
"""
event = CBEvent(event_type, payload=payload, id_=event_id)
self._event_pairs_by_type[event.event_type].append(event)
self._event_pairs_by_id[event.id_].append(event)
self._sequential_events.append(event)
return event.id_
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Store event end data by event type.
Args:
event_type (CBEventType): event type to store.
payload (Optional[Dict[str, Any]]): payload to store.
event_id (str): event id to store.
"""
event = CBEvent(event_type, payload=payload, id_=event_id)
self._event_pairs_by_type[event.event_type].append(event)
self._event_pairs_by_id[event.id_].append(event)
self._sequential_events.append(event)
self._trace_map = defaultdict(list)
def get_events(self, event_type: Optional[CBEventType] = None) -> List[CBEvent]:
"""Get all events for a specific event type."""
if event_type is not None:
return self._event_pairs_by_type[event_type]
return self._sequential_events
def _get_event_pairs(self, events: List[CBEvent]) -> List[List[CBEvent]]:
"""Helper function to pair events according to their ID."""
event_pairs: Dict[str, List[CBEvent]] = defaultdict(list)
for event in events:
event_pairs[event.id_].append(event)
return sorted(
event_pairs.values(),
key=lambda x: datetime.strptime(x[0].time, TIMESTAMP_FORMAT),
)
def _get_time_stats_from_event_pairs(
self, event_pairs: List[List[CBEvent]]
) -> EventStats:
"""Calculate time-based stats for a set of event pairs."""
total_secs = 0.0
for event_pair in event_pairs:
start_time = datetime.strptime(event_pair[0].time, TIMESTAMP_FORMAT)
end_time = datetime.strptime(event_pair[-1].time, TIMESTAMP_FORMAT)
total_secs += (end_time - start_time).total_seconds()
return EventStats(
total_secs=total_secs,
average_secs=total_secs / len(event_pairs),
total_count=len(event_pairs),
)
def get_event_pairs(
self, event_type: Optional[CBEventType] = None
) -> List[List[CBEvent]]:
"""Pair events by ID, either all events or a specific type."""
if event_type is not None:
return self._get_event_pairs(self._event_pairs_by_type[event_type])
return self._get_event_pairs(self._sequential_events)
def get_llm_inputs_outputs(self) -> List[List[CBEvent]]:
"""Get the exact LLM inputs and outputs."""
return self._get_event_pairs(self._event_pairs_by_type[CBEventType.LLM])
def get_event_time_info(
self, event_type: Optional[CBEventType] = None
) -> EventStats:
event_pairs = self.get_event_pairs(event_type)
return self._get_time_stats_from_event_pairs(event_pairs)
def flush_event_logs(self) -> None:
"""Clear all events from memory."""
self._event_pairs_by_type = defaultdict(list)
self._event_pairs_by_id = defaultdict(list)
self._sequential_events = []
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Launch a trace."""
self._trace_map = defaultdict(list)
self._cur_trace_id = trace_id
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Shutdown the current trace."""
self._trace_map = trace_map or defaultdict(list)
if self.print_trace_on_end:
self.print_trace_map()
def _print_trace_map(self, cur_event_id: str, level: int = 0) -> None:
"""Recursively print trace map to terminal for debugging."""
event_pair = self._event_pairs_by_id[cur_event_id]
if event_pair:
time_stats = self._get_time_stats_from_event_pairs([event_pair])
indent = " " * level * 2
print(
f"{indent}|_{event_pair[0].event_type} -> ",
f"{time_stats.total_secs} seconds",
flush=True,
)
child_event_ids = self._trace_map[cur_event_id]
for child_event_id in child_event_ids:
self._print_trace_map(child_event_id, level=level + 1)
def print_trace_map(self) -> None:
"""Print simple trace map to terminal for debugging of the most recent trace."""
print("*" * 10, flush=True)
print(f"Trace: {self._cur_trace_id}", flush=True)
self._print_trace_map(BASE_TRACE_EVENT, level=1)
print("*" * 10, flush=True)
@property
def event_pairs_by_type(self) -> Dict[CBEventType, List[CBEvent]]:
return self._event_pairs_by_type
@property
def events_pairs_by_id(self) -> Dict[str, List[CBEvent]]:
return self._event_pairs_by_id
@property
def sequential_events(self) -> List[CBEvent]:
return self._sequential_events