faiss_rag_enterprise/llama_index/schema.py

773 lines
23 KiB
Python

"""Base schema for data structures."""
import json
import textwrap
import uuid
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from hashlib import sha256
from io import BytesIO
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from dataclasses_json import DataClassJsonMixin
from typing_extensions import Self
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.utils import SAMPLE_TEXT, truncate_text
if TYPE_CHECKING:
from haystack.schema import Document as HaystackDocument
from semantic_kernel.memory.memory_record import MemoryRecord
from llama_index.bridge.langchain import Document as LCDocument
DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
DEFAULT_METADATA_TMPL = "{key}: {value}"
# NOTE: for pretty printing
TRUNCATE_LENGTH = 350
WRAP_WIDTH = 70
ImageType = Union[str, BytesIO]
class BaseComponent(BaseModel):
"""Base component object to capture class names."""
class Config:
@staticmethod
def schema_extra(schema: Dict[str, Any], model: "BaseComponent") -> None:
"""Add class name to schema."""
schema["properties"]["class_name"] = {
"title": "Class Name",
"type": "string",
"default": model.class_name(),
}
@classmethod
def class_name(cls) -> str:
"""
Get the class name, used as a unique ID in serialization.
This provides a key that makes serialization robust against actual class
name changes.
"""
return "base_component"
def json(self, **kwargs: Any) -> str:
return self.to_json(**kwargs)
def dict(self, **kwargs: Any) -> Dict[str, Any]:
data = super().dict(**kwargs)
data["class_name"] = self.class_name()
return data
def __getstate__(self) -> Dict[str, Any]:
state = super().__getstate__()
# tiktoken is not pickleable
# state["__dict__"] = self.dict()
state["__dict__"].pop("tokenizer", None)
# remove local functions
keys_to_remove = []
for key, val in state["__dict__"].items():
if key.endswith("_fn"):
keys_to_remove.append(key)
if "<lambda>" in str(val):
keys_to_remove.append(key)
for key in keys_to_remove:
state["__dict__"].pop(key, None)
# remove private attributes -- kind of dangerous
state["__private_attribute_values__"] = {}
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
# Use the __dict__ and __init__ method to set state
# so that all variable initialize
try:
self.__init__(**state["__dict__"]) # type: ignore
except Exception:
# Fall back to the default __setstate__ method
super().__setstate__(state)
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
data = self.dict(**kwargs)
data["class_name"] = self.class_name()
return data
def to_json(self, **kwargs: Any) -> str:
data = self.to_dict(**kwargs)
return json.dumps(data)
# TODO: return type here not supported by current mypy version
@classmethod
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
if isinstance(kwargs, dict):
data.update(kwargs)
data.pop("class_name", None)
return cls(**data)
@classmethod
def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore
data = json.loads(data_str)
return cls.from_dict(data, **kwargs)
class TransformComponent(BaseComponent):
"""Base class for transform components."""
class Config:
arbitrary_types_allowed = True
@abstractmethod
def __call__(self, nodes: List["BaseNode"], **kwargs: Any) -> List["BaseNode"]:
"""Transform nodes."""
async def acall(self, nodes: List["BaseNode"], **kwargs: Any) -> List["BaseNode"]:
"""Async transform nodes."""
return self.__call__(nodes, **kwargs)
class NodeRelationship(str, Enum):
"""Node relationships used in `BaseNode` class.
Attributes:
SOURCE: The node is the source document.
PREVIOUS: The node is the previous node in the document.
NEXT: The node is the next node in the document.
PARENT: The node is the parent node in the document.
CHILD: The node is a child node in the document.
"""
SOURCE = auto()
PREVIOUS = auto()
NEXT = auto()
PARENT = auto()
CHILD = auto()
class ObjectType(str, Enum):
TEXT = auto()
IMAGE = auto()
INDEX = auto()
DOCUMENT = auto()
class MetadataMode(str, Enum):
ALL = "all"
EMBED = "embed"
LLM = "llm"
NONE = "none"
class RelatedNodeInfo(BaseComponent):
node_id: str
node_type: Optional[ObjectType] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
hash: Optional[str] = None
@classmethod
def class_name(cls) -> str:
return "RelatedNodeInfo"
RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]]
# Node classes for indexes
class BaseNode(BaseComponent):
"""Base node Object.
Generic abstract interface for retrievable nodes
"""
class Config:
allow_population_by_field_name = True
# hash is computed on local field, during the validation process
validate_assignment = True
id_: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node."
)
embedding: Optional[List[float]] = Field(
default=None, description="Embedding of the node."
)
""""
metadata fields
- injected as part of the text shown to LLMs as context
- injected as part of the text for generating embeddings
- used by vector DBs for metadata filtering
"""
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="A flat dictionary of metadata fields",
alias="extra_info",
)
excluded_embed_metadata_keys: List[str] = Field(
default_factory=list,
description="Metadata keys that are excluded from text for the embed model.",
)
excluded_llm_metadata_keys: List[str] = Field(
default_factory=list,
description="Metadata keys that are excluded from text for the LLM.",
)
relationships: Dict[NodeRelationship, RelatedNodeType] = Field(
default_factory=dict,
description="A mapping of relationships to other node information.",
)
@classmethod
@abstractmethod
def get_type(cls) -> str:
"""Get Object type."""
@abstractmethod
def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
"""Get object content."""
@abstractmethod
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
"""Metadata string."""
@abstractmethod
def set_content(self, value: Any) -> None:
"""Set the content of the node."""
@property
@abstractmethod
def hash(self) -> str:
"""Get hash of node."""
@property
def node_id(self) -> str:
return self.id_
@node_id.setter
def node_id(self, value: str) -> None:
self.id_ = value
@property
def source_node(self) -> Optional[RelatedNodeInfo]:
"""Source object node.
Extracted from the relationships field.
"""
if NodeRelationship.SOURCE not in self.relationships:
return None
relation = self.relationships[NodeRelationship.SOURCE]
if isinstance(relation, list):
raise ValueError("Source object must be a single RelatedNodeInfo object")
return relation
@property
def prev_node(self) -> Optional[RelatedNodeInfo]:
"""Prev node."""
if NodeRelationship.PREVIOUS not in self.relationships:
return None
relation = self.relationships[NodeRelationship.PREVIOUS]
if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Previous object must be a single RelatedNodeInfo object")
return relation
@property
def next_node(self) -> Optional[RelatedNodeInfo]:
"""Next node."""
if NodeRelationship.NEXT not in self.relationships:
return None
relation = self.relationships[NodeRelationship.NEXT]
if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Next object must be a single RelatedNodeInfo object")
return relation
@property
def parent_node(self) -> Optional[RelatedNodeInfo]:
"""Parent node."""
if NodeRelationship.PARENT not in self.relationships:
return None
relation = self.relationships[NodeRelationship.PARENT]
if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Parent object must be a single RelatedNodeInfo object")
return relation
@property
def child_nodes(self) -> Optional[List[RelatedNodeInfo]]:
"""Child nodes."""
if NodeRelationship.CHILD not in self.relationships:
return None
relation = self.relationships[NodeRelationship.CHILD]
if not isinstance(relation, list):
raise ValueError("Child objects must be a list of RelatedNodeInfo objects.")
return relation
@property
def ref_doc_id(self) -> Optional[str]:
"""Deprecated: Get ref doc id."""
source_node = self.source_node
if source_node is None:
return None
return source_node.node_id
@property
def extra_info(self) -> Dict[str, Any]:
"""TODO: DEPRECATED: Extra info."""
return self.metadata
def __str__(self) -> str:
source_text_truncated = truncate_text(
self.get_content().strip(), TRUNCATE_LENGTH
)
source_text_wrapped = textwrap.fill(
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
)
return f"Node ID: {self.node_id}\n{source_text_wrapped}"
def get_embedding(self) -> List[float]:
"""Get embedding.
Errors if embedding is None.
"""
if self.embedding is None:
raise ValueError("embedding not set.")
return self.embedding
def as_related_node_info(self) -> RelatedNodeInfo:
"""Get node as RelatedNodeInfo."""
return RelatedNodeInfo(
node_id=self.node_id,
node_type=self.get_type(),
metadata=self.metadata,
hash=self.hash,
)
class TextNode(BaseNode):
text: str = Field(default="", description="Text content of the node.")
start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the node."
)
end_char_idx: Optional[int] = Field(
default=None, description="End char index of the node."
)
text_template: str = Field(
default=DEFAULT_TEXT_NODE_TMPL,
description=(
"Template for how text is formatted, with {content} and "
"{metadata_str} placeholders."
),
)
metadata_template: str = Field(
default=DEFAULT_METADATA_TMPL,
description=(
"Template for how metadata is formatted, with {key} and "
"{value} placeholders."
),
)
metadata_seperator: str = Field(
default="\n",
description="Separator between metadata fields when converting to string.",
)
@classmethod
def class_name(cls) -> str:
return "TextNode"
@property
def hash(self) -> str:
doc_identity = str(self.text) + str(self.metadata)
return str(sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest())
@classmethod
def get_type(cls) -> str:
"""Get Object type."""
return ObjectType.TEXT
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
"""Get object content."""
metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
if not metadata_str:
return self.text
return self.text_template.format(
content=self.text, metadata_str=metadata_str
).strip()
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
"""Metadata info string."""
if mode == MetadataMode.NONE:
return ""
usable_metadata_keys = set(self.metadata.keys())
if mode == MetadataMode.LLM:
for key in self.excluded_llm_metadata_keys:
if key in usable_metadata_keys:
usable_metadata_keys.remove(key)
elif mode == MetadataMode.EMBED:
for key in self.excluded_embed_metadata_keys:
if key in usable_metadata_keys:
usable_metadata_keys.remove(key)
return self.metadata_seperator.join(
[
self.metadata_template.format(key=key, value=str(value))
for key, value in self.metadata.items()
if key in usable_metadata_keys
]
)
def set_content(self, value: str) -> None:
"""Set the content of the node."""
self.text = value
def get_node_info(self) -> Dict[str, Any]:
"""Get node info."""
return {"start": self.start_char_idx, "end": self.end_char_idx}
def get_text(self) -> str:
return self.get_content(metadata_mode=MetadataMode.NONE)
@property
def node_info(self) -> Dict[str, Any]:
"""Deprecated: Get node info."""
return self.get_node_info()
# TODO: legacy backport of old Node class
Node = TextNode
class ImageNode(TextNode):
"""Node with image."""
# TODO: store reference instead of actual image
# base64 encoded image str
image: Optional[str] = None
image_path: Optional[str] = None
image_url: Optional[str] = None
image_mimetype: Optional[str] = None
text_embedding: Optional[List[float]] = Field(
default=None,
description="Text embedding of image node, if text field is filled out",
)
@classmethod
def get_type(cls) -> str:
return ObjectType.IMAGE
@classmethod
def class_name(cls) -> str:
return "ImageNode"
def resolve_image(self) -> ImageType:
"""Resolve an image such that PIL can read it."""
if self.image is not None:
import base64
return BytesIO(base64.b64decode(self.image))
elif self.image_path is not None:
return self.image_path
elif self.image_url is not None:
# load image from URL
import requests
response = requests.get(self.image_url)
return BytesIO(response.content)
else:
raise ValueError("No image found in node.")
class IndexNode(TextNode):
"""Node with reference to any object.
This can include other indices, query engines, retrievers.
This can also include other nodes (though this is overlapping with `relationships`
on the Node class).
"""
index_id: str
obj: Any = Field(exclude=True)
@classmethod
def from_text_node(
cls,
node: TextNode,
index_id: str,
) -> "IndexNode":
"""Create index node from text node."""
# copy all attributes from text node, add index id
return cls(
**node.dict(),
index_id=index_id,
)
@classmethod
def get_type(cls) -> str:
return ObjectType.INDEX
@classmethod
def class_name(cls) -> str:
return "IndexNode"
class NodeWithScore(BaseComponent):
node: BaseNode
score: Optional[float] = None
def __str__(self) -> str:
score_str = "None" if self.score is None else f"{self.score: 0.3f}"
return f"{self.node}\nScore: {score_str}\n"
def get_score(self, raise_error: bool = False) -> float:
"""Get score."""
if self.score is None:
if raise_error:
raise ValueError("Score not set.")
else:
return 0.0
else:
return self.score
@classmethod
def class_name(cls) -> str:
return "NodeWithScore"
##### pass through methods to BaseNode #####
@property
def node_id(self) -> str:
return self.node.node_id
@property
def id_(self) -> str:
return self.node.id_
@property
def text(self) -> str:
if isinstance(self.node, TextNode):
return self.node.text
else:
raise ValueError("Node must be a TextNode to get text.")
@property
def metadata(self) -> Dict[str, Any]:
return self.node.metadata
@property
def embedding(self) -> Optional[List[float]]:
return self.node.embedding
def get_text(self) -> str:
if isinstance(self.node, TextNode):
return self.node.get_text()
else:
raise ValueError("Node must be a TextNode to get text.")
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
return self.node.get_content(metadata_mode=metadata_mode)
def get_embedding(self) -> List[float]:
return self.node.get_embedding()
# Document Classes for Readers
class Document(TextNode):
"""Generic interface for a data document.
This document connects to data sources.
"""
# TODO: A lot of backwards compatibility logic here, clean up
id_: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID of the node.",
alias="doc_id",
)
_compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
@classmethod
def get_type(cls) -> str:
"""Get Document type."""
return ObjectType.DOCUMENT
@property
def doc_id(self) -> str:
"""Get document ID."""
return self.id_
def __str__(self) -> str:
source_text_truncated = truncate_text(
self.get_content().strip(), TRUNCATE_LENGTH
)
source_text_wrapped = textwrap.fill(
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
)
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
def get_doc_id(self) -> str:
"""TODO: Deprecated: Get document ID."""
return self.id_
def __setattr__(self, name: str, value: object) -> None:
if name in self._compat_fields:
name = self._compat_fields[name]
super().__setattr__(name, value)
def to_langchain_format(self) -> "LCDocument":
"""Convert struct to LangChain document format."""
from llama_index.bridge.langchain import Document as LCDocument
metadata = self.metadata or {}
return LCDocument(page_content=self.text, metadata=metadata)
@classmethod
def from_langchain_format(cls, doc: "LCDocument") -> "Document":
"""Convert struct from LangChain document format."""
return cls(text=doc.page_content, metadata=doc.metadata)
def to_haystack_format(self) -> "HaystackDocument":
"""Convert struct to Haystack document format."""
from haystack.schema import Document as HaystackDocument
return HaystackDocument(
content=self.text, meta=self.metadata, embedding=self.embedding, id=self.id_
)
@classmethod
def from_haystack_format(cls, doc: "HaystackDocument") -> "Document":
"""Convert struct from Haystack document format."""
return cls(
text=doc.content, metadata=doc.meta, embedding=doc.embedding, id_=doc.id
)
def to_embedchain_format(self) -> Dict[str, Any]:
"""Convert struct to EmbedChain document format."""
return {
"doc_id": self.id_,
"data": {"content": self.text, "meta_data": self.metadata},
}
@classmethod
def from_embedchain_format(cls, doc: Dict[str, Any]) -> "Document":
"""Convert struct from EmbedChain document format."""
return cls(
text=doc["data"]["content"],
metadata=doc["data"]["meta_data"],
id_=doc["doc_id"],
)
def to_semantic_kernel_format(self) -> "MemoryRecord":
"""Convert struct to Semantic Kernel document format."""
import numpy as np
from semantic_kernel.memory.memory_record import MemoryRecord
return MemoryRecord(
id=self.id_,
text=self.text,
additional_metadata=self.get_metadata_str(),
embedding=np.array(self.embedding) if self.embedding else None,
)
@classmethod
def from_semantic_kernel_format(cls, doc: "MemoryRecord") -> "Document":
"""Convert struct from Semantic Kernel document format."""
return cls(
text=doc._text,
metadata={"additional_metadata": doc._additional_metadata},
embedding=doc._embedding.tolist() if doc._embedding is not None else None,
id_=doc._id,
)
def to_vectorflow(self, client: Any) -> None:
"""Send a document to vectorflow, since they don't have a document object."""
# write document to temp file
import tempfile
with tempfile.NamedTemporaryFile() as f:
f.write(self.text.encode("utf-8"))
f.flush()
client.embed(f.name)
@classmethod
def example(cls) -> "Document":
return Document(
text=SAMPLE_TEXT,
metadata={"filename": "README.md", "category": "codebase"},
)
@classmethod
def class_name(cls) -> str:
return "Document"
class ImageDocument(Document, ImageNode):
"""Data document containing an image."""
@classmethod
def class_name(cls) -> str:
return "ImageDocument"
@dataclass
class QueryBundle(DataClassJsonMixin):
"""
Query bundle.
This dataclass contains the original query string and associated transformations.
Args:
query_str (str): the original user-specified query string.
This is currently used by all non embedding-based queries.
custom_embedding_strs (list[str]): list of strings used for embedding the query.
This is currently used by all embedding-based queries.
embedding (list[float]): the stored embedding for the query.
"""
query_str: str
# using single image path as query input
image_path: Optional[str] = None
custom_embedding_strs: Optional[List[str]] = None
embedding: Optional[List[float]] = None
@property
def embedding_strs(self) -> List[str]:
"""Use custom embedding strs if specified, otherwise use query str."""
if self.custom_embedding_strs is None:
if len(self.query_str) == 0:
return []
return [self.query_str]
else:
return self.custom_embedding_strs
@property
def embedding_image(self) -> List[ImageType]:
"""Use image path for image retrieval."""
if self.image_path is None:
return []
return [self.image_path]
def __str__(self) -> str:
"""Convert to string representation."""
return self.query_str
QueryType = Union[str, QueryBundle]