faiss_rag_enterprise/llama_index/tools/tool_spec/base.py

122 lines
4.4 KiB
Python

"""Base tool spec class."""
import asyncio
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from llama_index.bridge.pydantic import BaseModel
from llama_index.tools.function_tool import FunctionTool
from llama_index.tools.types import ToolMetadata
from llama_index.tools.utils import create_schema_from_function
AsyncCallable = Callable[..., Awaitable[Any]]
# TODO: deprecate the Tuple (there's no use for it)
SPEC_FUNCTION_TYPE = Union[str, Tuple[str, str]]
class BaseToolSpec:
"""Base tool spec class."""
# list of functions that you'd want to convert to spec
spec_functions: List[SPEC_FUNCTION_TYPE]
def get_fn_schema_from_fn_name(
self, fn_name: str, spec_functions: Optional[List[SPEC_FUNCTION_TYPE]] = None
) -> Optional[Type[BaseModel]]:
"""Return map from function name.
Return type is Optional, meaning that the schema can be None.
In this case, it's up to the downstream tool implementation to infer the schema.
"""
spec_functions = spec_functions or self.spec_functions
for fn in spec_functions:
if fn == fn_name:
return create_schema_from_function(fn_name, getattr(self, fn_name))
raise ValueError(f"Invalid function name: {fn_name}")
def get_metadata_from_fn_name(
self, fn_name: str, spec_functions: Optional[List[SPEC_FUNCTION_TYPE]] = None
) -> Optional[ToolMetadata]:
"""Return map from function name.
Return type is Optional, meaning that the schema can be None.
In this case, it's up to the downstream tool implementation to infer the schema.
"""
try:
func = getattr(self, fn_name)
except AttributeError:
return None
name = fn_name
docstring = func.__doc__ or ""
description = f"{name}{signature(func)}\n{docstring}"
fn_schema = self.get_fn_schema_from_fn_name(
fn_name, spec_functions=spec_functions
)
return ToolMetadata(name=name, description=description, fn_schema=fn_schema)
def to_tool_list(
self,
spec_functions: Optional[List[SPEC_FUNCTION_TYPE]] = None,
func_to_metadata_mapping: Optional[Dict[str, ToolMetadata]] = None,
) -> List[FunctionTool]:
"""Convert tool spec to list of tools."""
spec_functions = spec_functions or self.spec_functions
func_to_metadata_mapping = func_to_metadata_mapping or {}
tool_list = []
for func_spec in spec_functions:
func_sync = None
func_async = None
if isinstance(func_spec, str):
func = getattr(self, func_spec)
if asyncio.iscoroutinefunction(func):
func_async = func
else:
func_sync = func
metadata = func_to_metadata_mapping.get(func_spec, None)
if metadata is None:
metadata = self.get_metadata_from_fn_name(func_spec)
elif isinstance(func_spec, tuple) and len(func_spec) == 2:
func_sync = getattr(self, func_spec[0])
func_async = getattr(self, func_spec[1])
metadata = func_to_metadata_mapping.get(func_spec[0], None)
if metadata is None:
metadata = func_to_metadata_mapping.get(func_spec[1], None)
if metadata is None:
metadata = self.get_metadata_from_fn_name(func_spec[0])
else:
raise ValueError(
"spec_functions must be of type: List[Union[str, Tuple[str, str]]]"
)
if func_sync is None:
if func_async is not None:
func_sync = patch_sync(func_async)
else:
raise ValueError(
f"Could not retrieve a function for spec: {func_spec}"
)
tool = FunctionTool.from_defaults(
fn=func_sync,
async_fn=func_async,
tool_metadata=metadata,
)
tool_list.append(tool)
return tool_list
def patch_sync(func_async: AsyncCallable) -> Callable:
"""Patch sync function from async function."""
def patched_sync(*args: Any, **kwargs: Any) -> Any:
loop = asyncio.get_event_loop()
return loop.run_until_complete(func_async(*args, **kwargs))
return patched_sync