"""Base schema for data structures.""" import json import textwrap import uuid from abc import abstractmethod from dataclasses import dataclass from enum import Enum, auto from hashlib import sha256 from io import BytesIO from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from dataclasses_json import DataClassJsonMixin from typing_extensions import Self from llama_index.bridge.pydantic import BaseModel, Field from llama_index.utils import SAMPLE_TEXT, truncate_text if TYPE_CHECKING: from haystack.schema import Document as HaystackDocument from semantic_kernel.memory.memory_record import MemoryRecord from llama_index.bridge.langchain import Document as LCDocument DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}" DEFAULT_METADATA_TMPL = "{key}: {value}" # NOTE: for pretty printing TRUNCATE_LENGTH = 350 WRAP_WIDTH = 70 ImageType = Union[str, BytesIO] class BaseComponent(BaseModel): """Base component object to capture class names.""" class Config: @staticmethod def schema_extra(schema: Dict[str, Any], model: "BaseComponent") -> None: """Add class name to schema.""" schema["properties"]["class_name"] = { "title": "Class Name", "type": "string", "default": model.class_name(), } @classmethod def class_name(cls) -> str: """ Get the class name, used as a unique ID in serialization. This provides a key that makes serialization robust against actual class name changes. """ return "base_component" def json(self, **kwargs: Any) -> str: return self.to_json(**kwargs) def dict(self, **kwargs: Any) -> Dict[str, Any]: data = super().dict(**kwargs) data["class_name"] = self.class_name() return data def __getstate__(self) -> Dict[str, Any]: state = super().__getstate__() # tiktoken is not pickleable # state["__dict__"] = self.dict() state["__dict__"].pop("tokenizer", None) # remove local functions keys_to_remove = [] for key, val in state["__dict__"].items(): if key.endswith("_fn"): keys_to_remove.append(key) if "" in str(val): keys_to_remove.append(key) for key in keys_to_remove: state["__dict__"].pop(key, None) # remove private attributes -- kind of dangerous state["__private_attribute_values__"] = {} return state def __setstate__(self, state: Dict[str, Any]) -> None: # Use the __dict__ and __init__ method to set state # so that all variable initialize try: self.__init__(**state["__dict__"]) # type: ignore except Exception: # Fall back to the default __setstate__ method super().__setstate__(state) def to_dict(self, **kwargs: Any) -> Dict[str, Any]: data = self.dict(**kwargs) data["class_name"] = self.class_name() return data def to_json(self, **kwargs: Any) -> str: data = self.to_dict(**kwargs) return json.dumps(data) # TODO: return type here not supported by current mypy version @classmethod def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore if isinstance(kwargs, dict): data.update(kwargs) data.pop("class_name", None) return cls(**data) @classmethod def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore data = json.loads(data_str) return cls.from_dict(data, **kwargs) class TransformComponent(BaseComponent): """Base class for transform components.""" class Config: arbitrary_types_allowed = True @abstractmethod def __call__(self, nodes: List["BaseNode"], **kwargs: Any) -> List["BaseNode"]: """Transform nodes.""" async def acall(self, nodes: List["BaseNode"], **kwargs: Any) -> List["BaseNode"]: """Async transform nodes.""" return self.__call__(nodes, **kwargs) class NodeRelationship(str, Enum): """Node relationships used in `BaseNode` class. Attributes: SOURCE: The node is the source document. PREVIOUS: The node is the previous node in the document. NEXT: The node is the next node in the document. PARENT: The node is the parent node in the document. CHILD: The node is a child node in the document. """ SOURCE = auto() PREVIOUS = auto() NEXT = auto() PARENT = auto() CHILD = auto() class ObjectType(str, Enum): TEXT = auto() IMAGE = auto() INDEX = auto() DOCUMENT = auto() class MetadataMode(str, Enum): ALL = "all" EMBED = "embed" LLM = "llm" NONE = "none" class RelatedNodeInfo(BaseComponent): node_id: str node_type: Optional[ObjectType] = None metadata: Dict[str, Any] = Field(default_factory=dict) hash: Optional[str] = None @classmethod def class_name(cls) -> str: return "RelatedNodeInfo" RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]] # Node classes for indexes class BaseNode(BaseComponent): """Base node Object. Generic abstract interface for retrievable nodes """ class Config: allow_population_by_field_name = True # hash is computed on local field, during the validation process validate_assignment = True id_: str = Field( default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node." ) embedding: Optional[List[float]] = Field( default=None, description="Embedding of the node." ) """" metadata fields - injected as part of the text shown to LLMs as context - injected as part of the text for generating embeddings - used by vector DBs for metadata filtering """ metadata: Dict[str, Any] = Field( default_factory=dict, description="A flat dictionary of metadata fields", alias="extra_info", ) excluded_embed_metadata_keys: List[str] = Field( default_factory=list, description="Metadata keys that are excluded from text for the embed model.", ) excluded_llm_metadata_keys: List[str] = Field( default_factory=list, description="Metadata keys that are excluded from text for the LLM.", ) relationships: Dict[NodeRelationship, RelatedNodeType] = Field( default_factory=dict, description="A mapping of relationships to other node information.", ) @classmethod @abstractmethod def get_type(cls) -> str: """Get Object type.""" @abstractmethod def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str: """Get object content.""" @abstractmethod def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: """Metadata string.""" @abstractmethod def set_content(self, value: Any) -> None: """Set the content of the node.""" @property @abstractmethod def hash(self) -> str: """Get hash of node.""" @property def node_id(self) -> str: return self.id_ @node_id.setter def node_id(self, value: str) -> None: self.id_ = value @property def source_node(self) -> Optional[RelatedNodeInfo]: """Source object node. Extracted from the relationships field. """ if NodeRelationship.SOURCE not in self.relationships: return None relation = self.relationships[NodeRelationship.SOURCE] if isinstance(relation, list): raise ValueError("Source object must be a single RelatedNodeInfo object") return relation @property def prev_node(self) -> Optional[RelatedNodeInfo]: """Prev node.""" if NodeRelationship.PREVIOUS not in self.relationships: return None relation = self.relationships[NodeRelationship.PREVIOUS] if not isinstance(relation, RelatedNodeInfo): raise ValueError("Previous object must be a single RelatedNodeInfo object") return relation @property def next_node(self) -> Optional[RelatedNodeInfo]: """Next node.""" if NodeRelationship.NEXT not in self.relationships: return None relation = self.relationships[NodeRelationship.NEXT] if not isinstance(relation, RelatedNodeInfo): raise ValueError("Next object must be a single RelatedNodeInfo object") return relation @property def parent_node(self) -> Optional[RelatedNodeInfo]: """Parent node.""" if NodeRelationship.PARENT not in self.relationships: return None relation = self.relationships[NodeRelationship.PARENT] if not isinstance(relation, RelatedNodeInfo): raise ValueError("Parent object must be a single RelatedNodeInfo object") return relation @property def child_nodes(self) -> Optional[List[RelatedNodeInfo]]: """Child nodes.""" if NodeRelationship.CHILD not in self.relationships: return None relation = self.relationships[NodeRelationship.CHILD] if not isinstance(relation, list): raise ValueError("Child objects must be a list of RelatedNodeInfo objects.") return relation @property def ref_doc_id(self) -> Optional[str]: """Deprecated: Get ref doc id.""" source_node = self.source_node if source_node is None: return None return source_node.node_id @property def extra_info(self) -> Dict[str, Any]: """TODO: DEPRECATED: Extra info.""" return self.metadata def __str__(self) -> str: source_text_truncated = truncate_text( self.get_content().strip(), TRUNCATE_LENGTH ) source_text_wrapped = textwrap.fill( f"Text: {source_text_truncated}\n", width=WRAP_WIDTH ) return f"Node ID: {self.node_id}\n{source_text_wrapped}" def get_embedding(self) -> List[float]: """Get embedding. Errors if embedding is None. """ if self.embedding is None: raise ValueError("embedding not set.") return self.embedding def as_related_node_info(self) -> RelatedNodeInfo: """Get node as RelatedNodeInfo.""" return RelatedNodeInfo( node_id=self.node_id, node_type=self.get_type(), metadata=self.metadata, hash=self.hash, ) class TextNode(BaseNode): text: str = Field(default="", description="Text content of the node.") start_char_idx: Optional[int] = Field( default=None, description="Start char index of the node." ) end_char_idx: Optional[int] = Field( default=None, description="End char index of the node." ) text_template: str = Field( default=DEFAULT_TEXT_NODE_TMPL, description=( "Template for how text is formatted, with {content} and " "{metadata_str} placeholders." ), ) metadata_template: str = Field( default=DEFAULT_METADATA_TMPL, description=( "Template for how metadata is formatted, with {key} and " "{value} placeholders." ), ) metadata_seperator: str = Field( default="\n", description="Separator between metadata fields when converting to string.", ) @classmethod def class_name(cls) -> str: return "TextNode" @property def hash(self) -> str: doc_identity = str(self.text) + str(self.metadata) return str(sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()) @classmethod def get_type(cls) -> str: """Get Object type.""" return ObjectType.TEXT def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: """Get object content.""" metadata_str = self.get_metadata_str(mode=metadata_mode).strip() if not metadata_str: return self.text return self.text_template.format( content=self.text, metadata_str=metadata_str ).strip() def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: """Metadata info string.""" if mode == MetadataMode.NONE: return "" usable_metadata_keys = set(self.metadata.keys()) if mode == MetadataMode.LLM: for key in self.excluded_llm_metadata_keys: if key in usable_metadata_keys: usable_metadata_keys.remove(key) elif mode == MetadataMode.EMBED: for key in self.excluded_embed_metadata_keys: if key in usable_metadata_keys: usable_metadata_keys.remove(key) return self.metadata_seperator.join( [ self.metadata_template.format(key=key, value=str(value)) for key, value in self.metadata.items() if key in usable_metadata_keys ] ) def set_content(self, value: str) -> None: """Set the content of the node.""" self.text = value def get_node_info(self) -> Dict[str, Any]: """Get node info.""" return {"start": self.start_char_idx, "end": self.end_char_idx} def get_text(self) -> str: return self.get_content(metadata_mode=MetadataMode.NONE) @property def node_info(self) -> Dict[str, Any]: """Deprecated: Get node info.""" return self.get_node_info() # TODO: legacy backport of old Node class Node = TextNode class ImageNode(TextNode): """Node with image.""" # TODO: store reference instead of actual image # base64 encoded image str image: Optional[str] = None image_path: Optional[str] = None image_url: Optional[str] = None image_mimetype: Optional[str] = None text_embedding: Optional[List[float]] = Field( default=None, description="Text embedding of image node, if text field is filled out", ) @classmethod def get_type(cls) -> str: return ObjectType.IMAGE @classmethod def class_name(cls) -> str: return "ImageNode" def resolve_image(self) -> ImageType: """Resolve an image such that PIL can read it.""" if self.image is not None: import base64 return BytesIO(base64.b64decode(self.image)) elif self.image_path is not None: return self.image_path elif self.image_url is not None: # load image from URL import requests response = requests.get(self.image_url) return BytesIO(response.content) else: raise ValueError("No image found in node.") class IndexNode(TextNode): """Node with reference to any object. This can include other indices, query engines, retrievers. This can also include other nodes (though this is overlapping with `relationships` on the Node class). """ index_id: str obj: Any = Field(exclude=True) @classmethod def from_text_node( cls, node: TextNode, index_id: str, ) -> "IndexNode": """Create index node from text node.""" # copy all attributes from text node, add index id return cls( **node.dict(), index_id=index_id, ) @classmethod def get_type(cls) -> str: return ObjectType.INDEX @classmethod def class_name(cls) -> str: return "IndexNode" class NodeWithScore(BaseComponent): node: BaseNode score: Optional[float] = None def __str__(self) -> str: score_str = "None" if self.score is None else f"{self.score: 0.3f}" return f"{self.node}\nScore: {score_str}\n" def get_score(self, raise_error: bool = False) -> float: """Get score.""" if self.score is None: if raise_error: raise ValueError("Score not set.") else: return 0.0 else: return self.score @classmethod def class_name(cls) -> str: return "NodeWithScore" ##### pass through methods to BaseNode ##### @property def node_id(self) -> str: return self.node.node_id @property def id_(self) -> str: return self.node.id_ @property def text(self) -> str: if isinstance(self.node, TextNode): return self.node.text else: raise ValueError("Node must be a TextNode to get text.") @property def metadata(self) -> Dict[str, Any]: return self.node.metadata @property def embedding(self) -> Optional[List[float]]: return self.node.embedding def get_text(self) -> str: if isinstance(self.node, TextNode): return self.node.get_text() else: raise ValueError("Node must be a TextNode to get text.") def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: return self.node.get_content(metadata_mode=metadata_mode) def get_embedding(self) -> List[float]: return self.node.get_embedding() # Document Classes for Readers class Document(TextNode): """Generic interface for a data document. This document connects to data sources. """ # TODO: A lot of backwards compatibility logic here, clean up id_: str = Field( default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node.", alias="doc_id", ) _compat_fields = {"doc_id": "id_", "extra_info": "metadata"} @classmethod def get_type(cls) -> str: """Get Document type.""" return ObjectType.DOCUMENT @property def doc_id(self) -> str: """Get document ID.""" return self.id_ def __str__(self) -> str: source_text_truncated = truncate_text( self.get_content().strip(), TRUNCATE_LENGTH ) source_text_wrapped = textwrap.fill( f"Text: {source_text_truncated}\n", width=WRAP_WIDTH ) return f"Doc ID: {self.doc_id}\n{source_text_wrapped}" def get_doc_id(self) -> str: """TODO: Deprecated: Get document ID.""" return self.id_ def __setattr__(self, name: str, value: object) -> None: if name in self._compat_fields: name = self._compat_fields[name] super().__setattr__(name, value) def to_langchain_format(self) -> "LCDocument": """Convert struct to LangChain document format.""" from llama_index.bridge.langchain import Document as LCDocument metadata = self.metadata or {} return LCDocument(page_content=self.text, metadata=metadata) @classmethod def from_langchain_format(cls, doc: "LCDocument") -> "Document": """Convert struct from LangChain document format.""" return cls(text=doc.page_content, metadata=doc.metadata) def to_haystack_format(self) -> "HaystackDocument": """Convert struct to Haystack document format.""" from haystack.schema import Document as HaystackDocument return HaystackDocument( content=self.text, meta=self.metadata, embedding=self.embedding, id=self.id_ ) @classmethod def from_haystack_format(cls, doc: "HaystackDocument") -> "Document": """Convert struct from Haystack document format.""" return cls( text=doc.content, metadata=doc.meta, embedding=doc.embedding, id_=doc.id ) def to_embedchain_format(self) -> Dict[str, Any]: """Convert struct to EmbedChain document format.""" return { "doc_id": self.id_, "data": {"content": self.text, "meta_data": self.metadata}, } @classmethod def from_embedchain_format(cls, doc: Dict[str, Any]) -> "Document": """Convert struct from EmbedChain document format.""" return cls( text=doc["data"]["content"], metadata=doc["data"]["meta_data"], id_=doc["doc_id"], ) def to_semantic_kernel_format(self) -> "MemoryRecord": """Convert struct to Semantic Kernel document format.""" import numpy as np from semantic_kernel.memory.memory_record import MemoryRecord return MemoryRecord( id=self.id_, text=self.text, additional_metadata=self.get_metadata_str(), embedding=np.array(self.embedding) if self.embedding else None, ) @classmethod def from_semantic_kernel_format(cls, doc: "MemoryRecord") -> "Document": """Convert struct from Semantic Kernel document format.""" return cls( text=doc._text, metadata={"additional_metadata": doc._additional_metadata}, embedding=doc._embedding.tolist() if doc._embedding is not None else None, id_=doc._id, ) def to_vectorflow(self, client: Any) -> None: """Send a document to vectorflow, since they don't have a document object.""" # write document to temp file import tempfile with tempfile.NamedTemporaryFile() as f: f.write(self.text.encode("utf-8")) f.flush() client.embed(f.name) @classmethod def example(cls) -> "Document": return Document( text=SAMPLE_TEXT, metadata={"filename": "README.md", "category": "codebase"}, ) @classmethod def class_name(cls) -> str: return "Document" class ImageDocument(Document, ImageNode): """Data document containing an image.""" @classmethod def class_name(cls) -> str: return "ImageDocument" @dataclass class QueryBundle(DataClassJsonMixin): """ Query bundle. This dataclass contains the original query string and associated transformations. Args: query_str (str): the original user-specified query string. This is currently used by all non embedding-based queries. custom_embedding_strs (list[str]): list of strings used for embedding the query. This is currently used by all embedding-based queries. embedding (list[float]): the stored embedding for the query. """ query_str: str # using single image path as query input image_path: Optional[str] = None custom_embedding_strs: Optional[List[str]] = None embedding: Optional[List[float]] = None @property def embedding_strs(self) -> List[str]: """Use custom embedding strs if specified, otherwise use query str.""" if self.custom_embedding_strs is None: if len(self.query_str) == 0: return [] return [self.query_str] else: return self.custom_embedding_strs @property def embedding_image(self) -> List[ImageType]: """Use image path for image retrieval.""" if self.image_path is None: return [] return [self.image_path] def __str__(self) -> str: """Convert to string representation.""" return self.query_str QueryType = Union[str, QueryBundle]