faiss_rag_enterprise/llama_index/callbacks/utils.py

61 lines
2.1 KiB
Python

import asyncio
import functools
import logging
from typing import Any, Callable, cast
from llama_index.callbacks.base import CallbackManager
logger = logging.getLogger(__name__)
def trace_method(
trace_id: str, callback_manager_attr: str = "callback_manager"
) -> Callable[[Callable], Callable]:
"""
Decorator to trace a method.
Example:
@trace_method("my_trace_id")
def my_method(self):
pass
Assumes that the self instance has a CallbackManager instance in an attribute
named `callback_manager`.
This can be overridden by passing in a `callback_manager_attr` keyword argument.
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func) # preserve signature, name, etc. of func
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
try:
callback_manager = getattr(self, callback_manager_attr)
except AttributeError:
logger.warning(
"Could not find attribute %s on %s.",
callback_manager_attr,
type(self),
)
return func(self, *args, **kwargs)
callback_manager = cast(CallbackManager, callback_manager)
with callback_manager.as_trace(trace_id):
return func(self, *args, **kwargs)
@functools.wraps(func) # preserve signature, name, etc. of func
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
try:
callback_manager = getattr(self, callback_manager_attr)
except AttributeError:
logger.warning(
"Could not find attribute %s on %s.",
callback_manager_attr,
type(self),
)
return await func(self, *args, **kwargs)
callback_manager = cast(CallbackManager, callback_manager)
with callback_manager.as_trace(trace_id):
return await func(self, *args, **kwargs)
return async_wrapper if asyncio.iscoroutinefunction(func) else wrapper
return decorator