sglang0.4.5.post1/python/sglang/srt/constrained/outlines_backend.py

189 lines
6.7 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constrained decoding with outlines backend."""
import json
import logging
from typing import Dict, List, Optional, Tuple, Union
import interegular
import torch
from outlines.fsm.guide import RegexGuide
from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel
from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend,
BaseGrammarObject,
)
from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
try:
from outlines.fsm.json_schema import build_regex_from_schema
except ImportError:
from outlines_core.fsm.json_schema import build_regex_from_schema
logger = logging.getLogger(__name__)
class OutlinesGrammar(BaseGrammarObject):
def __init__(
self,
guide: RegexGuide,
jump_forward_map: Union[OutlinesJumpForwardMap, None],
) -> None:
self.guide = guide
self.jump_forward_map = jump_forward_map
self.state = 0
self.finished = False
def accept_token(self, token: int):
self.state = self.guide.get_next_state(self.state, token)
def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
if not self.jump_forward_map:
return None
jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state)
if jump_forward_bytes is None or len(jump_forward_bytes) <= 1:
return None
# preprocess the jump forward string
suffix_bytes = []
continuation_range = range(0x80, 0xC0)
cur_state = self.state
while (
len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range
):
# continuation bytes
byte_edge = jump_forward_bytes.pop(0)
suffix_bytes.append(byte_edge[0])
cur_state = byte_edge[1]
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
return suffix_ids, cur_state
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
_, cur_state = helper
return self.jump_forward_map.jump_forward_symbol(cur_state)
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
self.state = next_state
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
tokens = torch.tensor(
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
).to(vocab_mask.device, non_blocking=True)
vocab_mask = vocab_mask[idx]
vocab_mask.fill_(1)
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
logits.masked_fill_(vocab_mask, float("-inf"))
def copy(self):
return OutlinesGrammar(self.guide, self.jump_forward_map)
class OutlinesGrammarBackend(BaseGrammarBackend):
def __init__(
self,
tokenizer,
whitespace_pattern: bool,
):
super().__init__()
try:
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
except AttributeError:
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
origin_pad_token_id = tokenizer.pad_token_id
def fset(self, value):
self._value = value
type(tokenizer).pad_token_id = property(
fget=type(tokenizer).pad_token_id.fget, fset=fset
)
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token = (
self.outlines_tokenizer.tokenizer.pad_token
)
self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab()
)
self.whitespace_pattern = whitespace_pattern
def _compile_regex(self, regex: str) -> Optional[OutlinesGrammar]:
try:
if hasattr(RegexGuide, "from_regex"):
# outlines >= 0.1.1
guide = RegexGuide.from_regex(regex, self.outlines_tokenizer)
else:
# outlines <= 0.0.46
guide = RegexGuide(regex, self.outlines_tokenizer)
except interegular.patterns.InvalidSyntax as e:
logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
return None
jump_forward_map = None
return OutlinesGrammar(guide, jump_forward_map)
def dispatch_ebnf(self, key_string: str):
return super().dispatch_ebnf(key_string)
def dispatch_structural_tag(self, key_string: str):
return super().dispatch_structural_tag(key_string)
def dispatch_json(self, key_string: str):
try:
regex = build_regex_from_object(
key_string,
whitespace_pattern=self.whitespace_pattern,
)
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
return self._compile_regex(regex)
def dispatch_regex(self, key_string: str):
return self._compile_regex(key_string)
def build_regex_from_object(
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
):
if isinstance(object, type(BaseModel)):
schema = json.dumps(object.model_json_schema())
elif isinstance(object, Dict):
schema = json.dumps(object)
else:
schema = object
return build_regex_from_schema(schema, whitespace_pattern)