faiss_rag_enterprise/llama_index/postprocessor/node.py

386 lines
13 KiB
Python

"""Node postprocessor."""
import logging
from typing import Dict, List, Optional, cast
from llama_index.bridge.pydantic import Field, validator
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.prompts.base import PromptTemplate
from llama_index.response_synthesizers import ResponseMode, get_response_synthesizer
from llama_index.schema import NodeRelationship, NodeWithScore, QueryBundle
from llama_index.service_context import ServiceContext
from llama_index.storage.docstore import BaseDocumentStore
logger = logging.getLogger(__name__)
class KeywordNodePostprocessor(BaseNodePostprocessor):
"""Keyword-based Node processor."""
required_keywords: List[str] = Field(default_factory=list)
exclude_keywords: List[str] = Field(default_factory=list)
lang: str = Field(default="en")
@classmethod
def class_name(cls) -> str:
return "KeywordNodePostprocessor"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
"""Postprocess nodes."""
try:
import spacy
except ImportError:
raise ImportError(
"Spacy is not installed, please install it with `pip install spacy`."
)
from spacy.matcher import PhraseMatcher
nlp = spacy.blank(self.lang)
required_matcher = PhraseMatcher(nlp.vocab)
exclude_matcher = PhraseMatcher(nlp.vocab)
required_matcher.add("RequiredKeywords", list(nlp.pipe(self.required_keywords)))
exclude_matcher.add("ExcludeKeywords", list(nlp.pipe(self.exclude_keywords)))
new_nodes = []
for node_with_score in nodes:
node = node_with_score.node
doc = nlp(node.get_content())
if self.required_keywords and not required_matcher(doc):
continue
if self.exclude_keywords and exclude_matcher(doc):
continue
new_nodes.append(node_with_score)
return new_nodes
class SimilarityPostprocessor(BaseNodePostprocessor):
"""Similarity-based Node processor."""
similarity_cutoff: float = Field(default=None)
@classmethod
def class_name(cls) -> str:
return "SimilarityPostprocessor"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
"""Postprocess nodes."""
sim_cutoff_exists = self.similarity_cutoff is not None
new_nodes = []
for node in nodes:
should_use_node = True
if sim_cutoff_exists:
similarity = node.score
if similarity is None:
should_use_node = False
elif cast(float, similarity) < cast(float, self.similarity_cutoff):
should_use_node = False
if should_use_node:
new_nodes.append(node)
return new_nodes
def get_forward_nodes(
node_with_score: NodeWithScore, num_nodes: int, docstore: BaseDocumentStore
) -> Dict[str, NodeWithScore]:
"""Get forward nodes."""
node = node_with_score.node
nodes: Dict[str, NodeWithScore] = {node.node_id: node_with_score}
cur_count = 0
# get forward nodes in an iterative manner
while cur_count < num_nodes:
if NodeRelationship.NEXT not in node.relationships:
break
next_node_info = node.next_node
if next_node_info is None:
break
next_node_id = next_node_info.node_id
next_node = docstore.get_node(next_node_id)
nodes[next_node.node_id] = NodeWithScore(node=next_node)
node = next_node
cur_count += 1
return nodes
def get_backward_nodes(
node_with_score: NodeWithScore, num_nodes: int, docstore: BaseDocumentStore
) -> Dict[str, NodeWithScore]:
"""Get backward nodes."""
node = node_with_score.node
# get backward nodes in an iterative manner
nodes: Dict[str, NodeWithScore] = {node.node_id: node_with_score}
cur_count = 0
while cur_count < num_nodes:
prev_node_info = node.prev_node
if prev_node_info is None:
break
prev_node_id = prev_node_info.node_id
prev_node = docstore.get_node(prev_node_id)
if prev_node is None:
break
nodes[prev_node.node_id] = NodeWithScore(node=prev_node)
node = prev_node
cur_count += 1
return nodes
class PrevNextNodePostprocessor(BaseNodePostprocessor):
"""Previous/Next Node post-processor.
Allows users to fetch additional nodes from the document store,
based on the relationships of the nodes.
NOTE: this is a beta feature.
Args:
docstore (BaseDocumentStore): The document store.
num_nodes (int): The number of nodes to return (default: 1)
mode (str): The mode of the post-processor.
Can be "previous", "next", or "both.
"""
docstore: BaseDocumentStore
num_nodes: int = Field(default=1)
mode: str = Field(default="next")
@validator("mode")
def _validate_mode(cls, v: str) -> str:
"""Validate mode."""
if v not in ["next", "previous", "both"]:
raise ValueError(f"Invalid mode: {v}")
return v
@classmethod
def class_name(cls) -> str:
return "PrevNextNodePostprocessor"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
"""Postprocess nodes."""
all_nodes: Dict[str, NodeWithScore] = {}
for node in nodes:
all_nodes[node.node.node_id] = node
if self.mode == "next":
all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore))
elif self.mode == "previous":
all_nodes.update(
get_backward_nodes(node, self.num_nodes, self.docstore)
)
elif self.mode == "both":
all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore))
all_nodes.update(
get_backward_nodes(node, self.num_nodes, self.docstore)
)
else:
raise ValueError(f"Invalid mode: {self.mode}")
all_nodes_values: List[NodeWithScore] = list(all_nodes.values())
sorted_nodes: List[NodeWithScore] = []
for node in all_nodes_values:
# variable to check if cand node is inserted
node_inserted = False
for i, cand in enumerate(sorted_nodes):
node_id = node.node.node_id
# prepend to current candidate
prev_node_info = cand.node.prev_node
next_node_info = cand.node.next_node
if prev_node_info is not None and node_id == prev_node_info.node_id:
node_inserted = True
sorted_nodes.insert(i, node)
break
# append to current candidate
elif next_node_info is not None and node_id == next_node_info.node_id:
node_inserted = True
sorted_nodes.insert(i + 1, node)
break
if not node_inserted:
sorted_nodes.append(node)
return sorted_nodes
DEFAULT_INFER_PREV_NEXT_TMPL = (
"The current context information is provided. \n"
"A question is also provided. \n"
"You are a retrieval agent deciding whether to search the "
"document store for additional prior context or future context. \n"
"Given the context and question, return PREVIOUS or NEXT or NONE. \n"
"Examples: \n\n"
"Context: Describes the author's experience at Y Combinator."
"Question: What did the author do after his time at Y Combinator? \n"
"Answer: NEXT \n\n"
"Context: Describes the author's experience at Y Combinator."
"Question: What did the author do before his time at Y Combinator? \n"
"Answer: PREVIOUS \n\n"
"Context: Describe the author's experience at Y Combinator."
"Question: What did the author do at Y Combinator? \n"
"Answer: NONE \n\n"
"Context: {context_str}\n"
"Question: {query_str}\n"
"Answer: "
)
DEFAULT_REFINE_INFER_PREV_NEXT_TMPL = (
"The current context information is provided. \n"
"A question is also provided. \n"
"An existing answer is also provided.\n"
"You are a retrieval agent deciding whether to search the "
"document store for additional prior context or future context. \n"
"Given the context, question, and previous answer, "
"return PREVIOUS or NEXT or NONE.\n"
"Examples: \n\n"
"Context: {context_msg}\n"
"Question: {query_str}\n"
"Existing Answer: {existing_answer}\n"
"Answer: "
)
class AutoPrevNextNodePostprocessor(BaseNodePostprocessor):
"""Previous/Next Node post-processor.
Allows users to fetch additional nodes from the document store,
based on the prev/next relationships of the nodes.
NOTE: difference with PrevNextPostprocessor is that
this infers forward/backwards direction.
NOTE: this is a beta feature.
Args:
docstore (BaseDocumentStore): The document store.
num_nodes (int): The number of nodes to return (default: 1)
infer_prev_next_tmpl (str): The template to use for inference.
Required fields are {context_str} and {query_str}.
"""
docstore: BaseDocumentStore
service_context: ServiceContext
num_nodes: int = Field(default=1)
infer_prev_next_tmpl: str = Field(default=DEFAULT_INFER_PREV_NEXT_TMPL)
refine_prev_next_tmpl: str = Field(default=DEFAULT_REFINE_INFER_PREV_NEXT_TMPL)
verbose: bool = Field(default=False)
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@classmethod
def class_name(cls) -> str:
return "AutoPrevNextNodePostprocessor"
def _parse_prediction(self, raw_pred: str) -> str:
"""Parse prediction."""
pred = raw_pred.strip().lower()
if "previous" in pred:
return "previous"
elif "next" in pred:
return "next"
elif "none" in pred:
return "none"
raise ValueError(f"Invalid prediction: {raw_pred}")
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
"""Postprocess nodes."""
if query_bundle is None:
raise ValueError("Missing query bundle.")
infer_prev_next_prompt = PromptTemplate(
self.infer_prev_next_tmpl,
)
refine_infer_prev_next_prompt = PromptTemplate(self.refine_prev_next_tmpl)
all_nodes: Dict[str, NodeWithScore] = {}
for node in nodes:
all_nodes[node.node.node_id] = node
# use response builder instead of llm directly
# to be more robust to handling long context
response_builder = get_response_synthesizer(
service_context=self.service_context,
text_qa_template=infer_prev_next_prompt,
refine_template=refine_infer_prev_next_prompt,
response_mode=ResponseMode.TREE_SUMMARIZE,
)
raw_pred = response_builder.get_response(
text_chunks=[node.node.get_content()],
query_str=query_bundle.query_str,
)
raw_pred = cast(str, raw_pred)
mode = self._parse_prediction(raw_pred)
logger.debug(f"> Postprocessor Predicted mode: {mode}")
if self.verbose:
print(f"> Postprocessor Predicted mode: {mode}")
if mode == "next":
all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore))
elif mode == "previous":
all_nodes.update(
get_backward_nodes(node, self.num_nodes, self.docstore)
)
elif mode == "none":
pass
else:
raise ValueError(f"Invalid mode: {mode}")
sorted_nodes = sorted(all_nodes.values(), key=lambda x: x.node.node_id)
return list(sorted_nodes)
class LongContextReorder(BaseNodePostprocessor):
"""
Models struggle to access significant details found
in the center of extended contexts. A study
(https://arxiv.org/abs/2307.03172) observed that the best
performance typically arises when crucial data is positioned
at the start or conclusion of the input context. Additionally,
as the input context lengthens, performance drops notably, even
in models designed for long contexts.".
"""
@classmethod
def class_name(cls) -> str:
return "LongContextReorder"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
"""Postprocess nodes."""
reordered_nodes: List[NodeWithScore] = []
ordered_nodes: List[NodeWithScore] = sorted(
nodes, key=lambda x: x.score if x.score is not None else 0
)
for i, node in enumerate(ordered_nodes):
if i % 2 == 0:
reordered_nodes.insert(0, node)
else:
reordered_nodes.append(node)
return reordered_nodes