226 lines
8.4 KiB
Python
226 lines
8.4 KiB
Python
"""Token splitter."""
|
|
import logging
|
|
from typing import Callable, List, Optional
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
from llama_index.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
|
|
from llama_index.node_parser.interface import MetadataAwareTextSplitter
|
|
from llama_index.node_parser.node_utils import default_id_func
|
|
from llama_index.node_parser.text.utils import split_by_char, split_by_sep
|
|
from llama_index.schema import Document
|
|
from llama_index.utils import get_tokenizer
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
# NOTE: this is the number of tokens we reserve for metadata formatting
|
|
DEFAULT_METADATA_FORMAT_LEN = 2
|
|
|
|
|
|
class TokenTextSplitter(MetadataAwareTextSplitter):
|
|
"""Implementation of splitting text that looks at word tokens."""
|
|
|
|
chunk_size: int = Field(
|
|
default=DEFAULT_CHUNK_SIZE,
|
|
description="The token chunk size for each chunk.",
|
|
gt=0,
|
|
)
|
|
chunk_overlap: int = Field(
|
|
default=DEFAULT_CHUNK_OVERLAP,
|
|
description="The token overlap of each chunk when splitting.",
|
|
gte=0,
|
|
)
|
|
separator: str = Field(
|
|
default=" ", description="Default separator for splitting into words"
|
|
)
|
|
backup_separators: List = Field(
|
|
default_factory=list, description="Additional separators for splitting."
|
|
)
|
|
|
|
_tokenizer: Callable = PrivateAttr()
|
|
_split_fns: List[Callable] = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
|
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
|
tokenizer: Optional[Callable] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
separator: str = " ",
|
|
backup_separators: Optional[List[str]] = ["\n"],
|
|
include_metadata: bool = True,
|
|
include_prev_next_rel: bool = True,
|
|
id_func: Optional[Callable[[int, Document], str]] = None,
|
|
):
|
|
"""Initialize with parameters."""
|
|
if chunk_overlap > chunk_size:
|
|
raise ValueError(
|
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
|
f"({chunk_size}), should be smaller."
|
|
)
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
id_func = id_func or default_id_func
|
|
self._tokenizer = tokenizer or get_tokenizer()
|
|
|
|
all_seps = [separator] + (backup_separators or [])
|
|
self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()]
|
|
|
|
super().__init__(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
separator=separator,
|
|
backup_separators=backup_separators,
|
|
callback_manager=callback_manager,
|
|
include_metadata=include_metadata,
|
|
include_prev_next_rel=include_prev_next_rel,
|
|
id_func=id_func,
|
|
)
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
|
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
|
separator: str = " ",
|
|
backup_separators: Optional[List[str]] = ["\n"],
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
include_metadata: bool = True,
|
|
include_prev_next_rel: bool = True,
|
|
id_func: Optional[Callable[[int, Document], str]] = None,
|
|
) -> "TokenTextSplitter":
|
|
"""Initialize with default parameters."""
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
return cls(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
separator=separator,
|
|
backup_separators=backup_separators,
|
|
callback_manager=callback_manager,
|
|
include_metadata=include_metadata,
|
|
include_prev_next_rel=include_prev_next_rel,
|
|
id_func=id_func,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "TokenTextSplitter"
|
|
|
|
def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]:
|
|
"""Split text into chunks, reserving space required for metadata str."""
|
|
metadata_len = len(self._tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN
|
|
effective_chunk_size = self.chunk_size - metadata_len
|
|
if effective_chunk_size <= 0:
|
|
raise ValueError(
|
|
f"Metadata length ({metadata_len}) is longer than chunk size "
|
|
f"({self.chunk_size}). Consider increasing the chunk size or "
|
|
"decreasing the size of your metadata to avoid this."
|
|
)
|
|
elif effective_chunk_size < 50:
|
|
print(
|
|
f"Metadata length ({metadata_len}) is close to chunk size "
|
|
f"({self.chunk_size}). Resulting chunks are less than 50 tokens. "
|
|
"Consider increasing the chunk size or decreasing the size of "
|
|
"your metadata to avoid this.",
|
|
flush=True,
|
|
)
|
|
|
|
return self._split_text(text, chunk_size=effective_chunk_size)
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split text into chunks."""
|
|
return self._split_text(text, chunk_size=self.chunk_size)
|
|
|
|
def _split_text(self, text: str, chunk_size: int) -> List[str]:
|
|
"""Split text into chunks up to chunk_size."""
|
|
if text == "":
|
|
return [text]
|
|
|
|
with self.callback_manager.event(
|
|
CBEventType.CHUNKING, payload={EventPayload.CHUNKS: [text]}
|
|
) as event:
|
|
splits = self._split(text, chunk_size)
|
|
chunks = self._merge(splits, chunk_size)
|
|
|
|
event.on_end(
|
|
payload={EventPayload.CHUNKS: chunks},
|
|
)
|
|
|
|
return chunks
|
|
|
|
def _split(self, text: str, chunk_size: int) -> List[str]:
|
|
"""Break text into splits that are smaller than chunk size.
|
|
|
|
The order of splitting is:
|
|
1. split by separator
|
|
2. split by backup separators (if any)
|
|
3. split by characters
|
|
|
|
NOTE: the splits contain the separators.
|
|
"""
|
|
if len(self._tokenizer(text)) <= chunk_size:
|
|
return [text]
|
|
|
|
for split_fn in self._split_fns:
|
|
splits = split_fn(text)
|
|
if len(splits) > 1:
|
|
break
|
|
|
|
new_splits = []
|
|
for split in splits:
|
|
split_len = len(self._tokenizer(split))
|
|
if split_len <= chunk_size:
|
|
new_splits.append(split)
|
|
else:
|
|
# recursively split
|
|
new_splits.extend(self._split(split, chunk_size=chunk_size))
|
|
return new_splits
|
|
|
|
def _merge(self, splits: List[str], chunk_size: int) -> List[str]:
|
|
"""Merge splits into chunks.
|
|
|
|
The high-level idea is to keep adding splits to a chunk until we
|
|
exceed the chunk size, then we start a new chunk with overlap.
|
|
|
|
When we start a new chunk, we pop off the first element of the previous
|
|
chunk until the total length is less than the chunk size.
|
|
"""
|
|
chunks: List[str] = []
|
|
|
|
cur_chunk: List[str] = []
|
|
cur_len = 0
|
|
for split in splits:
|
|
split_len = len(self._tokenizer(split))
|
|
if split_len > chunk_size:
|
|
_logger.warning(
|
|
f"Got a split of size {split_len}, ",
|
|
f"larger than chunk size {chunk_size}.",
|
|
)
|
|
|
|
# if we exceed the chunk size after adding the new split, then
|
|
# we need to end the current chunk and start a new one
|
|
if cur_len + split_len > chunk_size:
|
|
# end the previous chunk
|
|
chunk = "".join(cur_chunk).strip()
|
|
if chunk:
|
|
chunks.append(chunk)
|
|
|
|
# start a new chunk with overlap
|
|
# keep popping off the first element of the previous chunk until:
|
|
# 1. the current chunk length is less than chunk overlap
|
|
# 2. the total length is less than chunk size
|
|
while cur_len > self.chunk_overlap or cur_len + split_len > chunk_size:
|
|
# pop off the first element
|
|
first_chunk = cur_chunk.pop(0)
|
|
cur_len -= len(self._tokenizer(first_chunk))
|
|
|
|
cur_chunk.append(split)
|
|
cur_len += split_len
|
|
|
|
# handle the last chunk
|
|
chunk = "".join(cur_chunk).strip()
|
|
if chunk:
|
|
chunks.append(chunk)
|
|
|
|
return chunks
|