sglang0.4.5.post1/python/sglang/srt/constrained/xgrammar_backend.py

184 lines
6.1 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constrained decoding with xgrammar backend."""
import json
import logging
from typing import List, Optional, Tuple, Union
import torch
from xgrammar import (
CompiledGrammar,
GrammarCompiler,
GrammarMatcher,
StructuralTagItem,
TokenizerInfo,
allocate_token_bitmask,
apply_token_bitmask_inplace,
)
from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend,
BaseGrammarObject,
)
logger = logging.getLogger(__name__)
MAX_ROLLBACK_TOKENS = 200
class XGrammarGrammar(BaseGrammarObject):
def __init__(
self,
matcher: GrammarMatcher,
vocab_size: int,
ctx: CompiledGrammar,
override_stop_tokens: Optional[Union[List[int], int]],
) -> None:
self.matcher = matcher
self.vocab_size = vocab_size
self.ctx = ctx
self.override_stop_tokens = override_stop_tokens
self.finished = False
def accept_token(self, token: int):
assert self.matcher.accept_token(token)
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
s = self.matcher.find_jump_forward_string()
if s:
return [], s
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
_, data = helper
return data, -1
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == new_output_ids[i]:
k = i + 1
else:
break
# rollback to the last token that is the same
if k < len(old_output_ids):
self.matcher.rollback(len(old_output_ids) - k)
for i in range(k, len(new_output_ids)):
assert self.matcher.accept_token(new_output_ids[i])
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
return allocate_token_bitmask(batch_size, vocab_size)
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask.to(device, non_blocking=True)
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
apply_token_bitmask_inplace(logits, vocab_mask)
def copy(self):
matcher = GrammarMatcher(
self.ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
override_stop_tokens=self.override_stop_tokens,
)
return XGrammarGrammar(
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
)
class XGrammarGrammarBackend(BaseGrammarBackend):
def __init__(
self,
tokenizer,
vocab_size: int,
):
super().__init__()
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size
)
override_stop_tokens = None
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size
self.override_stop_tokens = override_stop_tokens
def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
if key_string == "$$ANY$$":
ctx = self.grammar_compiler.compile_builtin_json_grammar()
else:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
return None
return self._from_context(ctx)
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
ctx = self.grammar_compiler.compile_grammar(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
return None
return self._from_context(ctx)
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
ctx = self.grammar_compiler.compile_regex(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
return self._from_context(ctx)
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
structural_tag = json.loads(key_string)
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
)
for structure in structural_tag["structures"]
]
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
return self._from_context(ctx)
def reset(self):
if self.grammar_compiler:
self.grammar_compiler.clear_cache()