"""Query Pipeline.""" import json import uuid from typing import ( Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast, get_args, ) import networkx from llama_index.async_utils import run_jobs from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.core.query_pipeline.query_component import ( QUERY_COMPONENT_TYPE, ChainableMixin, InputKeys, Link, OutputKeys, QueryComponent, ) from llama_index.utils import print_text def get_output( src_key: Optional[str], output_dict: Dict[str, Any], ) -> Any: """Add input to module deps inputs.""" # get relevant output from link if src_key is None: # ensure that output_dict only has one key if len(output_dict) != 1: raise ValueError("Output dict must have exactly one key.") output = next(iter(output_dict.values())) else: output = output_dict[src_key] return output def add_output_to_module_inputs( dest_key: str, output: Any, module: QueryComponent, module_inputs: Dict[str, Any], ) -> None: """Add input to module deps inputs.""" # now attach output to relevant input key for module if dest_key is None: free_keys = module.free_req_input_keys # ensure that there is only one remaining key given partials if len(free_keys) != 1: raise ValueError( "Module input keys must have exactly one key if " "dest_key is not specified. Remaining keys: " f"in module: {free_keys}" ) module_inputs[next(iter(free_keys))] = output else: module_inputs[dest_key] = output def print_debug_input( module_key: str, input: Dict[str, Any], val_str_len: int = 200, ) -> None: """Print debug input.""" output = f"> Running module {module_key} with input: \n" for key, value in input.items(): # stringify and truncate output val_str = ( str(value)[:val_str_len] + "..." if len(str(value)) > val_str_len else str(value) ) output += f"{key}: {val_str}\n" print_text(output + "\n", color="llama_lavender") def print_debug_input_multi( module_keys: List[str], module_inputs: List[Dict[str, Any]], val_str_len: int = 200, ) -> None: """Print debug input.""" output = f"> Running modules and inputs in parallel: \n" for module_key, input in zip(module_keys, module_inputs): cur_output = f"Module key: {module_key}. Input: \n" for key, value in input.items(): # stringify and truncate output val_str = ( str(value)[:val_str_len] + "..." if len(str(value)) > val_str_len else str(value) ) cur_output += f"{key}: {val_str}\n" output += cur_output + "\n" print_text(output + "\n", color="llama_lavender") # Function to clean non-serializable attributes and return a copy of the graph # https://stackoverflow.com/questions/23268421/networkx-how-to-access-attributes-of-objects-as-nodes def clean_graph_attributes_copy(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph: # Create a deep copy of the graph to preserve the original graph_copy = graph.copy() # Iterate over nodes and clean attributes for node, attributes in graph_copy.nodes(data=True): for key, value in list(attributes.items()): if callable(value): # Checks if the value is a function del attributes[key] # Remove the attribute if it's non-serializable # Similarly, you can extend this to clean edge attributes if necessary for u, v, attributes in graph_copy.edges(data=True): for key, value in list(attributes.items()): if callable(value): # Checks if the value is a function del attributes[key] # Remove the attribute if it's non-serializable return graph_copy CHAIN_COMPONENT_TYPE = Union[QUERY_COMPONENT_TYPE, str] class QueryPipeline(QueryComponent): """A query pipeline that can allow arbitrary chaining of different modules. A pipeline itself is a query component, and can be used as a module in another pipeline. """ callback_manager: CallbackManager = Field( default_factory=lambda: CallbackManager([]), exclude=True ) module_dict: Dict[str, QueryComponent] = Field( default_factory=dict, description="The modules in the pipeline." ) dag: networkx.MultiDiGraph = Field( default_factory=networkx.MultiDiGraph, description="The DAG of the pipeline." ) verbose: bool = Field( default=False, description="Whether to print intermediate steps." ) show_progress: bool = Field( default=False, description="Whether to show progress bar (currently async only).", ) num_workers: int = Field( default=4, description="Number of workers to use (currently async only)." ) class Config: arbitrary_types_allowed = True def __init__( self, callback_manager: Optional[CallbackManager] = None, chain: Optional[Sequence[CHAIN_COMPONENT_TYPE]] = None, modules: Optional[Dict[str, QUERY_COMPONENT_TYPE]] = None, links: Optional[List[Link]] = None, **kwargs: Any, ): super().__init__( callback_manager=callback_manager or CallbackManager([]), **kwargs, ) self._init_graph(chain=chain, modules=modules, links=links) def _init_graph( self, chain: Optional[Sequence[CHAIN_COMPONENT_TYPE]] = None, modules: Optional[Dict[str, QUERY_COMPONENT_TYPE]] = None, links: Optional[List[Link]] = None, ) -> None: """Initialize graph.""" if chain is not None: if modules is not None or links is not None: raise ValueError("Cannot specify both chain and modules/links in init.") self.add_chain(chain) elif modules is not None: self.add_modules(modules) if links is not None: for link in links: self.add_link(**link.dict()) def add_chain(self, chain: Sequence[CHAIN_COMPONENT_TYPE]) -> None: """Add a chain of modules to the pipeline. This is a special form of pipeline that is purely sequential/linear. This allows a more concise way of specifying a pipeline. """ # first add all modules module_keys: List[str] = [] for module in chain: if isinstance(module, get_args(QUERY_COMPONENT_TYPE)): module_key = str(uuid.uuid4()) self.add(module_key, cast(QUERY_COMPONENT_TYPE, module)) module_keys.append(module_key) elif isinstance(module, str): module_keys.append(module) else: raise ValueError("Chain must be a sequence of modules or module keys.") # then add all links for i in range(len(chain) - 1): self.add_link(src=module_keys[i], dest=module_keys[i + 1]) def add_links( self, links: List[Link], ) -> None: """Add links to the pipeline.""" for link in links: if isinstance(link, Link): self.add_link(**link.dict()) else: raise ValueError("Link must be of type `Link` or `ConditionalLinks`.") def add_modules(self, module_dict: Dict[str, QUERY_COMPONENT_TYPE]) -> None: """Add modules to the pipeline.""" for module_key, module in module_dict.items(): self.add(module_key, module) def add(self, module_key: str, module: QUERY_COMPONENT_TYPE) -> None: """Add a module to the pipeline.""" # if already exists, raise error if module_key in self.module_dict: raise ValueError(f"Module {module_key} already exists in pipeline.") if isinstance(module, ChainableMixin): module = module.as_query_component() else: pass self.module_dict[module_key] = cast(QueryComponent, module) self.dag.add_node(module_key) def add_link( self, src: str, dest: str, src_key: Optional[str] = None, dest_key: Optional[str] = None, condition_fn: Optional[Callable] = None, input_fn: Optional[Callable] = None, ) -> None: """Add a link between two modules.""" if src not in self.module_dict: raise ValueError(f"Module {src} does not exist in pipeline.") self.dag.add_edge( src, dest, src_key=src_key, dest_key=dest_key, condition_fn=condition_fn, input_fn=input_fn, ) def get_root_keys(self) -> List[str]: """Get root keys.""" return self._get_root_keys() def get_leaf_keys(self) -> List[str]: """Get leaf keys.""" return self._get_leaf_keys() def _get_root_keys(self) -> List[str]: """Get root keys.""" return [v for v, d in self.dag.in_degree() if d == 0] def _get_leaf_keys(self) -> List[str]: """Get leaf keys.""" # get all modules without downstream dependencies return [v for v, d in self.dag.out_degree() if d == 0] def set_callback_manager(self, callback_manager: CallbackManager) -> None: """Set callback manager.""" # go through every module in module dict and set callback manager self.callback_manager = callback_manager for module in self.module_dict.values(): module.set_callback_manager(callback_manager) def run( self, *args: Any, return_values_direct: bool = True, callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> Any: """Run the pipeline.""" # first set callback manager callback_manager = callback_manager or self.callback_manager self.set_callback_manager(callback_manager) with self.callback_manager.as_trace("query"): # try to get query payload try: query_payload = json.dumps(kwargs) except TypeError: query_payload = json.dumps(str(kwargs)) with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_payload} ) as query_event: return self._run( *args, return_values_direct=return_values_direct, **kwargs ) def run_multi( self, module_input_dict: Dict[str, Any], callback_manager: Optional[CallbackManager] = None, ) -> Dict[str, Any]: """Run the pipeline for multiple roots.""" callback_manager = callback_manager or self.callback_manager self.set_callback_manager(callback_manager) with self.callback_manager.as_trace("query"): with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)}, ) as query_event: return self._run_multi(module_input_dict) async def arun( self, *args: Any, return_values_direct: bool = True, callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> Any: """Run the pipeline.""" # first set callback manager callback_manager = callback_manager or self.callback_manager self.set_callback_manager(callback_manager) with self.callback_manager.as_trace("query"): try: query_payload = json.dumps(kwargs) except TypeError: query_payload = json.dumps(str(kwargs)) with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_payload} ) as query_event: return await self._arun( *args, return_values_direct=return_values_direct, **kwargs ) async def arun_multi( self, module_input_dict: Dict[str, Any], callback_manager: Optional[CallbackManager] = None, ) -> Dict[str, Any]: """Run the pipeline for multiple roots.""" callback_manager = callback_manager or self.callback_manager self.set_callback_manager(callback_manager) with self.callback_manager.as_trace("query"): with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)}, ) as query_event: return await self._arun_multi(module_input_dict) def _get_root_key_and_kwargs( self, *args: Any, **kwargs: Any ) -> Tuple[str, Dict[str, Any]]: """Get root key and kwargs. This is for `_run`. """ ## run pipeline ## assume there is only one root - for multiple roots, need to specify `run_multi` root_keys = self._get_root_keys() if len(root_keys) != 1: raise ValueError("Only one root is supported.") root_key = root_keys[0] root_module = self.module_dict[root_key] if len(args) > 0: # if args is specified, validate. only one arg is allowed, and there can only be one free # input key in the module if len(args) > 1: raise ValueError("Only one arg is allowed.") if len(kwargs) > 0: raise ValueError("No kwargs allowed if args is specified.") if len(root_module.free_req_input_keys) != 1: raise ValueError("Only one free input key is allowed.") # set kwargs kwargs[next(iter(root_module.free_req_input_keys))] = args[0] return root_key, kwargs def _get_single_result_output( self, result_outputs: Dict[str, Any], return_values_direct: bool, ) -> Any: """Get result output from a single module. If output dict is a single key, return the value directly if return_values_direct is True. """ if len(result_outputs) != 1: raise ValueError("Only one output is supported.") result_output = next(iter(result_outputs.values())) # return_values_direct: if True, return the value directly # without the key # if it's a dict with one key, return the value if ( isinstance(result_output, dict) and len(result_output) == 1 and return_values_direct ): return next(iter(result_output.values())) else: return result_output def _run(self, *args: Any, return_values_direct: bool = True, **kwargs: Any) -> Any: """Run the pipeline. Assume that there is a single root module and a single output module. For multi-input and multi-outputs, please see `run_multi`. """ root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs) # call run_multi with one root key result_outputs = self._run_multi({root_key: kwargs}) return self._get_single_result_output(result_outputs, return_values_direct) async def _arun( self, *args: Any, return_values_direct: bool = True, **kwargs: Any ) -> Any: """Run the pipeline. Assume that there is a single root module and a single output module. For multi-input and multi-outputs, please see `run_multi`. """ root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs) # call run_multi with one root key result_outputs = await self._arun_multi({root_key: kwargs}) return self._get_single_result_output(result_outputs, return_values_direct) def _validate_inputs(self, module_input_dict: Dict[str, Any]) -> None: root_keys = self._get_root_keys() # if root keys don't match up with kwargs keys, raise error if set(root_keys) != set(module_input_dict.keys()): raise ValueError( "Expected root keys do not match up with input keys.\n" f"Expected root keys: {root_keys}\n" f"Input keys: {module_input_dict.keys()}\n" ) def _process_component_output( self, queue: List[str], output_dict: Dict[str, Any], module_key: str, all_module_inputs: Dict[str, Dict[str, Any]], result_outputs: Dict[str, Any], ) -> List[str]: """Process component output.""" new_queue = queue.copy() # if there's no more edges, add result to output if module_key in self._get_leaf_keys(): result_outputs[module_key] = output_dict else: edge_list = list(self.dag.edges(module_key, data=True)) # everything not in conditional_edge_list is regular for _, dest, attr in edge_list: output = get_output(attr.get("src_key"), output_dict) # if input_fn is not None, use it to modify the input if attr["input_fn"] is not None: dest_output = attr["input_fn"](output) else: dest_output = output add_edge = True if attr["condition_fn"] is not None: conditional_val = attr["condition_fn"](output) if not conditional_val: add_edge = False if add_edge: add_output_to_module_inputs( attr.get("dest_key"), dest_output, self.module_dict[dest], all_module_inputs[dest], ) else: # remove dest from queue new_queue.remove(dest) return new_queue def _run_multi(self, module_input_dict: Dict[str, Any]) -> Dict[str, Any]: """Run the pipeline for multiple roots. kwargs is in the form of module_dict -> input_dict input_dict is in the form of input_key -> input """ self._validate_inputs(module_input_dict) queue = list(networkx.topological_sort(self.dag)) # module_deps_inputs is a dict to collect inputs for a module # mapping of module_key -> dict of input_key -> input # initialize with blank dict for every module key # the input dict of each module key will be populated as the upstream modules are run all_module_inputs: Dict[str, Dict[str, Any]] = { module_key: {} for module_key in self.module_dict } result_outputs: Dict[str, Any] = {} # add root inputs to all_module_inputs for module_key, module_input in module_input_dict.items(): all_module_inputs[module_key] = module_input while len(queue) > 0: module_key = queue.pop(0) module = self.module_dict[module_key] module_input = all_module_inputs[module_key] if self.verbose: print_debug_input(module_key, module_input) output_dict = module.run_component(**module_input) # get new nodes and is_leaf queue = self._process_component_output( queue, output_dict, module_key, all_module_inputs, result_outputs ) return result_outputs async def _arun_multi(self, module_input_dict: Dict[str, Any]) -> Dict[str, Any]: """Run the pipeline for multiple roots. kwargs is in the form of module_dict -> input_dict input_dict is in the form of input_key -> input """ self._validate_inputs(module_input_dict) queue = list(networkx.topological_sort(self.dag)) # module_deps_inputs is a dict to collect inputs for a module # mapping of module_key -> dict of input_key -> input # initialize with blank dict for every module key # the input dict of each module key will be populated as the upstream modules are run all_module_inputs: Dict[str, Dict[str, Any]] = { module_key: {} for module_key in self.module_dict } result_outputs: Dict[str, Any] = {} # add root inputs to all_module_inputs for module_key, module_input in module_input_dict.items(): all_module_inputs[module_key] = module_input while len(queue) > 0: popped_indices = set() popped_nodes = [] # get subset of nodes who don't have ancestors also in the queue # these are tasks that are parallelizable for i, module_key in enumerate(queue): module_ancestors = networkx.ancestors(self.dag, module_key) if len(set(module_ancestors).intersection(queue)) == 0: popped_indices.add(i) popped_nodes.append(module_key) # update queue queue = [ module_key for i, module_key in enumerate(queue) if i not in popped_indices ] if self.verbose: print_debug_input_multi( popped_nodes, [all_module_inputs[module_key] for module_key in popped_nodes], ) # create tasks from popped nodes tasks = [] for module_key in popped_nodes: module = self.module_dict[module_key] module_input = all_module_inputs[module_key] tasks.append(module.arun_component(**module_input)) # run tasks output_dicts = await run_jobs( tasks, show_progress=self.show_progress, workers=self.num_workers ) for output_dict, module_key in zip(output_dicts, popped_nodes): # get new nodes and is_leaf queue = self._process_component_output( queue, output_dict, module_key, all_module_inputs, result_outputs ) return result_outputs def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]: """Validate component inputs during run_component.""" raise NotImplementedError def validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]: """Validate component inputs.""" return input def _validate_component_outputs(self, input: Dict[str, Any]) -> Dict[str, Any]: raise NotImplementedError def validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]: """Validate component outputs.""" # NOTE: we override this to do nothing return output def _run_component(self, **kwargs: Any) -> Dict[str, Any]: """Run component.""" return self.run(return_values_direct=False, **kwargs) async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]: """Run component.""" return await self.arun(return_values_direct=False, **kwargs) @property def input_keys(self) -> InputKeys: """Input keys.""" # get input key of first module root_keys = self._get_root_keys() if len(root_keys) != 1: raise ValueError("Only one root is supported.") root_module = self.module_dict[root_keys[0]] return root_module.input_keys @property def output_keys(self) -> OutputKeys: """Output keys.""" # get output key of last module leaf_keys = self._get_leaf_keys() if len(leaf_keys) != 1: raise ValueError("Only one leaf is supported.") leaf_module = self.module_dict[leaf_keys[0]] return leaf_module.output_keys @property def sub_query_components(self) -> List[QueryComponent]: """Sub query components.""" return list(self.module_dict.values()) @property def clean_dag(self) -> networkx.DiGraph: """Clean dag.""" return clean_graph_attributes_copy(self.dag)