faiss_rag_enterprise/llama_index/postprocessor/metadata_replacement.py

34 lines
1.0 KiB
Python

from typing import List, Optional
from llama_index.bridge.pydantic import Field
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle
class MetadataReplacementPostProcessor(BaseNodePostprocessor):
target_metadata_key: str = Field(
description="Target metadata key to replace node content with."
)
def __init__(self, target_metadata_key: str) -> None:
super().__init__(target_metadata_key=target_metadata_key)
@classmethod
def class_name(cls) -> str:
return "MetadataReplacementPostProcessor"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
for n in nodes:
n.node.set_content(
n.node.metadata.get(
self.target_metadata_key,
n.node.get_content(metadata_mode=MetadataMode.NONE),
)
)
return nodes