# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Constrained decoding with outlines backend.""" import json import logging from typing import Dict, List, Optional, Tuple, Union import interegular import torch from outlines.fsm.guide import RegexGuide from outlines.models.transformers import TransformerTokenizer from pydantic import BaseModel from sglang.srt.constrained.base_grammar_backend import ( BaseGrammarBackend, BaseGrammarObject, ) from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap try: from outlines.fsm.json_schema import build_regex_from_schema except ImportError: from outlines_core.fsm.json_schema import build_regex_from_schema logger = logging.getLogger(__name__) class OutlinesGrammar(BaseGrammarObject): def __init__( self, guide: RegexGuide, jump_forward_map: Union[OutlinesJumpForwardMap, None], ) -> None: self.guide = guide self.jump_forward_map = jump_forward_map self.state = 0 self.finished = False def accept_token(self, token: int): self.state = self.guide.get_next_state(self.state, token) def try_jump_forward(self, tokenizer) -> Optional[Tuple]: if not self.jump_forward_map: return None jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state) if jump_forward_bytes is None or len(jump_forward_bytes) <= 1: return None # preprocess the jump forward string suffix_bytes = [] continuation_range = range(0x80, 0xC0) cur_state = self.state while ( len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range ): # continuation bytes byte_edge = jump_forward_bytes.pop(0) suffix_bytes.append(byte_edge[0]) cur_state = byte_edge[1] suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes] suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens) return suffix_ids, cur_state def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: _, cur_state = helper return self.jump_forward_map.jump_forward_symbol(cur_state) def jump_and_retokenize( self, old_output_ids: List[int], new_output_ids: List[int], next_state: int ): self.state = next_state def allocate_vocab_mask( self, vocab_size: int, batch_size: int, device ) -> torch.Tensor: return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device) @staticmethod def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: return vocab_mask def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: tokens = torch.tensor( self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64 ).to(vocab_mask.device, non_blocking=True) vocab_mask = vocab_mask[idx] vocab_mask.fill_(1) vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool)) @staticmethod def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor): logits.masked_fill_(vocab_mask, float("-inf")) def copy(self): return OutlinesGrammar(self.guide, self.jump_forward_map) class OutlinesGrammarBackend(BaseGrammarBackend): def __init__( self, tokenizer, whitespace_pattern: bool, ): super().__init__() try: self.outlines_tokenizer = TransformerTokenizer(tokenizer) except AttributeError: # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) origin_pad_token_id = tokenizer.pad_token_id def fset(self, value): self._value = value type(tokenizer).pad_token_id = property( fget=type(tokenizer).pad_token_id.fget, fset=fset ) self.outlines_tokenizer = TransformerTokenizer(tokenizer) self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id self.outlines_tokenizer.pad_token_id = origin_pad_token_id self.outlines_tokenizer.pad_token = ( self.outlines_tokenizer.tokenizer.pad_token ) self.outlines_tokenizer.vocabulary = ( self.outlines_tokenizer.tokenizer.get_vocab() ) self.whitespace_pattern = whitespace_pattern def _compile_regex(self, regex: str) -> Optional[OutlinesGrammar]: try: if hasattr(RegexGuide, "from_regex"): # outlines >= 0.1.1 guide = RegexGuide.from_regex(regex, self.outlines_tokenizer) else: # outlines <= 0.0.46 guide = RegexGuide(regex, self.outlines_tokenizer) except interegular.patterns.InvalidSyntax as e: logger.warning(f"skip invalid regex schema: {regex=}, {e=}") return None jump_forward_map = None return OutlinesGrammar(guide, jump_forward_map) def dispatch_ebnf(self, key_string: str): return super().dispatch_ebnf(key_string) def dispatch_structural_tag(self, key_string: str): return super().dispatch_structural_tag(key_string) def dispatch_json(self, key_string: str): try: regex = build_regex_from_object( key_string, whitespace_pattern=self.whitespace_pattern, ) except (NotImplementedError, json.decoder.JSONDecodeError) as e: logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}") return self._compile_regex(regex) def dispatch_regex(self, key_string: str): return self._compile_regex(key_string) def build_regex_from_object( object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None ): if isinstance(object, type(BaseModel)): schema = json.dumps(object.model_json_schema()) elif isinstance(object, Dict): schema = json.dumps(object) else: schema = object return build_regex_from_schema(schema, whitespace_pattern)