66 lines
2.1 KiB
Python
66 lines
2.1 KiB
Python
from typing import Any, Dict, List, Optional, cast
|
|
|
|
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
|
|
|
|
class SimpleLLMHandler(BaseCallbackHandler):
|
|
"""Callback handler for printing llms inputs/outputs."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
|
|
|
|
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
|
return
|
|
|
|
def end_trace(
|
|
self,
|
|
trace_id: Optional[str] = None,
|
|
trace_map: Optional[Dict[str, List[str]]] = None,
|
|
) -> None:
|
|
return
|
|
|
|
def _print_llm_event(self, payload: dict) -> None:
|
|
from llama_index.llms import ChatMessage
|
|
|
|
if EventPayload.PROMPT in payload:
|
|
prompt = str(payload.get(EventPayload.PROMPT))
|
|
completion = str(payload.get(EventPayload.COMPLETION))
|
|
|
|
print(f"** Prompt: **\n{prompt}")
|
|
print("*" * 50)
|
|
print(f"** Completion: **\n{completion}")
|
|
print("*" * 50)
|
|
print("\n")
|
|
elif EventPayload.MESSAGES in payload:
|
|
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
|
|
messages_str = "\n".join([str(x) for x in messages])
|
|
response = str(payload.get(EventPayload.RESPONSE))
|
|
|
|
print(f"** Messages: **\n{messages_str}")
|
|
print("*" * 50)
|
|
print(f"** Response: **\n{response}")
|
|
print("*" * 50)
|
|
print("\n")
|
|
|
|
def on_event_start(
|
|
self,
|
|
event_type: CBEventType,
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
event_id: str = "",
|
|
parent_id: str = "",
|
|
**kwargs: Any,
|
|
) -> str:
|
|
return event_id
|
|
|
|
def on_event_end(
|
|
self,
|
|
event_type: CBEventType,
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
event_id: str = "",
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Count the LLM or Embedding tokens as needed."""
|
|
if event_type == CBEventType.LLM and payload is not None:
|
|
self._print_llm_event(payload)
|