"""Node postprocessor.""" import logging from typing import Dict, List, Optional, cast from llama_index.bridge.pydantic import Field, validator from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts.base import PromptTemplate from llama_index.response_synthesizers import ResponseMode, get_response_synthesizer from llama_index.schema import NodeRelationship, NodeWithScore, QueryBundle from llama_index.service_context import ServiceContext from llama_index.storage.docstore import BaseDocumentStore logger = logging.getLogger(__name__) class KeywordNodePostprocessor(BaseNodePostprocessor): """Keyword-based Node processor.""" required_keywords: List[str] = Field(default_factory=list) exclude_keywords: List[str] = Field(default_factory=list) lang: str = Field(default="en") @classmethod def class_name(cls) -> str: return "KeywordNodePostprocessor" def _postprocess_nodes( self, nodes: List[NodeWithScore], query_bundle: Optional[QueryBundle] = None, ) -> List[NodeWithScore]: """Postprocess nodes.""" try: import spacy except ImportError: raise ImportError( "Spacy is not installed, please install it with `pip install spacy`." ) from spacy.matcher import PhraseMatcher nlp = spacy.blank(self.lang) required_matcher = PhraseMatcher(nlp.vocab) exclude_matcher = PhraseMatcher(nlp.vocab) required_matcher.add("RequiredKeywords", list(nlp.pipe(self.required_keywords))) exclude_matcher.add("ExcludeKeywords", list(nlp.pipe(self.exclude_keywords))) new_nodes = [] for node_with_score in nodes: node = node_with_score.node doc = nlp(node.get_content()) if self.required_keywords and not required_matcher(doc): continue if self.exclude_keywords and exclude_matcher(doc): continue new_nodes.append(node_with_score) return new_nodes class SimilarityPostprocessor(BaseNodePostprocessor): """Similarity-based Node processor.""" similarity_cutoff: float = Field(default=None) @classmethod def class_name(cls) -> str: return "SimilarityPostprocessor" def _postprocess_nodes( self, nodes: List[NodeWithScore], query_bundle: Optional[QueryBundle] = None, ) -> List[NodeWithScore]: """Postprocess nodes.""" sim_cutoff_exists = self.similarity_cutoff is not None new_nodes = [] for node in nodes: should_use_node = True if sim_cutoff_exists: similarity = node.score if similarity is None: should_use_node = False elif cast(float, similarity) < cast(float, self.similarity_cutoff): should_use_node = False if should_use_node: new_nodes.append(node) return new_nodes def get_forward_nodes( node_with_score: NodeWithScore, num_nodes: int, docstore: BaseDocumentStore ) -> Dict[str, NodeWithScore]: """Get forward nodes.""" node = node_with_score.node nodes: Dict[str, NodeWithScore] = {node.node_id: node_with_score} cur_count = 0 # get forward nodes in an iterative manner while cur_count < num_nodes: if NodeRelationship.NEXT not in node.relationships: break next_node_info = node.next_node if next_node_info is None: break next_node_id = next_node_info.node_id next_node = docstore.get_node(next_node_id) nodes[next_node.node_id] = NodeWithScore(node=next_node) node = next_node cur_count += 1 return nodes def get_backward_nodes( node_with_score: NodeWithScore, num_nodes: int, docstore: BaseDocumentStore ) -> Dict[str, NodeWithScore]: """Get backward nodes.""" node = node_with_score.node # get backward nodes in an iterative manner nodes: Dict[str, NodeWithScore] = {node.node_id: node_with_score} cur_count = 0 while cur_count < num_nodes: prev_node_info = node.prev_node if prev_node_info is None: break prev_node_id = prev_node_info.node_id prev_node = docstore.get_node(prev_node_id) if prev_node is None: break nodes[prev_node.node_id] = NodeWithScore(node=prev_node) node = prev_node cur_count += 1 return nodes class PrevNextNodePostprocessor(BaseNodePostprocessor): """Previous/Next Node post-processor. Allows users to fetch additional nodes from the document store, based on the relationships of the nodes. NOTE: this is a beta feature. Args: docstore (BaseDocumentStore): The document store. num_nodes (int): The number of nodes to return (default: 1) mode (str): The mode of the post-processor. Can be "previous", "next", or "both. """ docstore: BaseDocumentStore num_nodes: int = Field(default=1) mode: str = Field(default="next") @validator("mode") def _validate_mode(cls, v: str) -> str: """Validate mode.""" if v not in ["next", "previous", "both"]: raise ValueError(f"Invalid mode: {v}") return v @classmethod def class_name(cls) -> str: return "PrevNextNodePostprocessor" def _postprocess_nodes( self, nodes: List[NodeWithScore], query_bundle: Optional[QueryBundle] = None, ) -> List[NodeWithScore]: """Postprocess nodes.""" all_nodes: Dict[str, NodeWithScore] = {} for node in nodes: all_nodes[node.node.node_id] = node if self.mode == "next": all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore)) elif self.mode == "previous": all_nodes.update( get_backward_nodes(node, self.num_nodes, self.docstore) ) elif self.mode == "both": all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore)) all_nodes.update( get_backward_nodes(node, self.num_nodes, self.docstore) ) else: raise ValueError(f"Invalid mode: {self.mode}") all_nodes_values: List[NodeWithScore] = list(all_nodes.values()) sorted_nodes: List[NodeWithScore] = [] for node in all_nodes_values: # variable to check if cand node is inserted node_inserted = False for i, cand in enumerate(sorted_nodes): node_id = node.node.node_id # prepend to current candidate prev_node_info = cand.node.prev_node next_node_info = cand.node.next_node if prev_node_info is not None and node_id == prev_node_info.node_id: node_inserted = True sorted_nodes.insert(i, node) break # append to current candidate elif next_node_info is not None and node_id == next_node_info.node_id: node_inserted = True sorted_nodes.insert(i + 1, node) break if not node_inserted: sorted_nodes.append(node) return sorted_nodes DEFAULT_INFER_PREV_NEXT_TMPL = ( "The current context information is provided. \n" "A question is also provided. \n" "You are a retrieval agent deciding whether to search the " "document store for additional prior context or future context. \n" "Given the context and question, return PREVIOUS or NEXT or NONE. \n" "Examples: \n\n" "Context: Describes the author's experience at Y Combinator." "Question: What did the author do after his time at Y Combinator? \n" "Answer: NEXT \n\n" "Context: Describes the author's experience at Y Combinator." "Question: What did the author do before his time at Y Combinator? \n" "Answer: PREVIOUS \n\n" "Context: Describe the author's experience at Y Combinator." "Question: What did the author do at Y Combinator? \n" "Answer: NONE \n\n" "Context: {context_str}\n" "Question: {query_str}\n" "Answer: " ) DEFAULT_REFINE_INFER_PREV_NEXT_TMPL = ( "The current context information is provided. \n" "A question is also provided. \n" "An existing answer is also provided.\n" "You are a retrieval agent deciding whether to search the " "document store for additional prior context or future context. \n" "Given the context, question, and previous answer, " "return PREVIOUS or NEXT or NONE.\n" "Examples: \n\n" "Context: {context_msg}\n" "Question: {query_str}\n" "Existing Answer: {existing_answer}\n" "Answer: " ) class AutoPrevNextNodePostprocessor(BaseNodePostprocessor): """Previous/Next Node post-processor. Allows users to fetch additional nodes from the document store, based on the prev/next relationships of the nodes. NOTE: difference with PrevNextPostprocessor is that this infers forward/backwards direction. NOTE: this is a beta feature. Args: docstore (BaseDocumentStore): The document store. num_nodes (int): The number of nodes to return (default: 1) infer_prev_next_tmpl (str): The template to use for inference. Required fields are {context_str} and {query_str}. """ docstore: BaseDocumentStore service_context: ServiceContext num_nodes: int = Field(default=1) infer_prev_next_tmpl: str = Field(default=DEFAULT_INFER_PREV_NEXT_TMPL) refine_prev_next_tmpl: str = Field(default=DEFAULT_REFINE_INFER_PREV_NEXT_TMPL) verbose: bool = Field(default=False) class Config: """Configuration for this pydantic object.""" arbitrary_types_allowed = True @classmethod def class_name(cls) -> str: return "AutoPrevNextNodePostprocessor" def _parse_prediction(self, raw_pred: str) -> str: """Parse prediction.""" pred = raw_pred.strip().lower() if "previous" in pred: return "previous" elif "next" in pred: return "next" elif "none" in pred: return "none" raise ValueError(f"Invalid prediction: {raw_pred}") def _postprocess_nodes( self, nodes: List[NodeWithScore], query_bundle: Optional[QueryBundle] = None, ) -> List[NodeWithScore]: """Postprocess nodes.""" if query_bundle is None: raise ValueError("Missing query bundle.") infer_prev_next_prompt = PromptTemplate( self.infer_prev_next_tmpl, ) refine_infer_prev_next_prompt = PromptTemplate(self.refine_prev_next_tmpl) all_nodes: Dict[str, NodeWithScore] = {} for node in nodes: all_nodes[node.node.node_id] = node # use response builder instead of llm directly # to be more robust to handling long context response_builder = get_response_synthesizer( service_context=self.service_context, text_qa_template=infer_prev_next_prompt, refine_template=refine_infer_prev_next_prompt, response_mode=ResponseMode.TREE_SUMMARIZE, ) raw_pred = response_builder.get_response( text_chunks=[node.node.get_content()], query_str=query_bundle.query_str, ) raw_pred = cast(str, raw_pred) mode = self._parse_prediction(raw_pred) logger.debug(f"> Postprocessor Predicted mode: {mode}") if self.verbose: print(f"> Postprocessor Predicted mode: {mode}") if mode == "next": all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore)) elif mode == "previous": all_nodes.update( get_backward_nodes(node, self.num_nodes, self.docstore) ) elif mode == "none": pass else: raise ValueError(f"Invalid mode: {mode}") sorted_nodes = sorted(all_nodes.values(), key=lambda x: x.node.node_id) return list(sorted_nodes) class LongContextReorder(BaseNodePostprocessor): """ Models struggle to access significant details found in the center of extended contexts. A study (https://arxiv.org/abs/2307.03172) observed that the best performance typically arises when crucial data is positioned at the start or conclusion of the input context. Additionally, as the input context lengthens, performance drops notably, even in models designed for long contexts.". """ @classmethod def class_name(cls) -> str: return "LongContextReorder" def _postprocess_nodes( self, nodes: List[NodeWithScore], query_bundle: Optional[QueryBundle] = None, ) -> List[NodeWithScore]: """Postprocess nodes.""" reordered_nodes: List[NodeWithScore] = [] ordered_nodes: List[NodeWithScore] = sorted( nodes, key=lambda x: x.score if x.score is not None else 0 ) for i, node in enumerate(ordered_nodes): if i % 2 == 0: reordered_nodes.insert(0, node) else: reordered_nodes.append(node) return reordered_nodes