122 lines
4.4 KiB
Python
122 lines
4.4 KiB
Python
"""Base tool spec class."""
|
|
|
|
|
|
import asyncio
|
|
from inspect import signature
|
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
|
|
|
from llama_index.bridge.pydantic import BaseModel
|
|
from llama_index.tools.function_tool import FunctionTool
|
|
from llama_index.tools.types import ToolMetadata
|
|
from llama_index.tools.utils import create_schema_from_function
|
|
|
|
AsyncCallable = Callable[..., Awaitable[Any]]
|
|
|
|
|
|
# TODO: deprecate the Tuple (there's no use for it)
|
|
SPEC_FUNCTION_TYPE = Union[str, Tuple[str, str]]
|
|
|
|
|
|
class BaseToolSpec:
|
|
"""Base tool spec class."""
|
|
|
|
# list of functions that you'd want to convert to spec
|
|
spec_functions: List[SPEC_FUNCTION_TYPE]
|
|
|
|
def get_fn_schema_from_fn_name(
|
|
self, fn_name: str, spec_functions: Optional[List[SPEC_FUNCTION_TYPE]] = None
|
|
) -> Optional[Type[BaseModel]]:
|
|
"""Return map from function name.
|
|
|
|
Return type is Optional, meaning that the schema can be None.
|
|
In this case, it's up to the downstream tool implementation to infer the schema.
|
|
|
|
"""
|
|
spec_functions = spec_functions or self.spec_functions
|
|
for fn in spec_functions:
|
|
if fn == fn_name:
|
|
return create_schema_from_function(fn_name, getattr(self, fn_name))
|
|
|
|
raise ValueError(f"Invalid function name: {fn_name}")
|
|
|
|
def get_metadata_from_fn_name(
|
|
self, fn_name: str, spec_functions: Optional[List[SPEC_FUNCTION_TYPE]] = None
|
|
) -> Optional[ToolMetadata]:
|
|
"""Return map from function name.
|
|
|
|
Return type is Optional, meaning that the schema can be None.
|
|
In this case, it's up to the downstream tool implementation to infer the schema.
|
|
|
|
"""
|
|
try:
|
|
func = getattr(self, fn_name)
|
|
except AttributeError:
|
|
return None
|
|
name = fn_name
|
|
docstring = func.__doc__ or ""
|
|
description = f"{name}{signature(func)}\n{docstring}"
|
|
fn_schema = self.get_fn_schema_from_fn_name(
|
|
fn_name, spec_functions=spec_functions
|
|
)
|
|
return ToolMetadata(name=name, description=description, fn_schema=fn_schema)
|
|
|
|
def to_tool_list(
|
|
self,
|
|
spec_functions: Optional[List[SPEC_FUNCTION_TYPE]] = None,
|
|
func_to_metadata_mapping: Optional[Dict[str, ToolMetadata]] = None,
|
|
) -> List[FunctionTool]:
|
|
"""Convert tool spec to list of tools."""
|
|
spec_functions = spec_functions or self.spec_functions
|
|
func_to_metadata_mapping = func_to_metadata_mapping or {}
|
|
tool_list = []
|
|
for func_spec in spec_functions:
|
|
func_sync = None
|
|
func_async = None
|
|
if isinstance(func_spec, str):
|
|
func = getattr(self, func_spec)
|
|
if asyncio.iscoroutinefunction(func):
|
|
func_async = func
|
|
else:
|
|
func_sync = func
|
|
metadata = func_to_metadata_mapping.get(func_spec, None)
|
|
if metadata is None:
|
|
metadata = self.get_metadata_from_fn_name(func_spec)
|
|
elif isinstance(func_spec, tuple) and len(func_spec) == 2:
|
|
func_sync = getattr(self, func_spec[0])
|
|
func_async = getattr(self, func_spec[1])
|
|
metadata = func_to_metadata_mapping.get(func_spec[0], None)
|
|
if metadata is None:
|
|
metadata = func_to_metadata_mapping.get(func_spec[1], None)
|
|
if metadata is None:
|
|
metadata = self.get_metadata_from_fn_name(func_spec[0])
|
|
else:
|
|
raise ValueError(
|
|
"spec_functions must be of type: List[Union[str, Tuple[str, str]]]"
|
|
)
|
|
|
|
if func_sync is None:
|
|
if func_async is not None:
|
|
func_sync = patch_sync(func_async)
|
|
else:
|
|
raise ValueError(
|
|
f"Could not retrieve a function for spec: {func_spec}"
|
|
)
|
|
|
|
tool = FunctionTool.from_defaults(
|
|
fn=func_sync,
|
|
async_fn=func_async,
|
|
tool_metadata=metadata,
|
|
)
|
|
tool_list.append(tool)
|
|
return tool_list
|
|
|
|
|
|
def patch_sync(func_async: AsyncCallable) -> Callable:
|
|
"""Patch sync function from async function."""
|
|
|
|
def patched_sync(*args: Any, **kwargs: Any) -> Any:
|
|
loop = asyncio.get_event_loop()
|
|
return loop.run_until_complete(func_async(*args, **kwargs))
|
|
|
|
return patched_sync
|