148 lines
4.7 KiB
Python
148 lines
4.7 KiB
Python
"""Tool mapping."""
|
|
|
|
from typing import Any, Dict, Optional, Sequence
|
|
|
|
from llama_index.objects.base_node_mapping import (
|
|
DEFAULT_PERSIST_DIR,
|
|
DEFAULT_PERSIST_FNAME,
|
|
BaseObjectNodeMapping,
|
|
)
|
|
from llama_index.schema import BaseNode, TextNode
|
|
from llama_index.tools.query_engine import QueryEngineTool
|
|
from llama_index.tools.types import BaseTool
|
|
|
|
|
|
def convert_tool_to_node(tool: BaseTool) -> TextNode:
|
|
"""Function convert Tool to node."""
|
|
node_text = (
|
|
f"Tool name: {tool.metadata.name}\n"
|
|
f"Tool description: {tool.metadata.description}\n"
|
|
)
|
|
if tool.metadata.fn_schema is not None:
|
|
node_text += f"Tool schema: {tool.metadata.fn_schema.schema()}\n"
|
|
return TextNode(
|
|
text=node_text,
|
|
metadata={"name": tool.metadata.name},
|
|
excluded_embed_metadata_keys=["name"],
|
|
excluded_llm_metadata_keys=["name"],
|
|
)
|
|
|
|
|
|
class BaseToolNodeMapping(BaseObjectNodeMapping[BaseTool]):
|
|
"""Base Tool node mapping."""
|
|
|
|
def validate_object(self, obj: BaseTool) -> None:
|
|
if not isinstance(obj, BaseTool):
|
|
raise ValueError(f"Object must be of type {BaseTool}")
|
|
|
|
@property
|
|
def obj_node_mapping(self) -> Dict[int, Any]:
|
|
"""The mapping data structure between node and object."""
|
|
raise NotImplementedError("Subclasses should implement this!")
|
|
|
|
def persist(
|
|
self, persist_dir: str = ..., obj_node_mapping_fname: str = ...
|
|
) -> None:
|
|
"""Persist objs."""
|
|
raise NotImplementedError("Subclasses should implement this!")
|
|
|
|
@classmethod
|
|
def from_persist_dir(
|
|
cls,
|
|
persist_dir: str = DEFAULT_PERSIST_DIR,
|
|
obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
|
|
) -> "BaseToolNodeMapping":
|
|
raise NotImplementedError(
|
|
"This object node mapping does not support persist method."
|
|
)
|
|
|
|
|
|
class SimpleToolNodeMapping(BaseToolNodeMapping):
|
|
"""Simple Tool mapping.
|
|
|
|
In this setup, we assume that the tool name is unique, and
|
|
that the list of all tools are stored in memory.
|
|
|
|
"""
|
|
|
|
def __init__(self, objs: Optional[Sequence[BaseTool]] = None) -> None:
|
|
objs = objs or []
|
|
self._tools = {tool.metadata.name: tool for tool in objs}
|
|
|
|
@classmethod
|
|
def from_objects(
|
|
cls, objs: Sequence[BaseTool], *args: Any, **kwargs: Any
|
|
) -> "BaseObjectNodeMapping":
|
|
return cls(objs)
|
|
|
|
def _add_object(self, tool: BaseTool) -> None:
|
|
self._tools[tool.metadata.name] = tool
|
|
|
|
def to_node(self, tool: BaseTool) -> TextNode:
|
|
"""To node."""
|
|
return convert_tool_to_node(tool)
|
|
|
|
def _from_node(self, node: BaseNode) -> BaseTool:
|
|
"""From node."""
|
|
if node.metadata is None:
|
|
raise ValueError("Metadata must be set")
|
|
return self._tools[node.metadata["name"]]
|
|
|
|
|
|
class BaseQueryToolNodeMapping(BaseObjectNodeMapping[QueryEngineTool]):
|
|
"""Base query tool node mapping."""
|
|
|
|
@classmethod
|
|
def from_persist_dir(
|
|
cls,
|
|
persist_dir: str = DEFAULT_PERSIST_DIR,
|
|
obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
|
|
) -> "BaseQueryToolNodeMapping":
|
|
raise NotImplementedError(
|
|
"This object node mapping does not support persist method."
|
|
)
|
|
|
|
@property
|
|
def obj_node_mapping(self) -> Dict[int, Any]:
|
|
"""The mapping data structure between node and object."""
|
|
raise NotImplementedError("Subclasses should implement this!")
|
|
|
|
def persist(
|
|
self, persist_dir: str = ..., obj_node_mapping_fname: str = ...
|
|
) -> None:
|
|
"""Persist objs."""
|
|
raise NotImplementedError("Subclasses should implement this!")
|
|
|
|
|
|
class SimpleQueryToolNodeMapping(BaseQueryToolNodeMapping):
|
|
"""Simple query tool mapping."""
|
|
|
|
def __init__(self, objs: Optional[Sequence[QueryEngineTool]] = None) -> None:
|
|
objs = objs or []
|
|
self._tools = {tool.metadata.name: tool for tool in objs}
|
|
|
|
def validate_object(self, obj: QueryEngineTool) -> None:
|
|
if not isinstance(obj, QueryEngineTool):
|
|
raise ValueError(f"Object must be of type {QueryEngineTool}")
|
|
|
|
@classmethod
|
|
def from_objects(
|
|
cls, objs: Sequence[QueryEngineTool], *args: Any, **kwargs: Any
|
|
) -> "BaseObjectNodeMapping":
|
|
return cls(objs)
|
|
|
|
def _add_object(self, tool: QueryEngineTool) -> None:
|
|
if tool.metadata.name is None:
|
|
raise ValueError("Tool name must be set")
|
|
self._tools[tool.metadata.name] = tool
|
|
|
|
def to_node(self, obj: QueryEngineTool) -> TextNode:
|
|
"""To node."""
|
|
return convert_tool_to_node(obj)
|
|
|
|
def _from_node(self, node: BaseNode) -> QueryEngineTool:
|
|
"""From node."""
|
|
if node.metadata is None:
|
|
raise ValueError("Metadata must be set")
|
|
return self._tools[node.metadata["name"]]
|