184 lines
6.1 KiB
Python
184 lines
6.1 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Constrained decoding with xgrammar backend."""
|
|
|
|
import json
|
|
import logging
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from xgrammar import (
|
|
CompiledGrammar,
|
|
GrammarCompiler,
|
|
GrammarMatcher,
|
|
StructuralTagItem,
|
|
TokenizerInfo,
|
|
allocate_token_bitmask,
|
|
apply_token_bitmask_inplace,
|
|
)
|
|
|
|
from sglang.srt.constrained.base_grammar_backend import (
|
|
BaseGrammarBackend,
|
|
BaseGrammarObject,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
MAX_ROLLBACK_TOKENS = 200
|
|
|
|
|
|
class XGrammarGrammar(BaseGrammarObject):
|
|
|
|
def __init__(
|
|
self,
|
|
matcher: GrammarMatcher,
|
|
vocab_size: int,
|
|
ctx: CompiledGrammar,
|
|
override_stop_tokens: Optional[Union[List[int], int]],
|
|
) -> None:
|
|
self.matcher = matcher
|
|
self.vocab_size = vocab_size
|
|
self.ctx = ctx
|
|
self.override_stop_tokens = override_stop_tokens
|
|
self.finished = False
|
|
|
|
def accept_token(self, token: int):
|
|
assert self.matcher.accept_token(token)
|
|
|
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
|
s = self.matcher.find_jump_forward_string()
|
|
if s:
|
|
return [], s
|
|
return None
|
|
|
|
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
|
_, data = helper
|
|
return data, -1
|
|
|
|
def jump_and_retokenize(
|
|
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
|
):
|
|
k = 0
|
|
for i, old_id in enumerate(old_output_ids):
|
|
if old_id == new_output_ids[i]:
|
|
k = i + 1
|
|
else:
|
|
break
|
|
|
|
# rollback to the last token that is the same
|
|
if k < len(old_output_ids):
|
|
self.matcher.rollback(len(old_output_ids) - k)
|
|
|
|
for i in range(k, len(new_output_ids)):
|
|
assert self.matcher.accept_token(new_output_ids[i])
|
|
|
|
def allocate_vocab_mask(
|
|
self, vocab_size: int, batch_size: int, device
|
|
) -> torch.Tensor:
|
|
return allocate_token_bitmask(batch_size, vocab_size)
|
|
|
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
|
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
|
|
|
@staticmethod
|
|
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
|
return vocab_mask.to(device, non_blocking=True)
|
|
|
|
@staticmethod
|
|
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
|
apply_token_bitmask_inplace(logits, vocab_mask)
|
|
|
|
def copy(self):
|
|
matcher = GrammarMatcher(
|
|
self.ctx,
|
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
|
override_stop_tokens=self.override_stop_tokens,
|
|
)
|
|
return XGrammarGrammar(
|
|
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
|
|
)
|
|
|
|
|
|
class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
def __init__(
|
|
self,
|
|
tokenizer,
|
|
vocab_size: int,
|
|
):
|
|
super().__init__()
|
|
|
|
tokenizer_info = TokenizerInfo.from_huggingface(
|
|
tokenizer, vocab_size=vocab_size
|
|
)
|
|
override_stop_tokens = None
|
|
|
|
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
|
self.vocab_size = vocab_size
|
|
self.override_stop_tokens = override_stop_tokens
|
|
|
|
def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
|
|
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
|
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
|
|
|
|
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
|
try:
|
|
if key_string == "$$ANY$$":
|
|
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
|
else:
|
|
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
|
except RuntimeError as e:
|
|
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
|
|
return None
|
|
return self._from_context(ctx)
|
|
|
|
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
|
try:
|
|
ctx = self.grammar_compiler.compile_grammar(key_string)
|
|
except RuntimeError as e:
|
|
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
|
return None
|
|
return self._from_context(ctx)
|
|
|
|
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
|
try:
|
|
ctx = self.grammar_compiler.compile_regex(key_string)
|
|
except RuntimeError as e:
|
|
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
|
return None
|
|
return self._from_context(ctx)
|
|
|
|
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
|
try:
|
|
structural_tag = json.loads(key_string)
|
|
tags = [
|
|
StructuralTagItem(
|
|
begin=structure["begin"],
|
|
schema=json.dumps(structure["schema"]),
|
|
end=structure["end"],
|
|
)
|
|
for structure in structural_tag["structures"]
|
|
]
|
|
ctx = self.grammar_compiler.compile_structural_tag(
|
|
tags, structural_tag["triggers"]
|
|
)
|
|
except RuntimeError as e:
|
|
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
|
return None
|
|
return self._from_context(ctx)
|
|
|
|
def reset(self):
|
|
if self.grammar_compiler:
|
|
self.grammar_compiler.clear_cache()
|