91 lines
2.4 KiB
Python
91 lines
2.4 KiB
Python
from typing import Mapping, Optional
|
|
|
|
from sglang.lang.backend.base_backend import BaseBackend
|
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
|
from sglang.lang.interpreter import StreamExecutor
|
|
from sglang.lang.ir import SglSamplingParams
|
|
|
|
try:
|
|
import litellm
|
|
except ImportError as e:
|
|
litellm = e
|
|
litellm.num_retries = 1
|
|
|
|
|
|
class LiteLLM(BaseBackend):
|
|
def __init__(
|
|
self,
|
|
model_name,
|
|
chat_template=None,
|
|
api_key=None,
|
|
organization: Optional[str] = None,
|
|
base_url: Optional[str] = None,
|
|
timeout: Optional[float] = 600,
|
|
max_retries: Optional[int] = litellm.num_retries,
|
|
default_headers: Optional[Mapping[str, str]] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
if isinstance(litellm, Exception):
|
|
raise litellm
|
|
|
|
self.model_name = model_name
|
|
|
|
self.chat_template = chat_template or get_chat_template_by_model_path(
|
|
model_name
|
|
)
|
|
|
|
self.client_params = {
|
|
"api_key": api_key,
|
|
"organization": organization,
|
|
"base_url": base_url,
|
|
"timeout": timeout,
|
|
"max_retries": max_retries,
|
|
"default_headers": default_headers,
|
|
}
|
|
|
|
def get_chat_template(self):
|
|
return self.chat_template
|
|
|
|
def generate(
|
|
self,
|
|
s: StreamExecutor,
|
|
sampling_params: SglSamplingParams,
|
|
):
|
|
if s.messages_:
|
|
messages = s.messages_
|
|
else:
|
|
messages = [{"role": "user", "content": s.text_}]
|
|
|
|
ret = litellm.completion(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
**self.client_params,
|
|
**sampling_params.to_litellm_kwargs(),
|
|
)
|
|
comp = ret.choices[0].message.content
|
|
|
|
return comp, {}
|
|
|
|
def generate_stream(
|
|
self,
|
|
s: StreamExecutor,
|
|
sampling_params: SglSamplingParams,
|
|
):
|
|
if s.messages_:
|
|
messages = s.messages_
|
|
else:
|
|
messages = [{"role": "user", "content": s.text_}]
|
|
|
|
ret = litellm.completion(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
stream=True,
|
|
**self.client_params,
|
|
**sampling_params.to_litellm_kwargs(),
|
|
)
|
|
for chunk in ret:
|
|
text = chunk.choices[0].delta.content
|
|
if text is not None:
|
|
yield text, {}
|