sglang0.4.5.post1/python/sglang/lang/backend/vertexai.py

150 lines
4.7 KiB
Python

import os
import warnings
from typing import Optional
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import vertexai
from vertexai.preview.generative_models import (
GenerationConfig,
GenerativeModel,
Image,
)
except ImportError as e:
GenerativeModel = e
class VertexAI(BaseBackend):
def __init__(self, model_name, safety_settings=None):
super().__init__()
if isinstance(GenerativeModel, Exception):
raise GenerativeModel
project_id = os.environ["GCP_PROJECT_ID"]
location = os.environ.get("GCP_LOCATION")
vertexai.init(project=project_id, location=location)
self.model_name = model_name
self.chat_template = get_chat_template("default")
self.safety_settings = safety_settings
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
prompt = self.messages_to_vertexai_input(s.messages_)
else:
# single-turn
prompt = (
self.text_to_vertexai_input(s.text_, s.cur_images)
if s.cur_images
else s.text_
)
ret = GenerativeModel(self.model_name).generate_content(
prompt,
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
safety_settings=self.safety_settings,
)
comp = ret.text
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
prompt = self.messages_to_vertexai_input(s.messages_)
else:
# single-turn
prompt = (
self.text_to_vertexai_input(s.text_, s.cur_images)
if s.cur_images
else s.text_
)
generator = GenerativeModel(self.model_name).generate_content(
prompt,
stream=True,
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
safety_settings=self.safety_settings,
)
for ret in generator:
yield ret.text, {}
def text_to_vertexai_input(self, text, images):
input = []
# split with image token
text_segs = text.split(self.chat_template.image_token)
for image_path, image_base64_data in images:
text_seg = text_segs.pop(0)
if text_seg != "":
input.append(text_seg)
input.append(Image.from_bytes(image_base64_data))
text_seg = text_segs.pop(0)
if text_seg != "":
input.append(text_seg)
return input
def messages_to_vertexai_input(self, messages):
vertexai_message = []
# from openai message format to vertexai message format
for msg in messages:
if isinstance(msg["content"], str):
text = msg["content"]
else:
text = msg["content"][0]["text"]
if msg["role"] == "system":
warnings.warn("Warning: system prompt is not supported in VertexAI.")
vertexai_message.append(
{
"role": "user",
"parts": [{"text": "System prompt: " + text}],
}
)
vertexai_message.append(
{
"role": "model",
"parts": [{"text": "Understood."}],
}
)
continue
if msg["role"] == "user":
vertexai_msg = {
"role": "user",
"parts": [{"text": text}],
}
elif msg["role"] == "assistant":
vertexai_msg = {
"role": "model",
"parts": [{"text": text}],
}
# images
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
for image in msg["content"][1:]:
assert image["type"] == "image_url"
vertexai_msg["parts"].append(
{
"inline_data": {
"data": image["image_url"]["url"].split(",")[1],
"mime_type": "image/jpeg",
}
}
)
vertexai_message.append(vertexai_msg)
return vertexai_message