faiss_rag_enterprise/llama_index/objects/tool_node_mapping.py

148 lines
4.7 KiB
Python

"""Tool mapping."""
from typing import Any, Dict, Optional, Sequence
from llama_index.objects.base_node_mapping import (
DEFAULT_PERSIST_DIR,
DEFAULT_PERSIST_FNAME,
BaseObjectNodeMapping,
)
from llama_index.schema import BaseNode, TextNode
from llama_index.tools.query_engine import QueryEngineTool
from llama_index.tools.types import BaseTool
def convert_tool_to_node(tool: BaseTool) -> TextNode:
"""Function convert Tool to node."""
node_text = (
f"Tool name: {tool.metadata.name}\n"
f"Tool description: {tool.metadata.description}\n"
)
if tool.metadata.fn_schema is not None:
node_text += f"Tool schema: {tool.metadata.fn_schema.schema()}\n"
return TextNode(
text=node_text,
metadata={"name": tool.metadata.name},
excluded_embed_metadata_keys=["name"],
excluded_llm_metadata_keys=["name"],
)
class BaseToolNodeMapping(BaseObjectNodeMapping[BaseTool]):
"""Base Tool node mapping."""
def validate_object(self, obj: BaseTool) -> None:
if not isinstance(obj, BaseTool):
raise ValueError(f"Object must be of type {BaseTool}")
@property
def obj_node_mapping(self) -> Dict[int, Any]:
"""The mapping data structure between node and object."""
raise NotImplementedError("Subclasses should implement this!")
def persist(
self, persist_dir: str = ..., obj_node_mapping_fname: str = ...
) -> None:
"""Persist objs."""
raise NotImplementedError("Subclasses should implement this!")
@classmethod
def from_persist_dir(
cls,
persist_dir: str = DEFAULT_PERSIST_DIR,
obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
) -> "BaseToolNodeMapping":
raise NotImplementedError(
"This object node mapping does not support persist method."
)
class SimpleToolNodeMapping(BaseToolNodeMapping):
"""Simple Tool mapping.
In this setup, we assume that the tool name is unique, and
that the list of all tools are stored in memory.
"""
def __init__(self, objs: Optional[Sequence[BaseTool]] = None) -> None:
objs = objs or []
self._tools = {tool.metadata.name: tool for tool in objs}
@classmethod
def from_objects(
cls, objs: Sequence[BaseTool], *args: Any, **kwargs: Any
) -> "BaseObjectNodeMapping":
return cls(objs)
def _add_object(self, tool: BaseTool) -> None:
self._tools[tool.metadata.name] = tool
def to_node(self, tool: BaseTool) -> TextNode:
"""To node."""
return convert_tool_to_node(tool)
def _from_node(self, node: BaseNode) -> BaseTool:
"""From node."""
if node.metadata is None:
raise ValueError("Metadata must be set")
return self._tools[node.metadata["name"]]
class BaseQueryToolNodeMapping(BaseObjectNodeMapping[QueryEngineTool]):
"""Base query tool node mapping."""
@classmethod
def from_persist_dir(
cls,
persist_dir: str = DEFAULT_PERSIST_DIR,
obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
) -> "BaseQueryToolNodeMapping":
raise NotImplementedError(
"This object node mapping does not support persist method."
)
@property
def obj_node_mapping(self) -> Dict[int, Any]:
"""The mapping data structure between node and object."""
raise NotImplementedError("Subclasses should implement this!")
def persist(
self, persist_dir: str = ..., obj_node_mapping_fname: str = ...
) -> None:
"""Persist objs."""
raise NotImplementedError("Subclasses should implement this!")
class SimpleQueryToolNodeMapping(BaseQueryToolNodeMapping):
"""Simple query tool mapping."""
def __init__(self, objs: Optional[Sequence[QueryEngineTool]] = None) -> None:
objs = objs or []
self._tools = {tool.metadata.name: tool for tool in objs}
def validate_object(self, obj: QueryEngineTool) -> None:
if not isinstance(obj, QueryEngineTool):
raise ValueError(f"Object must be of type {QueryEngineTool}")
@classmethod
def from_objects(
cls, objs: Sequence[QueryEngineTool], *args: Any, **kwargs: Any
) -> "BaseObjectNodeMapping":
return cls(objs)
def _add_object(self, tool: QueryEngineTool) -> None:
if tool.metadata.name is None:
raise ValueError("Tool name must be set")
self._tools[tool.metadata.name] = tool
def to_node(self, obj: QueryEngineTool) -> TextNode:
"""To node."""
return convert_tool_to_node(obj)
def _from_node(self, node: BaseNode) -> QueryEngineTool:
"""From node."""
if node.metadata is None:
raise ValueError("Metadata must be set")
return self._tools[node.metadata["name"]]