faiss_rag_enterprise/llama_index/tools/ondemand_loader_tool.py

163 lines
5.7 KiB
Python

"""Ad-hoc data loader tool.
Tool that wraps any data loader, and is able to load data on-demand.
"""
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from llama_index.bridge.pydantic import BaseModel
from llama_index.indices.base import BaseIndex
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.readers.base import BaseReader
from llama_index.readers.schema.base import Document
from llama_index.tools.function_tool import FunctionTool
from llama_index.tools.types import AsyncBaseTool, ToolMetadata, ToolOutput
from llama_index.tools.utils import create_schema_from_function
class OnDemandLoaderTool(AsyncBaseTool):
"""On-demand data loader tool.
Loads data with by calling the provided loader function,
stores in index, and queries for relevant data with a
natural language query string.
"""
def __init__(
self,
loader: Callable[..., List[Document]],
index_cls: Type[BaseIndex],
index_kwargs: Dict,
metadata: ToolMetadata,
use_query_str_in_loader: bool = False,
query_str_kwargs_key: str = "query_str",
) -> None:
"""Init params."""
self._loader = loader
self._index_cls = index_cls
self._index_kwargs = index_kwargs
self._use_query_str_in_loader = use_query_str_in_loader
self._metadata = metadata
self._query_str_kwargs_key = query_str_kwargs_key
@property
def metadata(self) -> ToolMetadata:
return self._metadata
@classmethod
def from_defaults(
cls,
reader: BaseReader,
index_cls: Optional[Type[BaseIndex]] = None,
index_kwargs: Optional[Dict] = None,
use_query_str_in_loader: bool = False,
query_str_kwargs_key: str = "query_str",
name: Optional[str] = None,
description: Optional[str] = None,
fn_schema: Optional[Type[BaseModel]] = None,
) -> "OnDemandLoaderTool":
"""From defaults."""
# NOTE: fn_schema should be specified if you want to use as langchain Tool
index_cls = index_cls or VectorStoreIndex
index_kwargs = index_kwargs or {}
if description is None:
description = f"Tool to load data from {reader.__class__.__name__}"
if fn_schema is None:
fn_schema = create_schema_from_function(
name or "LoadData",
reader.load_data,
[(query_str_kwargs_key, str, None)],
)
metadata = ToolMetadata(name=name, description=description, fn_schema=fn_schema)
return cls(
loader=reader.load_data,
index_cls=index_cls,
index_kwargs=index_kwargs,
use_query_str_in_loader=use_query_str_in_loader,
query_str_kwargs_key=query_str_kwargs_key,
metadata=metadata,
)
@classmethod
def from_tool(
cls,
tool: FunctionTool,
index_cls: Optional[Type[BaseIndex]] = None,
index_kwargs: Optional[Dict] = None,
use_query_str_in_loader: bool = False,
query_str_kwargs_key: str = "query_str",
name: Optional[str] = None,
description: Optional[str] = None,
fn_schema: Optional[Type[BaseModel]] = None,
) -> "OnDemandLoaderTool":
"""From defaults."""
# NOTE: fn_schema should be specified if you want to use as langchain Tool
index_cls = index_cls or VectorStoreIndex
index_kwargs = index_kwargs or {}
if description is None:
description = f"Tool to load data from {tool.__class__.__name__}"
if fn_schema is None:
fn_schema = create_schema_from_function(
name or "LoadData", tool._fn, [(query_str_kwargs_key, str, None)]
)
metadata = ToolMetadata(name=name, description=description, fn_schema=fn_schema)
return cls(
loader=tool._fn,
index_cls=index_cls,
index_kwargs=index_kwargs,
use_query_str_in_loader=use_query_str_in_loader,
query_str_kwargs_key=query_str_kwargs_key,
metadata=metadata,
)
def _parse_args(self, *args: Any, **kwargs: Any) -> Tuple[str, List[Document]]:
if self._query_str_kwargs_key not in kwargs:
raise ValueError(
"Missing query_str in kwargs with parameter name: "
f"{self._query_str_kwargs_key}"
)
if self._use_query_str_in_loader:
query_str = kwargs[self._query_str_kwargs_key]
else:
query_str = kwargs.pop(self._query_str_kwargs_key)
docs = self._loader(*args, **kwargs)
return query_str, docs
def call(self, *args: Any, **kwargs: Any) -> ToolOutput:
"""Call."""
query_str, docs = self._parse_args(*args, **kwargs)
index = self._index_cls.from_documents(docs, **self._index_kwargs)
# TODO: add query kwargs
query_engine = index.as_query_engine()
response = query_engine.query(query_str)
return ToolOutput(
content=str(response),
tool_name=self.metadata.name,
raw_input={"query": query_str},
raw_output=response,
)
async def acall(self, *args: Any, **kwargs: Any) -> ToolOutput:
"""Async Call."""
query_str, docs = self._parse_args(*args, **kwargs)
index = self._index_cls.from_documents(docs, **self._index_kwargs)
# TODO: add query kwargs
query_engine = index.as_query_engine()
response = await query_engine.aquery(query_str)
return ToolOutput(
content=str(response),
tool_name=self.metadata.name,
raw_input={"query": query_str},
raw_output=response,
)