sglang0.4.5.post1/python/sglang/lang/backend/openai.py

472 lines
16 KiB
Python

import dataclasses
import logging
import time
import warnings
from typing import Callable, List, Optional, Union
import numpy as np
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import openai
import tiktoken
except ImportError as e:
openai = tiktoken = e
logger = logging.getLogger(__name__)
def create_logit_bias_int(tokenizer):
"""Get logit bias for integer numbers."""
int_token_ids = []
tokens = tokenizer._mergeable_ranks
for token, token_id in tokens.items():
s = tokenizer.decode([token_id])
if all([c.isdigit() for c in s]) or s in [" "]:
int_token_ids.append(token_id)
if len(int_token_ids) >= 300: # OpenAI API limit
break
special_tokens = tokenizer._special_tokens
mask = {t: 100 for t in int_token_ids[:299]}
mask[special_tokens["<|endoftext|>"]] = 100
return mask
INSTRUCT_MODEL_NAMES = [
"gpt-3.5-turbo-instruct",
]
@dataclasses.dataclass
class TokenUsage:
prompt_tokens: int
completion_tokens: int
def reset(self):
self.prompt_tokens = self.completion_tokens = 0
class OpenAI(BaseBackend):
def __init__(
self,
model_name: str,
is_chat_model: Optional[bool] = None,
chat_template: Optional[ChatTemplate] = None,
is_azure: bool = False,
*args,
**kwargs,
):
super().__init__()
if isinstance(openai, Exception):
raise openai
if is_azure:
self.client = openai.AzureOpenAI(*args, **kwargs)
else:
self.client = openai.OpenAI(*args, **kwargs)
self.model_name = model_name
try:
self.tokenizer = tiktoken.encoding_for_model(model_name)
except KeyError:
self.tokenizer = tiktoken.get_encoding("cl100k_base")
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
self.chat_template = chat_template or get_chat_template_by_model_path(
model_name
)
if is_chat_model is not None:
self.is_chat_model = is_chat_model
else:
if model_name in INSTRUCT_MODEL_NAMES:
self.is_chat_model = False
else:
self.is_chat_model = True
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
# Usage
self.token_usage = TokenUsage(0, 0)
# API speculative execution
# TODO(ying): This does not support multi-threading (run_batch)
self.spec_kwargs = {}
self.spec_format = []
self.spec_max_num_tries = 3
def get_chat_template(self):
return self.chat_template
def _prepare_spec_execution(
self,
sampling_params: SglSamplingParams,
num_api_spec_tokens: int,
spec_var_name: str,
):
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
else:
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
params = sampling_params.to_openai_kwargs()
for key, value in params.items():
if key in ["stop"]:
continue
if key in ["max_tokens"]:
warnings.warn(
"The parameter max_tokens will be overwritten by speculated number of tokens."
)
continue
if key not in self.spec_kwargs:
self.spec_kwargs[key] = value
else:
assert (
value == self.spec_kwargs[key]
), "sampling parameters should be consistent if turn on api speculative execution."
self.spec_format.append(
{"text": "", "stop": params["stop"], "name": spec_var_name}
)
return "", {}
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
spec_var_name: str = None,
):
if sampling_params.dtype is None:
if self.is_chat_model:
if s.num_api_spec_tokens is None:
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported if api speculative execution is off. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
)
prompt = s.messages_
else:
return self._prepare_spec_execution(
sampling_params, s.num_api_spec_tokens, spec_var_name
)
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
if self.model_name.startswith("o1") or self.model_name.startswith("o3"):
kwargs.pop("max_tokens", None)
else:
kwargs.pop("max_completion_tokens", None)
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
**kwargs,
)
# Keep the returned list (or string) as is.
elif sampling_params.dtype in [str, "str", "string"]:
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_ + '"',
stop='"',
**kwargs,
)
# Wrap each element in quotes if we have a list.
if isinstance(comp, list):
comp = ['"' + x + '"' for x in comp]
else:
comp = '"' + comp + '"'
elif sampling_params.dtype in [int, "int"]:
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_,
logit_bias=self.logit_bias_int,
stop=[" "],
**kwargs,
)
# Leave as a list if that's what is returned.
else:
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
return comp, {}
def spec_fill(self, value: str):
assert self.is_chat_model
self.spec_format.append({"text": value, "stop": None, "name": None})
def spec_pattern_match(self, comp):
for i, term in enumerate(self.spec_format):
text = term["text"]
if text != "":
if comp.startswith(text):
comp = comp[len(text) :]
else:
return False
else:
pos = comp.find(term["stop"])
if pos != -1:
term["text"] = comp[:pos]
comp = comp[pos:]
else:
if i == len(self.spec_format) - 1:
term["text"] = comp
else:
return False
return True
def role_end_generate(
self,
s: StreamExecutor,
):
if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
return
comp = ""
if not all(x["name"] is None for x in self.spec_format):
# TODO(ying): throw errors or warnings
for i in range(self.spec_max_num_tries):
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.messages_,
**self.spec_kwargs,
)
# Use a string for pattern matching.
comp_for_match = comp[0] if isinstance(comp, list) else comp
if self.spec_pattern_match(comp_for_match):
break
for term in self.spec_format:
s.text_ += term["text"]
name = term["name"]
if name is not None:
s.variables[name] = term["text"]
s.meta_info[name] = {}
s.variable_event[name].set()
self.spec_kwargs = {}
self.spec_format = []
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
if self.is_chat_model:
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
)
prompt = s.messages_
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
generator = openai_completion_stream(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
**kwargs,
)
return generator
else:
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
choices_method: ChoicesSamplingMethod,
) -> ChoicesDecision:
"""Note: `choices_method` is not used by the OpenAI backend."""
if self.is_chat_model:
raise NotImplementedError(
"select/choices is not supported for chat models. "
"Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
)
n_choices = len(choices)
token_ids = [self.tokenizer.encode(x) for x in choices]
scores = [0] * n_choices
valid = [len(x) > 0 for x in token_ids]
prompt_tokens = self.tokenizer.encode(s.text_)
max_len = max([len(x) for x in token_ids])
for step in range(max_len):
# Build logit bias
logit_bias = {}
for i in range(n_choices):
if valid[i]:
logit_bias[token_ids[i][step]] = 100
# Call API
ret = self.client.completions.create(
model=self.model_name,
prompt=prompt_tokens,
logit_bias=logit_bias,
max_tokens=1,
temperature=temperature,
)
ret_str = ret.choices[0].text
ret_token = self.tokenizer.encode(ret_str)[0]
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
self.token_usage.completion_tokens = ret.usage.completion_tokens
# TODO:
# 1. return logits as the scores
# 2. compute logits of the full choice
# 3. consider chunk-based decoding
# Update valid
hit = False
for i in range(n_choices):
if valid[i]:
if step == len(token_ids[i]) - 1:
valid[i] = False
if ret_token == token_ids[i][step]:
scores[i] += 1
hit = True
else:
valid[i] = False
assert hit
if np.sum(valid) <= 1:
break
prompt_tokens.append(ret_token)
return ChoicesDecision(
decision=choices[np.argmax(scores)],
meta_info={"scores": scores},
)
def openai_completion(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
) -> Union[str, List[str]]:
# if "ebnf" is in kwargs, warn and remove
if "ebnf" in kwargs:
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
del kwargs["ebnf"]
for attempt in range(retries):
try:
if is_chat:
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
ret = client.chat.completions.create(messages=prompt, **kwargs)
if len(ret.choices) == 1:
comp = ret.choices[0].message.content
else:
comp = [c.message.content for c in ret.choices]
else:
ret = client.completions.create(prompt=prompt, **kwargs)
if isinstance(prompt, (list, tuple)):
comp = [c.text for c in ret.choices]
else:
comp = ret.choices[0].text
if len(ret.choices) > 1:
comp = [c.text for c in ret.choices]
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
time.sleep(5)
if attempt == retries - 1:
raise e
except Exception as e:
logger.error(f"RuntimeError {e}.")
raise e
return comp
def openai_completion_stream(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
# if "ebnf" is in kwargs, warn and remove
if "ebnf" in kwargs:
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
del kwargs["ebnf"]
for attempt in range(retries):
try:
if is_chat:
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
generator = client.chat.completions.create(
messages=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:
continue
try:
content = ret.choices[0].delta.content
except IndexError:
content = None
yield content or "", {}
else:
generator = client.completions.create(
prompt=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:
continue
content = ret.choices[0].text
yield content or "", {}
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
time.sleep(5)
if attempt == retries - 1:
raise e
except Exception as e:
logger.error(f"RuntimeError {e}.")
raise e