faiss_rag_enterprise/llama_index/query_pipeline/query.py

673 lines
24 KiB
Python

"""Query Pipeline."""
import json
import uuid
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
get_args,
)
import networkx
from llama_index.async_utils import run_jobs
from llama_index.bridge.pydantic import Field
from llama_index.callbacks import CallbackManager
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.core.query_pipeline.query_component import (
QUERY_COMPONENT_TYPE,
ChainableMixin,
InputKeys,
Link,
OutputKeys,
QueryComponent,
)
from llama_index.utils import print_text
def get_output(
src_key: Optional[str],
output_dict: Dict[str, Any],
) -> Any:
"""Add input to module deps inputs."""
# get relevant output from link
if src_key is None:
# ensure that output_dict only has one key
if len(output_dict) != 1:
raise ValueError("Output dict must have exactly one key.")
output = next(iter(output_dict.values()))
else:
output = output_dict[src_key]
return output
def add_output_to_module_inputs(
dest_key: str,
output: Any,
module: QueryComponent,
module_inputs: Dict[str, Any],
) -> None:
"""Add input to module deps inputs."""
# now attach output to relevant input key for module
if dest_key is None:
free_keys = module.free_req_input_keys
# ensure that there is only one remaining key given partials
if len(free_keys) != 1:
raise ValueError(
"Module input keys must have exactly one key if "
"dest_key is not specified. Remaining keys: "
f"in module: {free_keys}"
)
module_inputs[next(iter(free_keys))] = output
else:
module_inputs[dest_key] = output
def print_debug_input(
module_key: str,
input: Dict[str, Any],
val_str_len: int = 200,
) -> None:
"""Print debug input."""
output = f"> Running module {module_key} with input: \n"
for key, value in input.items():
# stringify and truncate output
val_str = (
str(value)[:val_str_len] + "..."
if len(str(value)) > val_str_len
else str(value)
)
output += f"{key}: {val_str}\n"
print_text(output + "\n", color="llama_lavender")
def print_debug_input_multi(
module_keys: List[str],
module_inputs: List[Dict[str, Any]],
val_str_len: int = 200,
) -> None:
"""Print debug input."""
output = f"> Running modules and inputs in parallel: \n"
for module_key, input in zip(module_keys, module_inputs):
cur_output = f"Module key: {module_key}. Input: \n"
for key, value in input.items():
# stringify and truncate output
val_str = (
str(value)[:val_str_len] + "..."
if len(str(value)) > val_str_len
else str(value)
)
cur_output += f"{key}: {val_str}\n"
output += cur_output + "\n"
print_text(output + "\n", color="llama_lavender")
# Function to clean non-serializable attributes and return a copy of the graph
# https://stackoverflow.com/questions/23268421/networkx-how-to-access-attributes-of-objects-as-nodes
def clean_graph_attributes_copy(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
# Create a deep copy of the graph to preserve the original
graph_copy = graph.copy()
# Iterate over nodes and clean attributes
for node, attributes in graph_copy.nodes(data=True):
for key, value in list(attributes.items()):
if callable(value): # Checks if the value is a function
del attributes[key] # Remove the attribute if it's non-serializable
# Similarly, you can extend this to clean edge attributes if necessary
for u, v, attributes in graph_copy.edges(data=True):
for key, value in list(attributes.items()):
if callable(value): # Checks if the value is a function
del attributes[key] # Remove the attribute if it's non-serializable
return graph_copy
CHAIN_COMPONENT_TYPE = Union[QUERY_COMPONENT_TYPE, str]
class QueryPipeline(QueryComponent):
"""A query pipeline that can allow arbitrary chaining of different modules.
A pipeline itself is a query component, and can be used as a module in another pipeline.
"""
callback_manager: CallbackManager = Field(
default_factory=lambda: CallbackManager([]), exclude=True
)
module_dict: Dict[str, QueryComponent] = Field(
default_factory=dict, description="The modules in the pipeline."
)
dag: networkx.MultiDiGraph = Field(
default_factory=networkx.MultiDiGraph, description="The DAG of the pipeline."
)
verbose: bool = Field(
default=False, description="Whether to print intermediate steps."
)
show_progress: bool = Field(
default=False,
description="Whether to show progress bar (currently async only).",
)
num_workers: int = Field(
default=4, description="Number of workers to use (currently async only)."
)
class Config:
arbitrary_types_allowed = True
def __init__(
self,
callback_manager: Optional[CallbackManager] = None,
chain: Optional[Sequence[CHAIN_COMPONENT_TYPE]] = None,
modules: Optional[Dict[str, QUERY_COMPONENT_TYPE]] = None,
links: Optional[List[Link]] = None,
**kwargs: Any,
):
super().__init__(
callback_manager=callback_manager or CallbackManager([]),
**kwargs,
)
self._init_graph(chain=chain, modules=modules, links=links)
def _init_graph(
self,
chain: Optional[Sequence[CHAIN_COMPONENT_TYPE]] = None,
modules: Optional[Dict[str, QUERY_COMPONENT_TYPE]] = None,
links: Optional[List[Link]] = None,
) -> None:
"""Initialize graph."""
if chain is not None:
if modules is not None or links is not None:
raise ValueError("Cannot specify both chain and modules/links in init.")
self.add_chain(chain)
elif modules is not None:
self.add_modules(modules)
if links is not None:
for link in links:
self.add_link(**link.dict())
def add_chain(self, chain: Sequence[CHAIN_COMPONENT_TYPE]) -> None:
"""Add a chain of modules to the pipeline.
This is a special form of pipeline that is purely sequential/linear.
This allows a more concise way of specifying a pipeline.
"""
# first add all modules
module_keys: List[str] = []
for module in chain:
if isinstance(module, get_args(QUERY_COMPONENT_TYPE)):
module_key = str(uuid.uuid4())
self.add(module_key, cast(QUERY_COMPONENT_TYPE, module))
module_keys.append(module_key)
elif isinstance(module, str):
module_keys.append(module)
else:
raise ValueError("Chain must be a sequence of modules or module keys.")
# then add all links
for i in range(len(chain) - 1):
self.add_link(src=module_keys[i], dest=module_keys[i + 1])
def add_links(
self,
links: List[Link],
) -> None:
"""Add links to the pipeline."""
for link in links:
if isinstance(link, Link):
self.add_link(**link.dict())
else:
raise ValueError("Link must be of type `Link` or `ConditionalLinks`.")
def add_modules(self, module_dict: Dict[str, QUERY_COMPONENT_TYPE]) -> None:
"""Add modules to the pipeline."""
for module_key, module in module_dict.items():
self.add(module_key, module)
def add(self, module_key: str, module: QUERY_COMPONENT_TYPE) -> None:
"""Add a module to the pipeline."""
# if already exists, raise error
if module_key in self.module_dict:
raise ValueError(f"Module {module_key} already exists in pipeline.")
if isinstance(module, ChainableMixin):
module = module.as_query_component()
else:
pass
self.module_dict[module_key] = cast(QueryComponent, module)
self.dag.add_node(module_key)
def add_link(
self,
src: str,
dest: str,
src_key: Optional[str] = None,
dest_key: Optional[str] = None,
condition_fn: Optional[Callable] = None,
input_fn: Optional[Callable] = None,
) -> None:
"""Add a link between two modules."""
if src not in self.module_dict:
raise ValueError(f"Module {src} does not exist in pipeline.")
self.dag.add_edge(
src,
dest,
src_key=src_key,
dest_key=dest_key,
condition_fn=condition_fn,
input_fn=input_fn,
)
def get_root_keys(self) -> List[str]:
"""Get root keys."""
return self._get_root_keys()
def get_leaf_keys(self) -> List[str]:
"""Get leaf keys."""
return self._get_leaf_keys()
def _get_root_keys(self) -> List[str]:
"""Get root keys."""
return [v for v, d in self.dag.in_degree() if d == 0]
def _get_leaf_keys(self) -> List[str]:
"""Get leaf keys."""
# get all modules without downstream dependencies
return [v for v, d in self.dag.out_degree() if d == 0]
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# go through every module in module dict and set callback manager
self.callback_manager = callback_manager
for module in self.module_dict.values():
module.set_callback_manager(callback_manager)
def run(
self,
*args: Any,
return_values_direct: bool = True,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> Any:
"""Run the pipeline."""
# first set callback manager
callback_manager = callback_manager or self.callback_manager
self.set_callback_manager(callback_manager)
with self.callback_manager.as_trace("query"):
# try to get query payload
try:
query_payload = json.dumps(kwargs)
except TypeError:
query_payload = json.dumps(str(kwargs))
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_payload}
) as query_event:
return self._run(
*args, return_values_direct=return_values_direct, **kwargs
)
def run_multi(
self,
module_input_dict: Dict[str, Any],
callback_manager: Optional[CallbackManager] = None,
) -> Dict[str, Any]:
"""Run the pipeline for multiple roots."""
callback_manager = callback_manager or self.callback_manager
self.set_callback_manager(callback_manager)
with self.callback_manager.as_trace("query"):
with self.callback_manager.event(
CBEventType.QUERY,
payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)},
) as query_event:
return self._run_multi(module_input_dict)
async def arun(
self,
*args: Any,
return_values_direct: bool = True,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> Any:
"""Run the pipeline."""
# first set callback manager
callback_manager = callback_manager or self.callback_manager
self.set_callback_manager(callback_manager)
with self.callback_manager.as_trace("query"):
try:
query_payload = json.dumps(kwargs)
except TypeError:
query_payload = json.dumps(str(kwargs))
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_payload}
) as query_event:
return await self._arun(
*args, return_values_direct=return_values_direct, **kwargs
)
async def arun_multi(
self,
module_input_dict: Dict[str, Any],
callback_manager: Optional[CallbackManager] = None,
) -> Dict[str, Any]:
"""Run the pipeline for multiple roots."""
callback_manager = callback_manager or self.callback_manager
self.set_callback_manager(callback_manager)
with self.callback_manager.as_trace("query"):
with self.callback_manager.event(
CBEventType.QUERY,
payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)},
) as query_event:
return await self._arun_multi(module_input_dict)
def _get_root_key_and_kwargs(
self, *args: Any, **kwargs: Any
) -> Tuple[str, Dict[str, Any]]:
"""Get root key and kwargs.
This is for `_run`.
"""
## run pipeline
## assume there is only one root - for multiple roots, need to specify `run_multi`
root_keys = self._get_root_keys()
if len(root_keys) != 1:
raise ValueError("Only one root is supported.")
root_key = root_keys[0]
root_module = self.module_dict[root_key]
if len(args) > 0:
# if args is specified, validate. only one arg is allowed, and there can only be one free
# input key in the module
if len(args) > 1:
raise ValueError("Only one arg is allowed.")
if len(kwargs) > 0:
raise ValueError("No kwargs allowed if args is specified.")
if len(root_module.free_req_input_keys) != 1:
raise ValueError("Only one free input key is allowed.")
# set kwargs
kwargs[next(iter(root_module.free_req_input_keys))] = args[0]
return root_key, kwargs
def _get_single_result_output(
self,
result_outputs: Dict[str, Any],
return_values_direct: bool,
) -> Any:
"""Get result output from a single module.
If output dict is a single key, return the value directly
if return_values_direct is True.
"""
if len(result_outputs) != 1:
raise ValueError("Only one output is supported.")
result_output = next(iter(result_outputs.values()))
# return_values_direct: if True, return the value directly
# without the key
# if it's a dict with one key, return the value
if (
isinstance(result_output, dict)
and len(result_output) == 1
and return_values_direct
):
return next(iter(result_output.values()))
else:
return result_output
def _run(self, *args: Any, return_values_direct: bool = True, **kwargs: Any) -> Any:
"""Run the pipeline.
Assume that there is a single root module and a single output module.
For multi-input and multi-outputs, please see `run_multi`.
"""
root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs)
# call run_multi with one root key
result_outputs = self._run_multi({root_key: kwargs})
return self._get_single_result_output(result_outputs, return_values_direct)
async def _arun(
self, *args: Any, return_values_direct: bool = True, **kwargs: Any
) -> Any:
"""Run the pipeline.
Assume that there is a single root module and a single output module.
For multi-input and multi-outputs, please see `run_multi`.
"""
root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs)
# call run_multi with one root key
result_outputs = await self._arun_multi({root_key: kwargs})
return self._get_single_result_output(result_outputs, return_values_direct)
def _validate_inputs(self, module_input_dict: Dict[str, Any]) -> None:
root_keys = self._get_root_keys()
# if root keys don't match up with kwargs keys, raise error
if set(root_keys) != set(module_input_dict.keys()):
raise ValueError(
"Expected root keys do not match up with input keys.\n"
f"Expected root keys: {root_keys}\n"
f"Input keys: {module_input_dict.keys()}\n"
)
def _process_component_output(
self,
queue: List[str],
output_dict: Dict[str, Any],
module_key: str,
all_module_inputs: Dict[str, Dict[str, Any]],
result_outputs: Dict[str, Any],
) -> List[str]:
"""Process component output."""
new_queue = queue.copy()
# if there's no more edges, add result to output
if module_key in self._get_leaf_keys():
result_outputs[module_key] = output_dict
else:
edge_list = list(self.dag.edges(module_key, data=True))
# everything not in conditional_edge_list is regular
for _, dest, attr in edge_list:
output = get_output(attr.get("src_key"), output_dict)
# if input_fn is not None, use it to modify the input
if attr["input_fn"] is not None:
dest_output = attr["input_fn"](output)
else:
dest_output = output
add_edge = True
if attr["condition_fn"] is not None:
conditional_val = attr["condition_fn"](output)
if not conditional_val:
add_edge = False
if add_edge:
add_output_to_module_inputs(
attr.get("dest_key"),
dest_output,
self.module_dict[dest],
all_module_inputs[dest],
)
else:
# remove dest from queue
new_queue.remove(dest)
return new_queue
def _run_multi(self, module_input_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Run the pipeline for multiple roots.
kwargs is in the form of module_dict -> input_dict
input_dict is in the form of input_key -> input
"""
self._validate_inputs(module_input_dict)
queue = list(networkx.topological_sort(self.dag))
# module_deps_inputs is a dict to collect inputs for a module
# mapping of module_key -> dict of input_key -> input
# initialize with blank dict for every module key
# the input dict of each module key will be populated as the upstream modules are run
all_module_inputs: Dict[str, Dict[str, Any]] = {
module_key: {} for module_key in self.module_dict
}
result_outputs: Dict[str, Any] = {}
# add root inputs to all_module_inputs
for module_key, module_input in module_input_dict.items():
all_module_inputs[module_key] = module_input
while len(queue) > 0:
module_key = queue.pop(0)
module = self.module_dict[module_key]
module_input = all_module_inputs[module_key]
if self.verbose:
print_debug_input(module_key, module_input)
output_dict = module.run_component(**module_input)
# get new nodes and is_leaf
queue = self._process_component_output(
queue, output_dict, module_key, all_module_inputs, result_outputs
)
return result_outputs
async def _arun_multi(self, module_input_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Run the pipeline for multiple roots.
kwargs is in the form of module_dict -> input_dict
input_dict is in the form of input_key -> input
"""
self._validate_inputs(module_input_dict)
queue = list(networkx.topological_sort(self.dag))
# module_deps_inputs is a dict to collect inputs for a module
# mapping of module_key -> dict of input_key -> input
# initialize with blank dict for every module key
# the input dict of each module key will be populated as the upstream modules are run
all_module_inputs: Dict[str, Dict[str, Any]] = {
module_key: {} for module_key in self.module_dict
}
result_outputs: Dict[str, Any] = {}
# add root inputs to all_module_inputs
for module_key, module_input in module_input_dict.items():
all_module_inputs[module_key] = module_input
while len(queue) > 0:
popped_indices = set()
popped_nodes = []
# get subset of nodes who don't have ancestors also in the queue
# these are tasks that are parallelizable
for i, module_key in enumerate(queue):
module_ancestors = networkx.ancestors(self.dag, module_key)
if len(set(module_ancestors).intersection(queue)) == 0:
popped_indices.add(i)
popped_nodes.append(module_key)
# update queue
queue = [
module_key
for i, module_key in enumerate(queue)
if i not in popped_indices
]
if self.verbose:
print_debug_input_multi(
popped_nodes,
[all_module_inputs[module_key] for module_key in popped_nodes],
)
# create tasks from popped nodes
tasks = []
for module_key in popped_nodes:
module = self.module_dict[module_key]
module_input = all_module_inputs[module_key]
tasks.append(module.arun_component(**module_input))
# run tasks
output_dicts = await run_jobs(
tasks, show_progress=self.show_progress, workers=self.num_workers
)
for output_dict, module_key in zip(output_dicts, popped_nodes):
# get new nodes and is_leaf
queue = self._process_component_output(
queue, output_dict, module_key, all_module_inputs, result_outputs
)
return result_outputs
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
raise NotImplementedError
def validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs."""
return input
def _validate_component_outputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
raise NotImplementedError
def validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component outputs."""
# NOTE: we override this to do nothing
return output
def _run_component(self, **kwargs: Any) -> Dict[str, Any]:
"""Run component."""
return self.run(return_values_direct=False, **kwargs)
async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
"""Run component."""
return await self.arun(return_values_direct=False, **kwargs)
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
# get input key of first module
root_keys = self._get_root_keys()
if len(root_keys) != 1:
raise ValueError("Only one root is supported.")
root_module = self.module_dict[root_keys[0]]
return root_module.input_keys
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
# get output key of last module
leaf_keys = self._get_leaf_keys()
if len(leaf_keys) != 1:
raise ValueError("Only one leaf is supported.")
leaf_module = self.module_dict[leaf_keys[0]]
return leaf_module.output_keys
@property
def sub_query_components(self) -> List[QueryComponent]:
"""Sub query components."""
return list(self.module_dict.values())
@property
def clean_dag(self) -> networkx.DiGraph:
"""Clean dag."""
return clean_graph_attributes_copy(self.dag)