105 lines
3.5 KiB
Python
105 lines
3.5 KiB
Python
"""JSON node parser."""
|
|
import json
|
|
from typing import Any, Dict, Generator, List, Optional, Sequence
|
|
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.node_parser.interface import NodeParser
|
|
from llama_index.node_parser.node_utils import build_nodes_from_splits
|
|
from llama_index.schema import BaseNode, MetadataMode, TextNode
|
|
from llama_index.utils import get_tqdm_iterable
|
|
|
|
|
|
class JSONNodeParser(NodeParser):
|
|
"""JSON node parser.
|
|
|
|
Splits a document into Nodes using custom JSON splitting logic.
|
|
|
|
Args:
|
|
include_metadata (bool): whether to include metadata in nodes
|
|
include_prev_next_rel (bool): whether to include prev/next relationships
|
|
|
|
"""
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
include_metadata: bool = True,
|
|
include_prev_next_rel: bool = True,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
) -> "JSONNodeParser":
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
|
|
return cls(
|
|
include_metadata=include_metadata,
|
|
include_prev_next_rel=include_prev_next_rel,
|
|
callback_manager=callback_manager,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
"""Get class name."""
|
|
return "JSONNodeParser"
|
|
|
|
def _parse_nodes(
|
|
self, nodes: Sequence[BaseNode], show_progress: bool = False, **kwargs: Any
|
|
) -> List[BaseNode]:
|
|
all_nodes: List[BaseNode] = []
|
|
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")
|
|
|
|
for node in nodes_with_progress:
|
|
nodes = self.get_nodes_from_node(node)
|
|
all_nodes.extend(nodes)
|
|
|
|
return all_nodes
|
|
|
|
def get_nodes_from_node(self, node: BaseNode) -> List[TextNode]:
|
|
"""Get nodes from document."""
|
|
text = node.get_content(metadata_mode=MetadataMode.NONE)
|
|
try:
|
|
data = json.loads(text)
|
|
except json.JSONDecodeError:
|
|
# Handle invalid JSON input here
|
|
return []
|
|
|
|
json_nodes = []
|
|
if isinstance(data, dict):
|
|
lines = [*self._depth_first_yield(data, 0, [])]
|
|
json_nodes.extend(
|
|
build_nodes_from_splits(["\n".join(lines)], node, id_func=self.id_func)
|
|
)
|
|
elif isinstance(data, list):
|
|
for json_object in data:
|
|
lines = [*self._depth_first_yield(json_object, 0, [])]
|
|
json_nodes.extend(
|
|
build_nodes_from_splits(
|
|
["\n".join(lines)], node, id_func=self.id_func
|
|
)
|
|
)
|
|
else:
|
|
raise ValueError("JSON is invalid")
|
|
|
|
return json_nodes
|
|
|
|
def _depth_first_yield(
|
|
self, json_data: Dict, levels_back: int, path: List[str]
|
|
) -> Generator[str, None, None]:
|
|
"""Do depth first yield of all of the leaf nodes of a JSON.
|
|
|
|
Combines keys in the JSON tree using spaces.
|
|
|
|
If levels_back is set to 0, prints all levels.
|
|
|
|
"""
|
|
if isinstance(json_data, dict):
|
|
for key, value in json_data.items():
|
|
new_path = path[:]
|
|
new_path.append(key)
|
|
yield from self._depth_first_yield(value, levels_back, new_path)
|
|
elif isinstance(json_data, list):
|
|
for _, value in enumerate(json_data):
|
|
yield from self._depth_first_yield(value, levels_back, path)
|
|
else:
|
|
new_path = path[-levels_back:]
|
|
new_path.append(str(json_data))
|
|
yield " ".join(new_path)
|