237 lines
8.2 KiB
Python
237 lines
8.2 KiB
Python
from typing import Any, Callable, List, Optional, Sequence, TypedDict
|
|
|
|
import numpy as np
|
|
|
|
from llama_index.bridge.pydantic import Field
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.embeddings.base import BaseEmbedding
|
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
|
from llama_index.node_parser import NodeParser
|
|
from llama_index.node_parser.interface import NodeParser
|
|
from llama_index.node_parser.node_utils import build_nodes_from_splits, default_id_func
|
|
from llama_index.node_parser.text.utils import split_by_sentence_tokenizer
|
|
from llama_index.schema import BaseNode, Document
|
|
from llama_index.utils import get_tqdm_iterable
|
|
|
|
DEFAULT_OG_TEXT_METADATA_KEY = "original_text"
|
|
|
|
|
|
class SentenceCombination(TypedDict):
|
|
sentence: str
|
|
index: int
|
|
combined_sentence: str
|
|
combined_sentence_embedding: List[float]
|
|
|
|
|
|
class SemanticSplitterNodeParser(NodeParser):
|
|
"""Semantic node parser.
|
|
|
|
Splits a document into Nodes, with each node being a group of semantically related sentences.
|
|
|
|
Args:
|
|
buffer_size (int): number of sentences to group together when evaluating semantic similarity
|
|
embed_model: (BaseEmbedding): embedding model to use
|
|
sentence_splitter (Optional[Callable]): splits text into sentences
|
|
include_metadata (bool): whether to include metadata in nodes
|
|
include_prev_next_rel (bool): whether to include prev/next relationships
|
|
"""
|
|
|
|
sentence_splitter: Callable[[str], List[str]] = Field(
|
|
default_factory=split_by_sentence_tokenizer,
|
|
description="The text splitter to use when splitting documents.",
|
|
exclude=True,
|
|
)
|
|
|
|
embed_model: BaseEmbedding = Field(
|
|
description="The embedding model to use to for semantic comparison",
|
|
)
|
|
|
|
buffer_size: int = Field(
|
|
default=1,
|
|
description=(
|
|
"The number of sentences to group together when evaluating semantic similarity. "
|
|
"Set to 1 to consider each sentence individually. "
|
|
"Set to >1 to group sentences together."
|
|
),
|
|
)
|
|
|
|
breakpoint_percentile_threshold = Field(
|
|
default=95,
|
|
description=(
|
|
"The percentile of cosine dissimilarity that must be exceeded between a "
|
|
"group of sentences and the next to form a node. The smaller this "
|
|
"number is, the more nodes will be generated"
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "SemanticSplitterNodeParser"
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
embed_model: Optional[BaseEmbedding] = None,
|
|
breakpoint_percentile_threshold: Optional[int] = 95,
|
|
buffer_size: Optional[int] = 1,
|
|
sentence_splitter: Optional[Callable[[str], List[str]]] = None,
|
|
original_text_metadata_key: str = DEFAULT_OG_TEXT_METADATA_KEY,
|
|
include_metadata: bool = True,
|
|
include_prev_next_rel: bool = True,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
id_func: Optional[Callable[[int, Document], str]] = None,
|
|
) -> "SemanticSplitterNodeParser":
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
|
|
sentence_splitter = sentence_splitter or split_by_sentence_tokenizer()
|
|
embed_model = embed_model or OpenAIEmbedding()
|
|
|
|
id_func = id_func or default_id_func
|
|
|
|
return cls(
|
|
embed_model=embed_model,
|
|
breakpoint_percentile_threshold=breakpoint_percentile_threshold,
|
|
buffer_size=buffer_size,
|
|
sentence_splitter=sentence_splitter,
|
|
original_text_metadata_key=original_text_metadata_key,
|
|
include_metadata=include_metadata,
|
|
include_prev_next_rel=include_prev_next_rel,
|
|
callback_manager=callback_manager,
|
|
id_func=id_func,
|
|
)
|
|
|
|
def _parse_nodes(
|
|
self,
|
|
nodes: Sequence[BaseNode],
|
|
show_progress: bool = False,
|
|
**kwargs: Any,
|
|
) -> List[BaseNode]:
|
|
"""Parse document into nodes."""
|
|
all_nodes: List[BaseNode] = []
|
|
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")
|
|
|
|
for node in nodes_with_progress:
|
|
nodes = self.build_semantic_nodes_from_documents([node], show_progress)
|
|
all_nodes.extend(nodes)
|
|
|
|
return all_nodes
|
|
|
|
def build_semantic_nodes_from_documents(
|
|
self,
|
|
documents: Sequence[Document],
|
|
show_progress: bool = False,
|
|
) -> List[BaseNode]:
|
|
"""Build window nodes from documents."""
|
|
all_nodes: List[BaseNode] = []
|
|
for doc in documents:
|
|
text = doc.text
|
|
text_splits = self.sentence_splitter(text)
|
|
|
|
sentences = self._build_sentence_groups(text_splits)
|
|
|
|
combined_sentence_embeddings = self.embed_model.get_text_embedding_batch(
|
|
[s["combined_sentence"] for s in sentences],
|
|
show_progress=show_progress,
|
|
)
|
|
|
|
for i, embedding in enumerate(combined_sentence_embeddings):
|
|
sentences[i]["combined_sentence_embedding"] = embedding
|
|
|
|
distances = self._calculate_distances_between_sentence_groups(sentences)
|
|
|
|
chunks = self._build_node_chunks(sentences, distances)
|
|
|
|
nodes = build_nodes_from_splits(
|
|
chunks,
|
|
doc,
|
|
id_func=self.id_func,
|
|
)
|
|
|
|
all_nodes.extend(nodes)
|
|
|
|
return all_nodes
|
|
|
|
def _build_sentence_groups(
|
|
self, text_splits: List[str]
|
|
) -> List[SentenceCombination]:
|
|
sentences: List[SentenceCombination] = [
|
|
{
|
|
"sentence": x,
|
|
"index": i,
|
|
"combined_sentence": "",
|
|
"combined_sentence_embedding": [],
|
|
}
|
|
for i, x in enumerate(text_splits)
|
|
]
|
|
|
|
# Group sentences and calculate embeddings for sentence groups
|
|
for i in range(len(sentences)):
|
|
combined_sentence = ""
|
|
|
|
for j in range(i - self.buffer_size, i):
|
|
if j >= 0:
|
|
combined_sentence += sentences[j]["sentence"]
|
|
|
|
combined_sentence += sentences[i]["sentence"]
|
|
|
|
for j in range(i + 1, i + 1 + self.buffer_size):
|
|
if j < len(sentences):
|
|
combined_sentence += sentences[j]["sentence"]
|
|
|
|
sentences[i]["combined_sentence"] = combined_sentence
|
|
|
|
return sentences
|
|
|
|
def _calculate_distances_between_sentence_groups(
|
|
self, sentences: List[SentenceCombination]
|
|
) -> List[float]:
|
|
distances = []
|
|
for i in range(len(sentences) - 1):
|
|
embedding_current = sentences[i]["combined_sentence_embedding"]
|
|
embedding_next = sentences[i + 1]["combined_sentence_embedding"]
|
|
|
|
similarity = self.embed_model.similarity(embedding_current, embedding_next)
|
|
|
|
distance = 1 - similarity
|
|
|
|
distances.append(distance)
|
|
|
|
return distances
|
|
|
|
def _build_node_chunks(
|
|
self, sentences: List[SentenceCombination], distances: List[float]
|
|
) -> List[str]:
|
|
chunks = []
|
|
if len(distances) > 0:
|
|
breakpoint_distance_threshold = np.percentile(
|
|
distances, self.breakpoint_percentile_threshold
|
|
)
|
|
|
|
indices_above_threshold = [
|
|
i for i, x in enumerate(distances) if x > breakpoint_distance_threshold
|
|
]
|
|
|
|
# Chunk sentences into semantic groups based on percentile breakpoints
|
|
start_index = 0
|
|
|
|
for index in indices_above_threshold:
|
|
end_index = index - 1
|
|
|
|
group = sentences[start_index : end_index + 1]
|
|
combined_text = "".join([d["sentence"] for d in group])
|
|
chunks.append(combined_text)
|
|
|
|
start_index = index
|
|
|
|
if start_index < len(sentences):
|
|
combined_text = "".join(
|
|
[d["sentence"] for d in sentences[start_index:]]
|
|
)
|
|
chunks.append(combined_text)
|
|
else:
|
|
# If, for some reason we didn't get any distances (i.e. very, very small documents) just
|
|
# treat the whole document as a single node
|
|
chunks = [" ".join([s["sentence"] for s in sentences])]
|
|
|
|
return chunks
|