faiss_rag_enterprise/llama_index/node_parser/text/code.py

163 lines
5.7 KiB
Python

"""Code splitter."""
from typing import Any, Callable, List, Optional
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks.base import CallbackManager
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.node_parser.interface import TextSplitter
from llama_index.node_parser.node_utils import default_id_func
from llama_index.schema import Document
DEFAULT_CHUNK_LINES = 40
DEFAULT_LINES_OVERLAP = 15
DEFAULT_MAX_CHARS = 1500
class CodeSplitter(TextSplitter):
"""Split code using a AST parser.
Thank you to Kevin Lu / SweepAI for suggesting this elegant code splitting solution.
https://docs.sweep.dev/blogs/chunking-2m-files
"""
language: str = Field(
description="The programming language of the code being split."
)
chunk_lines: int = Field(
default=DEFAULT_CHUNK_LINES,
description="The number of lines to include in each chunk.",
gt=0,
)
chunk_lines_overlap: int = Field(
default=DEFAULT_LINES_OVERLAP,
description="How many lines of code each chunk overlaps with.",
gt=0,
)
max_chars: int = Field(
default=DEFAULT_MAX_CHARS,
description="Maximum number of characters per chunk.",
gt=0,
)
_parser: Any = PrivateAttr()
def __init__(
self,
language: str,
chunk_lines: int = DEFAULT_CHUNK_LINES,
chunk_lines_overlap: int = DEFAULT_LINES_OVERLAP,
max_chars: int = DEFAULT_MAX_CHARS,
parser: Any = None,
callback_manager: Optional[CallbackManager] = None,
include_metadata: bool = True,
include_prev_next_rel: bool = True,
id_func: Optional[Callable[[int, Document], str]] = None,
) -> None:
"""Initialize a CodeSplitter."""
from tree_sitter import Parser
if parser is None:
try:
import tree_sitter_languages
parser = tree_sitter_languages.get_parser(language)
except ImportError:
raise ImportError(
"Please install tree_sitter_languages to use CodeSplitter."
"Or pass in a parser object."
)
except Exception:
print(
f"Could not get parser for language {language}. Check "
"https://github.com/grantjenks/py-tree-sitter-languages#license "
"for a list of valid languages."
)
raise
if not isinstance(parser, Parser):
raise ValueError("Parser must be a tree-sitter Parser object.")
self._parser = parser
callback_manager = callback_manager or CallbackManager([])
id_func = id_func or default_id_func
super().__init__(
language=language,
chunk_lines=chunk_lines,
chunk_lines_overlap=chunk_lines_overlap,
max_chars=max_chars,
callback_manager=callback_manager,
include_metadata=include_metadata,
include_prev_next_rel=include_prev_next_rel,
id_func=id_func,
)
@classmethod
def from_defaults(
cls,
language: str,
chunk_lines: int = DEFAULT_CHUNK_LINES,
chunk_lines_overlap: int = DEFAULT_LINES_OVERLAP,
max_chars: int = DEFAULT_MAX_CHARS,
callback_manager: Optional[CallbackManager] = None,
parser: Any = None,
) -> "CodeSplitter":
"""Create a CodeSplitter with default values."""
return cls(
language=language,
chunk_lines=chunk_lines,
chunk_lines_overlap=chunk_lines_overlap,
max_chars=max_chars,
parser=parser,
)
@classmethod
def class_name(cls) -> str:
return "CodeSplitter"
def _chunk_node(self, node: Any, text: str, last_end: int = 0) -> List[str]:
new_chunks = []
current_chunk = ""
for child in node.children:
if child.end_byte - child.start_byte > self.max_chars:
# Child is too big, recursively chunk the child
if len(current_chunk) > 0:
new_chunks.append(current_chunk)
current_chunk = ""
new_chunks.extend(self._chunk_node(child, text, last_end))
elif (
len(current_chunk) + child.end_byte - child.start_byte > self.max_chars
):
# Child would make the current chunk too big, so start a new chunk
new_chunks.append(current_chunk)
current_chunk = text[last_end : child.end_byte]
else:
current_chunk += text[last_end : child.end_byte]
last_end = child.end_byte
if len(current_chunk) > 0:
new_chunks.append(current_chunk)
return new_chunks
def split_text(self, text: str) -> List[str]:
"""Split incoming code and return chunks using the AST."""
with self.callback_manager.event(
CBEventType.CHUNKING, payload={EventPayload.CHUNKS: [text]}
) as event:
tree = self._parser.parse(bytes(text, "utf-8"))
if (
not tree.root_node.children
or tree.root_node.children[0].type != "ERROR"
):
chunks = [
chunk.strip() for chunk in self._chunk_node(tree.root_node, text)
]
event.on_end(
payload={EventPayload.CHUNKS: chunks},
)
return chunks
else:
raise ValueError(f"Could not parse code with language {self.language}.")
# TODO: set up auto-language detection using something like https://github.com/yoeo/guesslang.