faiss_rag_enterprise/llama_index/callbacks/open_inference_callback.py

248 lines
7.7 KiB
Python

"""
Callback handler for storing generation data in OpenInference format.
OpenInference is an open standard for capturing and storing AI model inferences.
It enables production LLMapp servers to seamlessly integrate with LLM
observability solutions such as Arize and Phoenix.
For more information on the specification, see
https://github.com/Arize-ai/open-inference-spec
"""
import importlib
import uuid
from dataclasses import dataclass, field, fields
from datetime import datetime
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, TypeVar
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
if TYPE_CHECKING:
from pandas import DataFrame
OPENINFERENCE_COLUMN_NAME = "openinference_column_name"
Embedding = List[float]
def _generate_random_id() -> str:
"""Generates a random ID.
Returns:
str: A random ID.
"""
return str(uuid.uuid4())
@dataclass
class QueryData:
"""
Query data with column names following the OpenInference specification.
"""
id: str = field(
default_factory=_generate_random_id,
metadata={OPENINFERENCE_COLUMN_NAME: ":id.id:"},
)
timestamp: Optional[str] = field(
default=None, metadata={OPENINFERENCE_COLUMN_NAME: ":timestamp.iso_8601:"}
)
query_text: Optional[str] = field(
default=None,
metadata={OPENINFERENCE_COLUMN_NAME: ":feature.text:prompt"},
)
query_embedding: Optional[Embedding] = field(
default=None,
metadata={OPENINFERENCE_COLUMN_NAME: ":feature.[float].embedding:prompt"},
)
response_text: Optional[str] = field(
default=None, metadata={OPENINFERENCE_COLUMN_NAME: ":prediction.text:response"}
)
node_ids: List[str] = field(
default_factory=list,
metadata={
OPENINFERENCE_COLUMN_NAME: ":feature.[str].retrieved_document_ids:prompt"
},
)
scores: List[float] = field(
default_factory=list,
metadata={
OPENINFERENCE_COLUMN_NAME: (
":feature.[float].retrieved_document_scores:prompt"
)
},
)
@dataclass
class NodeData:
"""Node data."""
id: str
node_text: Optional[str] = None
node_embedding: Optional[Embedding] = None
BaseDataType = TypeVar("BaseDataType", QueryData, NodeData)
def as_dataframe(data: Iterable[BaseDataType]) -> "DataFrame":
"""Converts a list of BaseDataType to a pandas dataframe.
Args:
data (Iterable[BaseDataType]): A list of BaseDataType.
Returns:
DataFrame: The converted pandas dataframe.
"""
pandas = _import_package("pandas")
as_dict_list = []
for datum in data:
as_dict = {
field.metadata.get(OPENINFERENCE_COLUMN_NAME, field.name): getattr(
datum, field.name
)
for field in fields(datum)
}
as_dict_list.append(as_dict)
return pandas.DataFrame(as_dict_list)
@dataclass
class TraceData:
"""Trace data."""
query_data: QueryData = field(default_factory=QueryData)
node_datas: List[NodeData] = field(default_factory=list)
def _import_package(package_name: str) -> ModuleType:
"""Dynamically imports a package.
Args:
package_name (str): Name of the package to import.
Raises:
ImportError: If the package is not installed.
Returns:
ModuleType: The imported package.
"""
try:
package = importlib.import_module(package_name)
except ImportError:
raise ImportError(f"The {package_name} package must be installed.")
return package
class OpenInferenceCallbackHandler(BaseCallbackHandler):
"""Callback handler for storing generation data in OpenInference format.
OpenInference is an open standard for capturing and storing AI model
inferences. It enables production LLMapp servers to seamlessly integrate
with LLM observability solutions such as Arize and Phoenix.
For more information on the specification, see
https://github.com/Arize-ai/open-inference-spec
"""
def __init__(
self,
callback: Optional[Callable[[List[QueryData], List[NodeData]], None]] = None,
) -> None:
"""Initializes the OpenInferenceCallbackHandler.
Args:
callback (Optional[Callable[[List[QueryData], List[NodeData]], None]], optional): A
callback function that will be called when a query trace is
completed, often used for logging or persisting query data.
"""
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
self._callback = callback
self._trace_data = TraceData()
self._query_data_buffer: List[QueryData] = []
self._node_data_buffer: List[NodeData] = []
def start_trace(self, trace_id: Optional[str] = None) -> None:
if trace_id == "query":
self._trace_data = TraceData()
self._trace_data.query_data.timestamp = datetime.now().isoformat()
self._trace_data.query_data.id = _generate_random_id()
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
if trace_id == "query":
self._query_data_buffer.append(self._trace_data.query_data)
self._node_data_buffer.extend(self._trace_data.node_datas)
self._trace_data = TraceData()
if self._callback is not None:
self._callback(self._query_data_buffer, self._node_data_buffer)
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
if payload is not None:
if event_type is CBEventType.QUERY:
query_text = payload[EventPayload.QUERY_STR]
self._trace_data.query_data.query_text = query_text
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
if payload is None:
return
if event_type is CBEventType.RETRIEVE:
for node_with_score in payload[EventPayload.NODES]:
node = node_with_score.node
score = node_with_score.score
self._trace_data.query_data.node_ids.append(node.hash)
self._trace_data.query_data.scores.append(score)
self._trace_data.node_datas.append(
NodeData(
id=node.hash,
node_text=node.text,
)
)
elif event_type is CBEventType.LLM:
self._trace_data.query_data.response_text = str(
payload.get(EventPayload.RESPONSE, "")
) or str(payload.get(EventPayload.COMPLETION, ""))
elif event_type is CBEventType.EMBEDDING:
self._trace_data.query_data.query_embedding = payload[
EventPayload.EMBEDDINGS
][0]
def flush_query_data_buffer(self) -> List[QueryData]:
"""Clears the query data buffer and returns the data.
Returns:
List[QueryData]: The query data.
"""
query_data_buffer = self._query_data_buffer
self._query_data_buffer = []
return query_data_buffer
def flush_node_data_buffer(self) -> List[NodeData]:
"""Clears the node data buffer and returns the data.
Returns:
List[NodeData]: The node data.
"""
node_data_buffer = self._node_data_buffer
self._node_data_buffer = []
return node_data_buffer