163 lines
5.7 KiB
Python
163 lines
5.7 KiB
Python
"""Code splitter."""
|
|
from typing import Any, Callable, List, Optional
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
from llama_index.node_parser.interface import TextSplitter
|
|
from llama_index.node_parser.node_utils import default_id_func
|
|
from llama_index.schema import Document
|
|
|
|
DEFAULT_CHUNK_LINES = 40
|
|
DEFAULT_LINES_OVERLAP = 15
|
|
DEFAULT_MAX_CHARS = 1500
|
|
|
|
|
|
class CodeSplitter(TextSplitter):
|
|
"""Split code using a AST parser.
|
|
|
|
Thank you to Kevin Lu / SweepAI for suggesting this elegant code splitting solution.
|
|
https://docs.sweep.dev/blogs/chunking-2m-files
|
|
"""
|
|
|
|
language: str = Field(
|
|
description="The programming language of the code being split."
|
|
)
|
|
chunk_lines: int = Field(
|
|
default=DEFAULT_CHUNK_LINES,
|
|
description="The number of lines to include in each chunk.",
|
|
gt=0,
|
|
)
|
|
chunk_lines_overlap: int = Field(
|
|
default=DEFAULT_LINES_OVERLAP,
|
|
description="How many lines of code each chunk overlaps with.",
|
|
gt=0,
|
|
)
|
|
max_chars: int = Field(
|
|
default=DEFAULT_MAX_CHARS,
|
|
description="Maximum number of characters per chunk.",
|
|
gt=0,
|
|
)
|
|
_parser: Any = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
language: str,
|
|
chunk_lines: int = DEFAULT_CHUNK_LINES,
|
|
chunk_lines_overlap: int = DEFAULT_LINES_OVERLAP,
|
|
max_chars: int = DEFAULT_MAX_CHARS,
|
|
parser: Any = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
include_metadata: bool = True,
|
|
include_prev_next_rel: bool = True,
|
|
id_func: Optional[Callable[[int, Document], str]] = None,
|
|
) -> None:
|
|
"""Initialize a CodeSplitter."""
|
|
from tree_sitter import Parser
|
|
|
|
if parser is None:
|
|
try:
|
|
import tree_sitter_languages
|
|
|
|
parser = tree_sitter_languages.get_parser(language)
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install tree_sitter_languages to use CodeSplitter."
|
|
"Or pass in a parser object."
|
|
)
|
|
except Exception:
|
|
print(
|
|
f"Could not get parser for language {language}. Check "
|
|
"https://github.com/grantjenks/py-tree-sitter-languages#license "
|
|
"for a list of valid languages."
|
|
)
|
|
raise
|
|
if not isinstance(parser, Parser):
|
|
raise ValueError("Parser must be a tree-sitter Parser object.")
|
|
|
|
self._parser = parser
|
|
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
id_func = id_func or default_id_func
|
|
|
|
super().__init__(
|
|
language=language,
|
|
chunk_lines=chunk_lines,
|
|
chunk_lines_overlap=chunk_lines_overlap,
|
|
max_chars=max_chars,
|
|
callback_manager=callback_manager,
|
|
include_metadata=include_metadata,
|
|
include_prev_next_rel=include_prev_next_rel,
|
|
id_func=id_func,
|
|
)
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
language: str,
|
|
chunk_lines: int = DEFAULT_CHUNK_LINES,
|
|
chunk_lines_overlap: int = DEFAULT_LINES_OVERLAP,
|
|
max_chars: int = DEFAULT_MAX_CHARS,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
parser: Any = None,
|
|
) -> "CodeSplitter":
|
|
"""Create a CodeSplitter with default values."""
|
|
return cls(
|
|
language=language,
|
|
chunk_lines=chunk_lines,
|
|
chunk_lines_overlap=chunk_lines_overlap,
|
|
max_chars=max_chars,
|
|
parser=parser,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "CodeSplitter"
|
|
|
|
def _chunk_node(self, node: Any, text: str, last_end: int = 0) -> List[str]:
|
|
new_chunks = []
|
|
current_chunk = ""
|
|
for child in node.children:
|
|
if child.end_byte - child.start_byte > self.max_chars:
|
|
# Child is too big, recursively chunk the child
|
|
if len(current_chunk) > 0:
|
|
new_chunks.append(current_chunk)
|
|
current_chunk = ""
|
|
new_chunks.extend(self._chunk_node(child, text, last_end))
|
|
elif (
|
|
len(current_chunk) + child.end_byte - child.start_byte > self.max_chars
|
|
):
|
|
# Child would make the current chunk too big, so start a new chunk
|
|
new_chunks.append(current_chunk)
|
|
current_chunk = text[last_end : child.end_byte]
|
|
else:
|
|
current_chunk += text[last_end : child.end_byte]
|
|
last_end = child.end_byte
|
|
if len(current_chunk) > 0:
|
|
new_chunks.append(current_chunk)
|
|
return new_chunks
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split incoming code and return chunks using the AST."""
|
|
with self.callback_manager.event(
|
|
CBEventType.CHUNKING, payload={EventPayload.CHUNKS: [text]}
|
|
) as event:
|
|
tree = self._parser.parse(bytes(text, "utf-8"))
|
|
|
|
if (
|
|
not tree.root_node.children
|
|
or tree.root_node.children[0].type != "ERROR"
|
|
):
|
|
chunks = [
|
|
chunk.strip() for chunk in self._chunk_node(tree.root_node, text)
|
|
]
|
|
event.on_end(
|
|
payload={EventPayload.CHUNKS: chunks},
|
|
)
|
|
|
|
return chunks
|
|
else:
|
|
raise ValueError(f"Could not parse code with language {self.language}.")
|
|
|
|
# TODO: set up auto-language detection using something like https://github.com/yoeo/guesslang.
|