faiss_rag_enterprise/llama_index/extractors/interface.py

166 lines
5.3 KiB
Python

"""Node parser interface."""
import asyncio
from abc import abstractmethod
from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, cast
from typing_extensions import Self
from llama_index.bridge.pydantic import Field
from llama_index.schema import BaseNode, MetadataMode, TextNode, TransformComponent
DEFAULT_NODE_TEXT_TEMPLATE = """\
[Excerpt from document]\n{metadata_str}\n\
Excerpt:\n-----\n{content}\n-----\n"""
class BaseExtractor(TransformComponent):
"""Metadata extractor."""
is_text_node_only: bool = True
show_progress: bool = Field(default=True, description="Whether to show progress.")
metadata_mode: MetadataMode = Field(
default=MetadataMode.ALL, description="Metadata mode to use when reading nodes."
)
node_text_template: str = Field(
default=DEFAULT_NODE_TEXT_TEMPLATE,
description="Template to represent how node text is mixed with metadata text.",
)
disable_template_rewrite: bool = Field(
default=False, description="Disable the node template rewrite."
)
in_place: bool = Field(
default=True, description="Whether to process nodes in place."
)
num_workers: int = Field(
default=4,
description="Number of workers to use for concurrent async processing.",
)
@classmethod
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
if isinstance(kwargs, dict):
data.update(kwargs)
data.pop("class_name", None)
llm_predictor = data.get("llm_predictor", None)
if llm_predictor:
from llama_index.llm_predictor.loading import load_predictor
llm_predictor = load_predictor(llm_predictor)
data["llm_predictor"] = llm_predictor
llm = data.get("llm", None)
if llm:
from llama_index.llms.loading import load_llm
llm = load_llm(llm)
data["llm"] = llm
return cls(**data)
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "MetadataExtractor"
@abstractmethod
async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
"""Extracts metadata for a sequence of nodes, returning a list of
metadata dictionaries corresponding to each node.
Args:
nodes (Sequence[Document]): nodes to extract metadata from
"""
def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
"""Extracts metadata for a sequence of nodes, returning a list of
metadata dictionaries corresponding to each node.
Args:
nodes (Sequence[Document]): nodes to extract metadata from
"""
return asyncio.run(self.aextract(nodes))
async def aprocess_nodes(
self,
nodes: List[BaseNode],
excluded_embed_metadata_keys: Optional[List[str]] = None,
excluded_llm_metadata_keys: Optional[List[str]] = None,
**kwargs: Any,
) -> List[BaseNode]:
"""Post process nodes parsed from documents.
Allows extractors to be chained.
Args:
nodes (List[BaseNode]): nodes to post-process
excluded_embed_metadata_keys (Optional[List[str]]):
keys to exclude from embed metadata
excluded_llm_metadata_keys (Optional[List[str]]):
keys to exclude from llm metadata
"""
if self.in_place:
new_nodes = nodes
else:
new_nodes = [deepcopy(node) for node in nodes]
cur_metadata_list = await self.aextract(new_nodes)
for idx, node in enumerate(new_nodes):
node.metadata.update(cur_metadata_list[idx])
for idx, node in enumerate(new_nodes):
if excluded_embed_metadata_keys is not None:
node.excluded_embed_metadata_keys.extend(excluded_embed_metadata_keys)
if excluded_llm_metadata_keys is not None:
node.excluded_llm_metadata_keys.extend(excluded_llm_metadata_keys)
if not self.disable_template_rewrite:
if isinstance(node, TextNode):
cast(TextNode, node).text_template = self.node_text_template
return new_nodes
def process_nodes(
self,
nodes: List[BaseNode],
excluded_embed_metadata_keys: Optional[List[str]] = None,
excluded_llm_metadata_keys: Optional[List[str]] = None,
**kwargs: Any,
) -> List[BaseNode]:
return asyncio.run(
self.aprocess_nodes(
nodes,
excluded_embed_metadata_keys=excluded_embed_metadata_keys,
excluded_llm_metadata_keys=excluded_llm_metadata_keys,
**kwargs,
)
)
def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:
"""Post process nodes parsed from documents.
Allows extractors to be chained.
Args:
nodes (List[BaseNode]): nodes to post-process
"""
return self.process_nodes(nodes, **kwargs)
async def acall(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:
"""Post process nodes parsed from documents.
Allows extractors to be chained.
Args:
nodes (List[BaseNode]): nodes to post-process
"""
return await self.aprocess_nodes(nodes, **kwargs)