773 lines
23 KiB
Python
773 lines
23 KiB
Python
"""Base schema for data structures."""
|
|
import json
|
|
import textwrap
|
|
import uuid
|
|
from abc import abstractmethod
|
|
from dataclasses import dataclass
|
|
from enum import Enum, auto
|
|
from hashlib import sha256
|
|
from io import BytesIO
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
|
|
from dataclasses_json import DataClassJsonMixin
|
|
from typing_extensions import Self
|
|
|
|
from llama_index.bridge.pydantic import BaseModel, Field
|
|
from llama_index.utils import SAMPLE_TEXT, truncate_text
|
|
|
|
if TYPE_CHECKING:
|
|
from haystack.schema import Document as HaystackDocument
|
|
from semantic_kernel.memory.memory_record import MemoryRecord
|
|
|
|
from llama_index.bridge.langchain import Document as LCDocument
|
|
|
|
|
|
DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
|
|
DEFAULT_METADATA_TMPL = "{key}: {value}"
|
|
# NOTE: for pretty printing
|
|
TRUNCATE_LENGTH = 350
|
|
WRAP_WIDTH = 70
|
|
|
|
ImageType = Union[str, BytesIO]
|
|
|
|
|
|
class BaseComponent(BaseModel):
|
|
"""Base component object to capture class names."""
|
|
|
|
class Config:
|
|
@staticmethod
|
|
def schema_extra(schema: Dict[str, Any], model: "BaseComponent") -> None:
|
|
"""Add class name to schema."""
|
|
schema["properties"]["class_name"] = {
|
|
"title": "Class Name",
|
|
"type": "string",
|
|
"default": model.class_name(),
|
|
}
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
"""
|
|
Get the class name, used as a unique ID in serialization.
|
|
|
|
This provides a key that makes serialization robust against actual class
|
|
name changes.
|
|
"""
|
|
return "base_component"
|
|
|
|
def json(self, **kwargs: Any) -> str:
|
|
return self.to_json(**kwargs)
|
|
|
|
def dict(self, **kwargs: Any) -> Dict[str, Any]:
|
|
data = super().dict(**kwargs)
|
|
data["class_name"] = self.class_name()
|
|
return data
|
|
|
|
def __getstate__(self) -> Dict[str, Any]:
|
|
state = super().__getstate__()
|
|
|
|
# tiktoken is not pickleable
|
|
# state["__dict__"] = self.dict()
|
|
state["__dict__"].pop("tokenizer", None)
|
|
|
|
# remove local functions
|
|
keys_to_remove = []
|
|
for key, val in state["__dict__"].items():
|
|
if key.endswith("_fn"):
|
|
keys_to_remove.append(key)
|
|
if "<lambda>" in str(val):
|
|
keys_to_remove.append(key)
|
|
for key in keys_to_remove:
|
|
state["__dict__"].pop(key, None)
|
|
|
|
# remove private attributes -- kind of dangerous
|
|
state["__private_attribute_values__"] = {}
|
|
|
|
return state
|
|
|
|
def __setstate__(self, state: Dict[str, Any]) -> None:
|
|
# Use the __dict__ and __init__ method to set state
|
|
# so that all variable initialize
|
|
try:
|
|
self.__init__(**state["__dict__"]) # type: ignore
|
|
except Exception:
|
|
# Fall back to the default __setstate__ method
|
|
super().__setstate__(state)
|
|
|
|
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
|
|
data = self.dict(**kwargs)
|
|
data["class_name"] = self.class_name()
|
|
return data
|
|
|
|
def to_json(self, **kwargs: Any) -> str:
|
|
data = self.to_dict(**kwargs)
|
|
return json.dumps(data)
|
|
|
|
# TODO: return type here not supported by current mypy version
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
|
|
if isinstance(kwargs, dict):
|
|
data.update(kwargs)
|
|
|
|
data.pop("class_name", None)
|
|
return cls(**data)
|
|
|
|
@classmethod
|
|
def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore
|
|
data = json.loads(data_str)
|
|
return cls.from_dict(data, **kwargs)
|
|
|
|
|
|
class TransformComponent(BaseComponent):
|
|
"""Base class for transform components."""
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
@abstractmethod
|
|
def __call__(self, nodes: List["BaseNode"], **kwargs: Any) -> List["BaseNode"]:
|
|
"""Transform nodes."""
|
|
|
|
async def acall(self, nodes: List["BaseNode"], **kwargs: Any) -> List["BaseNode"]:
|
|
"""Async transform nodes."""
|
|
return self.__call__(nodes, **kwargs)
|
|
|
|
|
|
class NodeRelationship(str, Enum):
|
|
"""Node relationships used in `BaseNode` class.
|
|
|
|
Attributes:
|
|
SOURCE: The node is the source document.
|
|
PREVIOUS: The node is the previous node in the document.
|
|
NEXT: The node is the next node in the document.
|
|
PARENT: The node is the parent node in the document.
|
|
CHILD: The node is a child node in the document.
|
|
|
|
"""
|
|
|
|
SOURCE = auto()
|
|
PREVIOUS = auto()
|
|
NEXT = auto()
|
|
PARENT = auto()
|
|
CHILD = auto()
|
|
|
|
|
|
class ObjectType(str, Enum):
|
|
TEXT = auto()
|
|
IMAGE = auto()
|
|
INDEX = auto()
|
|
DOCUMENT = auto()
|
|
|
|
|
|
class MetadataMode(str, Enum):
|
|
ALL = "all"
|
|
EMBED = "embed"
|
|
LLM = "llm"
|
|
NONE = "none"
|
|
|
|
|
|
class RelatedNodeInfo(BaseComponent):
|
|
node_id: str
|
|
node_type: Optional[ObjectType] = None
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
hash: Optional[str] = None
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "RelatedNodeInfo"
|
|
|
|
|
|
RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]]
|
|
|
|
|
|
# Node classes for indexes
|
|
class BaseNode(BaseComponent):
|
|
"""Base node Object.
|
|
|
|
Generic abstract interface for retrievable nodes
|
|
|
|
"""
|
|
|
|
class Config:
|
|
allow_population_by_field_name = True
|
|
# hash is computed on local field, during the validation process
|
|
validate_assignment = True
|
|
|
|
id_: str = Field(
|
|
default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node."
|
|
)
|
|
embedding: Optional[List[float]] = Field(
|
|
default=None, description="Embedding of the node."
|
|
)
|
|
|
|
""""
|
|
metadata fields
|
|
- injected as part of the text shown to LLMs as context
|
|
- injected as part of the text for generating embeddings
|
|
- used by vector DBs for metadata filtering
|
|
|
|
"""
|
|
metadata: Dict[str, Any] = Field(
|
|
default_factory=dict,
|
|
description="A flat dictionary of metadata fields",
|
|
alias="extra_info",
|
|
)
|
|
excluded_embed_metadata_keys: List[str] = Field(
|
|
default_factory=list,
|
|
description="Metadata keys that are excluded from text for the embed model.",
|
|
)
|
|
excluded_llm_metadata_keys: List[str] = Field(
|
|
default_factory=list,
|
|
description="Metadata keys that are excluded from text for the LLM.",
|
|
)
|
|
relationships: Dict[NodeRelationship, RelatedNodeType] = Field(
|
|
default_factory=dict,
|
|
description="A mapping of relationships to other node information.",
|
|
)
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def get_type(cls) -> str:
|
|
"""Get Object type."""
|
|
|
|
@abstractmethod
|
|
def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
|
|
"""Get object content."""
|
|
|
|
@abstractmethod
|
|
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
|
|
"""Metadata string."""
|
|
|
|
@abstractmethod
|
|
def set_content(self, value: Any) -> None:
|
|
"""Set the content of the node."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def hash(self) -> str:
|
|
"""Get hash of node."""
|
|
|
|
@property
|
|
def node_id(self) -> str:
|
|
return self.id_
|
|
|
|
@node_id.setter
|
|
def node_id(self, value: str) -> None:
|
|
self.id_ = value
|
|
|
|
@property
|
|
def source_node(self) -> Optional[RelatedNodeInfo]:
|
|
"""Source object node.
|
|
|
|
Extracted from the relationships field.
|
|
|
|
"""
|
|
if NodeRelationship.SOURCE not in self.relationships:
|
|
return None
|
|
|
|
relation = self.relationships[NodeRelationship.SOURCE]
|
|
if isinstance(relation, list):
|
|
raise ValueError("Source object must be a single RelatedNodeInfo object")
|
|
return relation
|
|
|
|
@property
|
|
def prev_node(self) -> Optional[RelatedNodeInfo]:
|
|
"""Prev node."""
|
|
if NodeRelationship.PREVIOUS not in self.relationships:
|
|
return None
|
|
|
|
relation = self.relationships[NodeRelationship.PREVIOUS]
|
|
if not isinstance(relation, RelatedNodeInfo):
|
|
raise ValueError("Previous object must be a single RelatedNodeInfo object")
|
|
return relation
|
|
|
|
@property
|
|
def next_node(self) -> Optional[RelatedNodeInfo]:
|
|
"""Next node."""
|
|
if NodeRelationship.NEXT not in self.relationships:
|
|
return None
|
|
|
|
relation = self.relationships[NodeRelationship.NEXT]
|
|
if not isinstance(relation, RelatedNodeInfo):
|
|
raise ValueError("Next object must be a single RelatedNodeInfo object")
|
|
return relation
|
|
|
|
@property
|
|
def parent_node(self) -> Optional[RelatedNodeInfo]:
|
|
"""Parent node."""
|
|
if NodeRelationship.PARENT not in self.relationships:
|
|
return None
|
|
|
|
relation = self.relationships[NodeRelationship.PARENT]
|
|
if not isinstance(relation, RelatedNodeInfo):
|
|
raise ValueError("Parent object must be a single RelatedNodeInfo object")
|
|
return relation
|
|
|
|
@property
|
|
def child_nodes(self) -> Optional[List[RelatedNodeInfo]]:
|
|
"""Child nodes."""
|
|
if NodeRelationship.CHILD not in self.relationships:
|
|
return None
|
|
|
|
relation = self.relationships[NodeRelationship.CHILD]
|
|
if not isinstance(relation, list):
|
|
raise ValueError("Child objects must be a list of RelatedNodeInfo objects.")
|
|
return relation
|
|
|
|
@property
|
|
def ref_doc_id(self) -> Optional[str]:
|
|
"""Deprecated: Get ref doc id."""
|
|
source_node = self.source_node
|
|
if source_node is None:
|
|
return None
|
|
return source_node.node_id
|
|
|
|
@property
|
|
def extra_info(self) -> Dict[str, Any]:
|
|
"""TODO: DEPRECATED: Extra info."""
|
|
return self.metadata
|
|
|
|
def __str__(self) -> str:
|
|
source_text_truncated = truncate_text(
|
|
self.get_content().strip(), TRUNCATE_LENGTH
|
|
)
|
|
source_text_wrapped = textwrap.fill(
|
|
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
|
|
)
|
|
return f"Node ID: {self.node_id}\n{source_text_wrapped}"
|
|
|
|
def get_embedding(self) -> List[float]:
|
|
"""Get embedding.
|
|
|
|
Errors if embedding is None.
|
|
|
|
"""
|
|
if self.embedding is None:
|
|
raise ValueError("embedding not set.")
|
|
return self.embedding
|
|
|
|
def as_related_node_info(self) -> RelatedNodeInfo:
|
|
"""Get node as RelatedNodeInfo."""
|
|
return RelatedNodeInfo(
|
|
node_id=self.node_id,
|
|
node_type=self.get_type(),
|
|
metadata=self.metadata,
|
|
hash=self.hash,
|
|
)
|
|
|
|
|
|
class TextNode(BaseNode):
|
|
text: str = Field(default="", description="Text content of the node.")
|
|
start_char_idx: Optional[int] = Field(
|
|
default=None, description="Start char index of the node."
|
|
)
|
|
end_char_idx: Optional[int] = Field(
|
|
default=None, description="End char index of the node."
|
|
)
|
|
text_template: str = Field(
|
|
default=DEFAULT_TEXT_NODE_TMPL,
|
|
description=(
|
|
"Template for how text is formatted, with {content} and "
|
|
"{metadata_str} placeholders."
|
|
),
|
|
)
|
|
metadata_template: str = Field(
|
|
default=DEFAULT_METADATA_TMPL,
|
|
description=(
|
|
"Template for how metadata is formatted, with {key} and "
|
|
"{value} placeholders."
|
|
),
|
|
)
|
|
metadata_seperator: str = Field(
|
|
default="\n",
|
|
description="Separator between metadata fields when converting to string.",
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "TextNode"
|
|
|
|
@property
|
|
def hash(self) -> str:
|
|
doc_identity = str(self.text) + str(self.metadata)
|
|
return str(sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest())
|
|
|
|
@classmethod
|
|
def get_type(cls) -> str:
|
|
"""Get Object type."""
|
|
return ObjectType.TEXT
|
|
|
|
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
|
|
"""Get object content."""
|
|
metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
|
|
if not metadata_str:
|
|
return self.text
|
|
|
|
return self.text_template.format(
|
|
content=self.text, metadata_str=metadata_str
|
|
).strip()
|
|
|
|
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
|
|
"""Metadata info string."""
|
|
if mode == MetadataMode.NONE:
|
|
return ""
|
|
|
|
usable_metadata_keys = set(self.metadata.keys())
|
|
if mode == MetadataMode.LLM:
|
|
for key in self.excluded_llm_metadata_keys:
|
|
if key in usable_metadata_keys:
|
|
usable_metadata_keys.remove(key)
|
|
elif mode == MetadataMode.EMBED:
|
|
for key in self.excluded_embed_metadata_keys:
|
|
if key in usable_metadata_keys:
|
|
usable_metadata_keys.remove(key)
|
|
|
|
return self.metadata_seperator.join(
|
|
[
|
|
self.metadata_template.format(key=key, value=str(value))
|
|
for key, value in self.metadata.items()
|
|
if key in usable_metadata_keys
|
|
]
|
|
)
|
|
|
|
def set_content(self, value: str) -> None:
|
|
"""Set the content of the node."""
|
|
self.text = value
|
|
|
|
def get_node_info(self) -> Dict[str, Any]:
|
|
"""Get node info."""
|
|
return {"start": self.start_char_idx, "end": self.end_char_idx}
|
|
|
|
def get_text(self) -> str:
|
|
return self.get_content(metadata_mode=MetadataMode.NONE)
|
|
|
|
@property
|
|
def node_info(self) -> Dict[str, Any]:
|
|
"""Deprecated: Get node info."""
|
|
return self.get_node_info()
|
|
|
|
|
|
# TODO: legacy backport of old Node class
|
|
Node = TextNode
|
|
|
|
|
|
class ImageNode(TextNode):
|
|
"""Node with image."""
|
|
|
|
# TODO: store reference instead of actual image
|
|
# base64 encoded image str
|
|
image: Optional[str] = None
|
|
image_path: Optional[str] = None
|
|
image_url: Optional[str] = None
|
|
image_mimetype: Optional[str] = None
|
|
text_embedding: Optional[List[float]] = Field(
|
|
default=None,
|
|
description="Text embedding of image node, if text field is filled out",
|
|
)
|
|
|
|
@classmethod
|
|
def get_type(cls) -> str:
|
|
return ObjectType.IMAGE
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "ImageNode"
|
|
|
|
def resolve_image(self) -> ImageType:
|
|
"""Resolve an image such that PIL can read it."""
|
|
if self.image is not None:
|
|
import base64
|
|
|
|
return BytesIO(base64.b64decode(self.image))
|
|
elif self.image_path is not None:
|
|
return self.image_path
|
|
elif self.image_url is not None:
|
|
# load image from URL
|
|
import requests
|
|
|
|
response = requests.get(self.image_url)
|
|
return BytesIO(response.content)
|
|
else:
|
|
raise ValueError("No image found in node.")
|
|
|
|
|
|
class IndexNode(TextNode):
|
|
"""Node with reference to any object.
|
|
|
|
This can include other indices, query engines, retrievers.
|
|
|
|
This can also include other nodes (though this is overlapping with `relationships`
|
|
on the Node class).
|
|
|
|
"""
|
|
|
|
index_id: str
|
|
obj: Any = Field(exclude=True)
|
|
|
|
@classmethod
|
|
def from_text_node(
|
|
cls,
|
|
node: TextNode,
|
|
index_id: str,
|
|
) -> "IndexNode":
|
|
"""Create index node from text node."""
|
|
# copy all attributes from text node, add index id
|
|
return cls(
|
|
**node.dict(),
|
|
index_id=index_id,
|
|
)
|
|
|
|
@classmethod
|
|
def get_type(cls) -> str:
|
|
return ObjectType.INDEX
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "IndexNode"
|
|
|
|
|
|
class NodeWithScore(BaseComponent):
|
|
node: BaseNode
|
|
score: Optional[float] = None
|
|
|
|
def __str__(self) -> str:
|
|
score_str = "None" if self.score is None else f"{self.score: 0.3f}"
|
|
return f"{self.node}\nScore: {score_str}\n"
|
|
|
|
def get_score(self, raise_error: bool = False) -> float:
|
|
"""Get score."""
|
|
if self.score is None:
|
|
if raise_error:
|
|
raise ValueError("Score not set.")
|
|
else:
|
|
return 0.0
|
|
else:
|
|
return self.score
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "NodeWithScore"
|
|
|
|
##### pass through methods to BaseNode #####
|
|
@property
|
|
def node_id(self) -> str:
|
|
return self.node.node_id
|
|
|
|
@property
|
|
def id_(self) -> str:
|
|
return self.node.id_
|
|
|
|
@property
|
|
def text(self) -> str:
|
|
if isinstance(self.node, TextNode):
|
|
return self.node.text
|
|
else:
|
|
raise ValueError("Node must be a TextNode to get text.")
|
|
|
|
@property
|
|
def metadata(self) -> Dict[str, Any]:
|
|
return self.node.metadata
|
|
|
|
@property
|
|
def embedding(self) -> Optional[List[float]]:
|
|
return self.node.embedding
|
|
|
|
def get_text(self) -> str:
|
|
if isinstance(self.node, TextNode):
|
|
return self.node.get_text()
|
|
else:
|
|
raise ValueError("Node must be a TextNode to get text.")
|
|
|
|
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
|
|
return self.node.get_content(metadata_mode=metadata_mode)
|
|
|
|
def get_embedding(self) -> List[float]:
|
|
return self.node.get_embedding()
|
|
|
|
|
|
# Document Classes for Readers
|
|
|
|
|
|
class Document(TextNode):
|
|
"""Generic interface for a data document.
|
|
|
|
This document connects to data sources.
|
|
|
|
"""
|
|
|
|
# TODO: A lot of backwards compatibility logic here, clean up
|
|
id_: str = Field(
|
|
default_factory=lambda: str(uuid.uuid4()),
|
|
description="Unique ID of the node.",
|
|
alias="doc_id",
|
|
)
|
|
|
|
_compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
|
|
|
|
@classmethod
|
|
def get_type(cls) -> str:
|
|
"""Get Document type."""
|
|
return ObjectType.DOCUMENT
|
|
|
|
@property
|
|
def doc_id(self) -> str:
|
|
"""Get document ID."""
|
|
return self.id_
|
|
|
|
def __str__(self) -> str:
|
|
source_text_truncated = truncate_text(
|
|
self.get_content().strip(), TRUNCATE_LENGTH
|
|
)
|
|
source_text_wrapped = textwrap.fill(
|
|
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
|
|
)
|
|
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
|
|
|
|
def get_doc_id(self) -> str:
|
|
"""TODO: Deprecated: Get document ID."""
|
|
return self.id_
|
|
|
|
def __setattr__(self, name: str, value: object) -> None:
|
|
if name in self._compat_fields:
|
|
name = self._compat_fields[name]
|
|
super().__setattr__(name, value)
|
|
|
|
def to_langchain_format(self) -> "LCDocument":
|
|
"""Convert struct to LangChain document format."""
|
|
from llama_index.bridge.langchain import Document as LCDocument
|
|
|
|
metadata = self.metadata or {}
|
|
return LCDocument(page_content=self.text, metadata=metadata)
|
|
|
|
@classmethod
|
|
def from_langchain_format(cls, doc: "LCDocument") -> "Document":
|
|
"""Convert struct from LangChain document format."""
|
|
return cls(text=doc.page_content, metadata=doc.metadata)
|
|
|
|
def to_haystack_format(self) -> "HaystackDocument":
|
|
"""Convert struct to Haystack document format."""
|
|
from haystack.schema import Document as HaystackDocument
|
|
|
|
return HaystackDocument(
|
|
content=self.text, meta=self.metadata, embedding=self.embedding, id=self.id_
|
|
)
|
|
|
|
@classmethod
|
|
def from_haystack_format(cls, doc: "HaystackDocument") -> "Document":
|
|
"""Convert struct from Haystack document format."""
|
|
return cls(
|
|
text=doc.content, metadata=doc.meta, embedding=doc.embedding, id_=doc.id
|
|
)
|
|
|
|
def to_embedchain_format(self) -> Dict[str, Any]:
|
|
"""Convert struct to EmbedChain document format."""
|
|
return {
|
|
"doc_id": self.id_,
|
|
"data": {"content": self.text, "meta_data": self.metadata},
|
|
}
|
|
|
|
@classmethod
|
|
def from_embedchain_format(cls, doc: Dict[str, Any]) -> "Document":
|
|
"""Convert struct from EmbedChain document format."""
|
|
return cls(
|
|
text=doc["data"]["content"],
|
|
metadata=doc["data"]["meta_data"],
|
|
id_=doc["doc_id"],
|
|
)
|
|
|
|
def to_semantic_kernel_format(self) -> "MemoryRecord":
|
|
"""Convert struct to Semantic Kernel document format."""
|
|
import numpy as np
|
|
from semantic_kernel.memory.memory_record import MemoryRecord
|
|
|
|
return MemoryRecord(
|
|
id=self.id_,
|
|
text=self.text,
|
|
additional_metadata=self.get_metadata_str(),
|
|
embedding=np.array(self.embedding) if self.embedding else None,
|
|
)
|
|
|
|
@classmethod
|
|
def from_semantic_kernel_format(cls, doc: "MemoryRecord") -> "Document":
|
|
"""Convert struct from Semantic Kernel document format."""
|
|
return cls(
|
|
text=doc._text,
|
|
metadata={"additional_metadata": doc._additional_metadata},
|
|
embedding=doc._embedding.tolist() if doc._embedding is not None else None,
|
|
id_=doc._id,
|
|
)
|
|
|
|
def to_vectorflow(self, client: Any) -> None:
|
|
"""Send a document to vectorflow, since they don't have a document object."""
|
|
# write document to temp file
|
|
import tempfile
|
|
|
|
with tempfile.NamedTemporaryFile() as f:
|
|
f.write(self.text.encode("utf-8"))
|
|
f.flush()
|
|
client.embed(f.name)
|
|
|
|
@classmethod
|
|
def example(cls) -> "Document":
|
|
return Document(
|
|
text=SAMPLE_TEXT,
|
|
metadata={"filename": "README.md", "category": "codebase"},
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "Document"
|
|
|
|
|
|
class ImageDocument(Document, ImageNode):
|
|
"""Data document containing an image."""
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "ImageDocument"
|
|
|
|
|
|
@dataclass
|
|
class QueryBundle(DataClassJsonMixin):
|
|
"""
|
|
Query bundle.
|
|
|
|
This dataclass contains the original query string and associated transformations.
|
|
|
|
Args:
|
|
query_str (str): the original user-specified query string.
|
|
This is currently used by all non embedding-based queries.
|
|
custom_embedding_strs (list[str]): list of strings used for embedding the query.
|
|
This is currently used by all embedding-based queries.
|
|
embedding (list[float]): the stored embedding for the query.
|
|
"""
|
|
|
|
query_str: str
|
|
# using single image path as query input
|
|
image_path: Optional[str] = None
|
|
custom_embedding_strs: Optional[List[str]] = None
|
|
embedding: Optional[List[float]] = None
|
|
|
|
@property
|
|
def embedding_strs(self) -> List[str]:
|
|
"""Use custom embedding strs if specified, otherwise use query str."""
|
|
if self.custom_embedding_strs is None:
|
|
if len(self.query_str) == 0:
|
|
return []
|
|
return [self.query_str]
|
|
else:
|
|
return self.custom_embedding_strs
|
|
|
|
@property
|
|
def embedding_image(self) -> List[ImageType]:
|
|
"""Use image path for image retrieval."""
|
|
if self.image_path is None:
|
|
return []
|
|
return [self.image_path]
|
|
|
|
def __str__(self) -> str:
|
|
"""Convert to string representation."""
|
|
return self.query_str
|
|
|
|
|
|
QueryType = Union[str, QueryBundle]
|