248 lines
7.7 KiB
Python
248 lines
7.7 KiB
Python
"""
|
|
Callback handler for storing generation data in OpenInference format.
|
|
OpenInference is an open standard for capturing and storing AI model inferences.
|
|
It enables production LLMapp servers to seamlessly integrate with LLM
|
|
observability solutions such as Arize and Phoenix.
|
|
|
|
For more information on the specification, see
|
|
https://github.com/Arize-ai/open-inference-spec
|
|
"""
|
|
|
|
import importlib
|
|
import uuid
|
|
from dataclasses import dataclass, field, fields
|
|
from datetime import datetime
|
|
from types import ModuleType
|
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, TypeVar
|
|
|
|
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
|
|
if TYPE_CHECKING:
|
|
from pandas import DataFrame
|
|
|
|
|
|
OPENINFERENCE_COLUMN_NAME = "openinference_column_name"
|
|
Embedding = List[float]
|
|
|
|
|
|
def _generate_random_id() -> str:
|
|
"""Generates a random ID.
|
|
|
|
Returns:
|
|
str: A random ID.
|
|
"""
|
|
return str(uuid.uuid4())
|
|
|
|
|
|
@dataclass
|
|
class QueryData:
|
|
"""
|
|
Query data with column names following the OpenInference specification.
|
|
"""
|
|
|
|
id: str = field(
|
|
default_factory=_generate_random_id,
|
|
metadata={OPENINFERENCE_COLUMN_NAME: ":id.id:"},
|
|
)
|
|
timestamp: Optional[str] = field(
|
|
default=None, metadata={OPENINFERENCE_COLUMN_NAME: ":timestamp.iso_8601:"}
|
|
)
|
|
query_text: Optional[str] = field(
|
|
default=None,
|
|
metadata={OPENINFERENCE_COLUMN_NAME: ":feature.text:prompt"},
|
|
)
|
|
query_embedding: Optional[Embedding] = field(
|
|
default=None,
|
|
metadata={OPENINFERENCE_COLUMN_NAME: ":feature.[float].embedding:prompt"},
|
|
)
|
|
response_text: Optional[str] = field(
|
|
default=None, metadata={OPENINFERENCE_COLUMN_NAME: ":prediction.text:response"}
|
|
)
|
|
node_ids: List[str] = field(
|
|
default_factory=list,
|
|
metadata={
|
|
OPENINFERENCE_COLUMN_NAME: ":feature.[str].retrieved_document_ids:prompt"
|
|
},
|
|
)
|
|
scores: List[float] = field(
|
|
default_factory=list,
|
|
metadata={
|
|
OPENINFERENCE_COLUMN_NAME: (
|
|
":feature.[float].retrieved_document_scores:prompt"
|
|
)
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class NodeData:
|
|
"""Node data."""
|
|
|
|
id: str
|
|
node_text: Optional[str] = None
|
|
node_embedding: Optional[Embedding] = None
|
|
|
|
|
|
BaseDataType = TypeVar("BaseDataType", QueryData, NodeData)
|
|
|
|
|
|
def as_dataframe(data: Iterable[BaseDataType]) -> "DataFrame":
|
|
"""Converts a list of BaseDataType to a pandas dataframe.
|
|
|
|
Args:
|
|
data (Iterable[BaseDataType]): A list of BaseDataType.
|
|
|
|
Returns:
|
|
DataFrame: The converted pandas dataframe.
|
|
"""
|
|
pandas = _import_package("pandas")
|
|
as_dict_list = []
|
|
for datum in data:
|
|
as_dict = {
|
|
field.metadata.get(OPENINFERENCE_COLUMN_NAME, field.name): getattr(
|
|
datum, field.name
|
|
)
|
|
for field in fields(datum)
|
|
}
|
|
as_dict_list.append(as_dict)
|
|
|
|
return pandas.DataFrame(as_dict_list)
|
|
|
|
|
|
@dataclass
|
|
class TraceData:
|
|
"""Trace data."""
|
|
|
|
query_data: QueryData = field(default_factory=QueryData)
|
|
node_datas: List[NodeData] = field(default_factory=list)
|
|
|
|
|
|
def _import_package(package_name: str) -> ModuleType:
|
|
"""Dynamically imports a package.
|
|
|
|
Args:
|
|
package_name (str): Name of the package to import.
|
|
|
|
Raises:
|
|
ImportError: If the package is not installed.
|
|
|
|
Returns:
|
|
ModuleType: The imported package.
|
|
"""
|
|
try:
|
|
package = importlib.import_module(package_name)
|
|
except ImportError:
|
|
raise ImportError(f"The {package_name} package must be installed.")
|
|
return package
|
|
|
|
|
|
class OpenInferenceCallbackHandler(BaseCallbackHandler):
|
|
"""Callback handler for storing generation data in OpenInference format.
|
|
OpenInference is an open standard for capturing and storing AI model
|
|
inferences. It enables production LLMapp servers to seamlessly integrate
|
|
with LLM observability solutions such as Arize and Phoenix.
|
|
|
|
For more information on the specification, see
|
|
https://github.com/Arize-ai/open-inference-spec
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
callback: Optional[Callable[[List[QueryData], List[NodeData]], None]] = None,
|
|
) -> None:
|
|
"""Initializes the OpenInferenceCallbackHandler.
|
|
|
|
Args:
|
|
callback (Optional[Callable[[List[QueryData], List[NodeData]], None]], optional): A
|
|
callback function that will be called when a query trace is
|
|
completed, often used for logging or persisting query data.
|
|
"""
|
|
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
|
|
self._callback = callback
|
|
self._trace_data = TraceData()
|
|
self._query_data_buffer: List[QueryData] = []
|
|
self._node_data_buffer: List[NodeData] = []
|
|
|
|
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
|
if trace_id == "query":
|
|
self._trace_data = TraceData()
|
|
self._trace_data.query_data.timestamp = datetime.now().isoformat()
|
|
self._trace_data.query_data.id = _generate_random_id()
|
|
|
|
def end_trace(
|
|
self,
|
|
trace_id: Optional[str] = None,
|
|
trace_map: Optional[Dict[str, List[str]]] = None,
|
|
) -> None:
|
|
if trace_id == "query":
|
|
self._query_data_buffer.append(self._trace_data.query_data)
|
|
self._node_data_buffer.extend(self._trace_data.node_datas)
|
|
self._trace_data = TraceData()
|
|
if self._callback is not None:
|
|
self._callback(self._query_data_buffer, self._node_data_buffer)
|
|
|
|
def on_event_start(
|
|
self,
|
|
event_type: CBEventType,
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
event_id: str = "",
|
|
parent_id: str = "",
|
|
**kwargs: Any,
|
|
) -> str:
|
|
if payload is not None:
|
|
if event_type is CBEventType.QUERY:
|
|
query_text = payload[EventPayload.QUERY_STR]
|
|
self._trace_data.query_data.query_text = query_text
|
|
return event_id
|
|
|
|
def on_event_end(
|
|
self,
|
|
event_type: CBEventType,
|
|
payload: Optional[Dict[str, Any]] = None,
|
|
event_id: str = "",
|
|
**kwargs: Any,
|
|
) -> None:
|
|
if payload is None:
|
|
return
|
|
if event_type is CBEventType.RETRIEVE:
|
|
for node_with_score in payload[EventPayload.NODES]:
|
|
node = node_with_score.node
|
|
score = node_with_score.score
|
|
self._trace_data.query_data.node_ids.append(node.hash)
|
|
self._trace_data.query_data.scores.append(score)
|
|
self._trace_data.node_datas.append(
|
|
NodeData(
|
|
id=node.hash,
|
|
node_text=node.text,
|
|
)
|
|
)
|
|
elif event_type is CBEventType.LLM:
|
|
self._trace_data.query_data.response_text = str(
|
|
payload.get(EventPayload.RESPONSE, "")
|
|
) or str(payload.get(EventPayload.COMPLETION, ""))
|
|
elif event_type is CBEventType.EMBEDDING:
|
|
self._trace_data.query_data.query_embedding = payload[
|
|
EventPayload.EMBEDDINGS
|
|
][0]
|
|
|
|
def flush_query_data_buffer(self) -> List[QueryData]:
|
|
"""Clears the query data buffer and returns the data.
|
|
|
|
Returns:
|
|
List[QueryData]: The query data.
|
|
"""
|
|
query_data_buffer = self._query_data_buffer
|
|
self._query_data_buffer = []
|
|
return query_data_buffer
|
|
|
|
def flush_node_data_buffer(self) -> List[NodeData]:
|
|
"""Clears the node data buffer and returns the data.
|
|
|
|
Returns:
|
|
List[NodeData]: The node data.
|
|
"""
|
|
node_data_buffer = self._node_data_buffer
|
|
self._node_data_buffer = []
|
|
return node_data_buffer
|