188 lines
6.3 KiB
Python
188 lines
6.3 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""
|
|
Faster constrained decoding with jump forward decoding / compressed finite state machine.
|
|
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
|
"""
|
|
|
|
import dataclasses
|
|
import logging
|
|
from collections import defaultdict
|
|
|
|
import interegular
|
|
from interegular import InvalidSyntax
|
|
from outlines.caching import cache as disk_cache
|
|
|
|
try:
|
|
# outlines >= 0.1.0
|
|
from outlines_core.fsm.outlines_core_rs import FSMInfo
|
|
from outlines_core.fsm.regex import make_byte_level_fsm, make_deterministic_fsm
|
|
except ImportError:
|
|
# outlines <= 0.0.46
|
|
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
|
|
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class JumpEdge:
|
|
symbol: str = None
|
|
symbol_next_state: int = None
|
|
byte: int = None
|
|
byte_next_state: int = None
|
|
|
|
|
|
@disk_cache()
|
|
def init_state_to_jump_forward(regex_string):
|
|
try:
|
|
regex_pattern = interegular.parse_pattern(regex_string)
|
|
except InvalidSyntax as e:
|
|
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
|
return
|
|
|
|
byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
|
|
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
|
|
|
fsm_info: FSMInfo = regex_fsm.fsm_info
|
|
|
|
symbol_to_id = fsm_info.alphabet_symbol_mapping
|
|
id_to_symbol = {}
|
|
for symbol, id_ in symbol_to_id.items():
|
|
id_to_symbol.setdefault(id_, []).append(symbol)
|
|
|
|
transitions = fsm_info.transitions
|
|
|
|
outgoings_ct = defaultdict(int)
|
|
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
|
|
for s in fsm_info.finals:
|
|
outgoings_ct[s] = 1
|
|
|
|
state_to_jump_forward = {}
|
|
for (state, id_), next_state in transitions.items():
|
|
if id_ == fsm_info.alphabet_anything_value:
|
|
# Arbitrarily symbol cannot be recognized as jump forward
|
|
continue
|
|
|
|
symbols = id_to_symbol[id_]
|
|
for c in symbols:
|
|
if len(c) > 1:
|
|
# Skip byte level transitions like c = "5E"
|
|
continue
|
|
|
|
outgoings_ct[state] += 1
|
|
if outgoings_ct[state] > 1:
|
|
if state in state_to_jump_forward:
|
|
del state_to_jump_forward[state]
|
|
break
|
|
|
|
state_to_jump_forward[state] = JumpEdge(
|
|
symbol=c,
|
|
symbol_next_state=next_state,
|
|
)
|
|
|
|
# Process the byte level jump forward
|
|
outgoings_ct = defaultdict(int)
|
|
for s in fsm_info.finals:
|
|
outgoings_ct[s] = 1
|
|
|
|
for (state, id_), next_state in transitions.items():
|
|
if id_ == fsm_info.alphabet_anything_value:
|
|
continue
|
|
symbols = id_to_symbol[id_]
|
|
for c in symbols:
|
|
byte_ = None
|
|
if len(c) == 1 and ord(c) < 0x80:
|
|
# ASCII character
|
|
byte_ = ord(c)
|
|
elif len(c) > 1:
|
|
# FIXME: This logic is due to the leading \x00
|
|
# https://github.com/outlines-dev/outlines/pull/930
|
|
byte_ = int(symbols[0][1:], 16)
|
|
|
|
if byte_ is not None:
|
|
outgoings_ct[state] += 1
|
|
if outgoings_ct[state] > 1:
|
|
if state in state_to_jump_forward:
|
|
del state_to_jump_forward[state]
|
|
break
|
|
e = state_to_jump_forward.get(state, JumpEdge())
|
|
e.byte = byte_
|
|
e.byte_next_state = next_state
|
|
state_to_jump_forward[state] = e
|
|
|
|
return state_to_jump_forward
|
|
|
|
|
|
class OutlinesJumpForwardMap:
|
|
def __init__(self, regex_string):
|
|
self.state_to_jump_forward = init_state_to_jump_forward(regex_string)
|
|
|
|
def jump_forward_symbol(self, state):
|
|
jump_forward_str = ""
|
|
next_state = state
|
|
while state in self.state_to_jump_forward:
|
|
e = self.state_to_jump_forward[state]
|
|
if e.symbol is None:
|
|
break
|
|
jump_forward_str += e.symbol
|
|
next_state = e.symbol_next_state
|
|
state = next_state
|
|
|
|
return jump_forward_str, next_state
|
|
|
|
def jump_forward_byte(self, state):
|
|
if state not in self.state_to_jump_forward:
|
|
return None
|
|
|
|
jump_forward_bytes = []
|
|
next_state = None
|
|
while state in self.state_to_jump_forward:
|
|
e = self.state_to_jump_forward[state]
|
|
assert e.byte is not None and e.byte_next_state is not None
|
|
jump_forward_bytes.append((e.byte, e.byte_next_state))
|
|
next_state = e.byte_next_state
|
|
state = next_state
|
|
|
|
return jump_forward_bytes
|
|
|
|
def is_jump_forward_symbol_state(self, state):
|
|
return (
|
|
state in self.state_to_jump_forward
|
|
and self.state_to_jump_forward[state].symbol is not None
|
|
)
|
|
|
|
|
|
def test_main(regex_string):
|
|
jump_forward_map = OutlinesJumpForwardMap(regex_string)
|
|
for state, e in jump_forward_map.state_to_jump_forward.items():
|
|
if e.symbol is not None:
|
|
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
|
|
print(f"{state} -> {next_state}", jump_forward_str)
|
|
bytes_ = jump_forward_map.jump_forward_byte(state)
|
|
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import outlines
|
|
|
|
outlines.caching.clear_cache()
|
|
test_main(r"The google's DNS sever address is " + IP_REGEX)
|
|
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
|
|
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
|
|
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
|
|
|
|
test_main(r"[-+]?[0-9]+[ ]*")
|