import dataclasses import logging import time import warnings from typing import Callable, List, Optional, Union import numpy as np from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import SglSamplingParams try: import openai import tiktoken except ImportError as e: openai = tiktoken = e logger = logging.getLogger(__name__) def create_logit_bias_int(tokenizer): """Get logit bias for integer numbers.""" int_token_ids = [] tokens = tokenizer._mergeable_ranks for token, token_id in tokens.items(): s = tokenizer.decode([token_id]) if all([c.isdigit() for c in s]) or s in [" "]: int_token_ids.append(token_id) if len(int_token_ids) >= 300: # OpenAI API limit break special_tokens = tokenizer._special_tokens mask = {t: 100 for t in int_token_ids[:299]} mask[special_tokens["<|endoftext|>"]] = 100 return mask INSTRUCT_MODEL_NAMES = [ "gpt-3.5-turbo-instruct", ] @dataclasses.dataclass class TokenUsage: prompt_tokens: int completion_tokens: int def reset(self): self.prompt_tokens = self.completion_tokens = 0 class OpenAI(BaseBackend): def __init__( self, model_name: str, is_chat_model: Optional[bool] = None, chat_template: Optional[ChatTemplate] = None, is_azure: bool = False, *args, **kwargs, ): super().__init__() if isinstance(openai, Exception): raise openai if is_azure: self.client = openai.AzureOpenAI(*args, **kwargs) else: self.client = openai.OpenAI(*args, **kwargs) self.model_name = model_name try: self.tokenizer = tiktoken.encoding_for_model(model_name) except KeyError: self.tokenizer = tiktoken.get_encoding("cl100k_base") self.logit_bias_int = create_logit_bias_int(self.tokenizer) self.chat_template = chat_template or get_chat_template_by_model_path( model_name ) if is_chat_model is not None: self.is_chat_model = is_chat_model else: if model_name in INSTRUCT_MODEL_NAMES: self.is_chat_model = False else: self.is_chat_model = True self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0] # Usage self.token_usage = TokenUsage(0, 0) # API speculative execution # TODO(ying): This does not support multi-threading (run_batch) self.spec_kwargs = {} self.spec_format = [] self.spec_max_num_tries = 3 def get_chat_template(self): return self.chat_template def _prepare_spec_execution( self, sampling_params: SglSamplingParams, num_api_spec_tokens: int, spec_var_name: str, ): if "max_tokens" not in self.spec_kwargs: self.spec_kwargs["max_tokens"] = num_api_spec_tokens else: assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens params = sampling_params.to_openai_kwargs() for key, value in params.items(): if key in ["stop"]: continue if key in ["max_tokens"]: warnings.warn( "The parameter max_tokens will be overwritten by speculated number of tokens." ) continue if key not in self.spec_kwargs: self.spec_kwargs[key] = value else: assert ( value == self.spec_kwargs[key] ), "sampling parameters should be consistent if turn on api speculative execution." self.spec_format.append( {"text": "", "stop": params["stop"], "name": spec_var_name} ) return "", {} def generate( self, s: StreamExecutor, sampling_params: SglSamplingParams, spec_var_name: str = None, ): if sampling_params.dtype is None: if self.is_chat_model: if s.num_api_spec_tokens is None: if not s.text_.endswith(self.chat_prefix): raise RuntimeError( "This use case is not supported if api speculative execution is off. " "For OpenAI chat models, sgl.gen must be right after sgl.assistant. " "Example of adding api speculative execution: @function(num_api_spec_tokens=128)." ) prompt = s.messages_ else: return self._prepare_spec_execution( sampling_params, s.num_api_spec_tokens, spec_var_name ) else: prompt = s.text_ kwargs = sampling_params.to_openai_kwargs() if self.model_name.startswith("o1") or self.model_name.startswith("o3"): kwargs.pop("max_tokens", None) else: kwargs.pop("max_completion_tokens", None) comp = openai_completion( client=self.client, token_usage=self.token_usage, is_chat=self.is_chat_model, model=self.model_name, prompt=prompt, **kwargs, ) # Keep the returned list (or string) as is. elif sampling_params.dtype in [str, "str", "string"]: assert ( not self.is_chat_model ), "constrained type not supported on chat model" kwargs = sampling_params.to_openai_kwargs() kwargs.pop("stop") comp = openai_completion( client=self.client, token_usage=self.token_usage, is_chat=self.is_chat_model, model=self.model_name, prompt=s.text_ + '"', stop='"', **kwargs, ) # Wrap each element in quotes if we have a list. if isinstance(comp, list): comp = ['"' + x + '"' for x in comp] else: comp = '"' + comp + '"' elif sampling_params.dtype in [int, "int"]: assert ( not self.is_chat_model ), "constrained type not supported on chat model" kwargs = sampling_params.to_openai_kwargs() kwargs.pop("stop") comp = openai_completion( client=self.client, token_usage=self.token_usage, is_chat=self.is_chat_model, model=self.model_name, prompt=s.text_, logit_bias=self.logit_bias_int, stop=[" "], **kwargs, ) # Leave as a list if that's what is returned. else: raise ValueError(f"Unknown dtype: {sampling_params.dtype}") return comp, {} def spec_fill(self, value: str): assert self.is_chat_model self.spec_format.append({"text": value, "stop": None, "name": None}) def spec_pattern_match(self, comp): for i, term in enumerate(self.spec_format): text = term["text"] if text != "": if comp.startswith(text): comp = comp[len(text) :] else: return False else: pos = comp.find(term["stop"]) if pos != -1: term["text"] = comp[:pos] comp = comp[pos:] else: if i == len(self.spec_format) - 1: term["text"] = comp else: return False return True def role_end_generate( self, s: StreamExecutor, ): if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix): return comp = "" if not all(x["name"] is None for x in self.spec_format): # TODO(ying): throw errors or warnings for i in range(self.spec_max_num_tries): comp = openai_completion( client=self.client, token_usage=self.token_usage, is_chat=self.is_chat_model, model=self.model_name, prompt=s.messages_, **self.spec_kwargs, ) # Use a string for pattern matching. comp_for_match = comp[0] if isinstance(comp, list) else comp if self.spec_pattern_match(comp_for_match): break for term in self.spec_format: s.text_ += term["text"] name = term["name"] if name is not None: s.variables[name] = term["text"] s.meta_info[name] = {} s.variable_event[name].set() self.spec_kwargs = {} self.spec_format = [] def generate_stream( self, s: StreamExecutor, sampling_params: SglSamplingParams, ): if sampling_params.dtype is None: if self.is_chat_model: if not s.text_.endswith(self.chat_prefix): raise RuntimeError( "This use case is not supported. " "For OpenAI chat models, sgl.gen must be right after sgl.assistant" ) prompt = s.messages_ else: prompt = s.text_ kwargs = sampling_params.to_openai_kwargs() generator = openai_completion_stream( client=self.client, token_usage=self.token_usage, is_chat=self.is_chat_model, model=self.model_name, prompt=prompt, **kwargs, ) return generator else: raise ValueError(f"Unknown dtype: {sampling_params.dtype}") def select( self, s: StreamExecutor, choices: List[str], temperature: float, choices_method: ChoicesSamplingMethod, ) -> ChoicesDecision: """Note: `choices_method` is not used by the OpenAI backend.""" if self.is_chat_model: raise NotImplementedError( "select/choices is not supported for chat models. " "Please try to use a non-chat model such as gpt-3.5-turbo-instruct" ) n_choices = len(choices) token_ids = [self.tokenizer.encode(x) for x in choices] scores = [0] * n_choices valid = [len(x) > 0 for x in token_ids] prompt_tokens = self.tokenizer.encode(s.text_) max_len = max([len(x) for x in token_ids]) for step in range(max_len): # Build logit bias logit_bias = {} for i in range(n_choices): if valid[i]: logit_bias[token_ids[i][step]] = 100 # Call API ret = self.client.completions.create( model=self.model_name, prompt=prompt_tokens, logit_bias=logit_bias, max_tokens=1, temperature=temperature, ) ret_str = ret.choices[0].text ret_token = self.tokenizer.encode(ret_str)[0] self.token_usage.prompt_tokens += ret.usage.prompt_tokens self.token_usage.completion_tokens = ret.usage.completion_tokens # TODO: # 1. return logits as the scores # 2. compute logits of the full choice # 3. consider chunk-based decoding # Update valid hit = False for i in range(n_choices): if valid[i]: if step == len(token_ids[i]) - 1: valid[i] = False if ret_token == token_ids[i][step]: scores[i] += 1 hit = True else: valid[i] = False assert hit if np.sum(valid) <= 1: break prompt_tokens.append(ret_token) return ChoicesDecision( decision=choices[np.argmax(scores)], meta_info={"scores": scores}, ) def openai_completion( client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs ) -> Union[str, List[str]]: # if "ebnf" is in kwargs, warn and remove if "ebnf" in kwargs: warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.") del kwargs["ebnf"] for attempt in range(retries): try: if is_chat: if "stop" in kwargs and kwargs["stop"] is None: kwargs.pop("stop") ret = client.chat.completions.create(messages=prompt, **kwargs) if len(ret.choices) == 1: comp = ret.choices[0].message.content else: comp = [c.message.content for c in ret.choices] else: ret = client.completions.create(prompt=prompt, **kwargs) if isinstance(prompt, (list, tuple)): comp = [c.text for c in ret.choices] else: comp = ret.choices[0].text if len(ret.choices) > 1: comp = [c.text for c in ret.choices] token_usage.prompt_tokens += ret.usage.prompt_tokens token_usage.completion_tokens += ret.usage.completion_tokens break except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e: logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...") time.sleep(5) if attempt == retries - 1: raise e except Exception as e: logger.error(f"RuntimeError {e}.") raise e return comp def openai_completion_stream( client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs ): # if "ebnf" is in kwargs, warn and remove if "ebnf" in kwargs: warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.") del kwargs["ebnf"] for attempt in range(retries): try: if is_chat: if "stop" in kwargs and kwargs["stop"] is None: kwargs.pop("stop") generator = client.chat.completions.create( messages=prompt, stream=True, stream_options={"include_usage": True}, **kwargs, ) for ret in generator: if len(ret.choices) == 0: continue try: content = ret.choices[0].delta.content except IndexError: content = None yield content or "", {} else: generator = client.completions.create( prompt=prompt, stream=True, stream_options={"include_usage": True}, **kwargs, ) for ret in generator: if len(ret.choices) == 0: continue content = ret.choices[0].text yield content or "", {} token_usage.prompt_tokens += ret.usage.prompt_tokens token_usage.completion_tokens += ret.usage.completion_tokens break except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e: logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...") time.sleep(5) if attempt == retries - 1: raise e except Exception as e: logger.error(f"RuntimeError {e}.") raise e