150 lines
4.7 KiB
Python
150 lines
4.7 KiB
Python
import os
|
|
import warnings
|
|
from typing import Optional
|
|
|
|
from sglang.lang.backend.base_backend import BaseBackend
|
|
from sglang.lang.chat_template import get_chat_template
|
|
from sglang.lang.interpreter import StreamExecutor
|
|
from sglang.lang.ir import SglSamplingParams
|
|
|
|
try:
|
|
import vertexai
|
|
from vertexai.preview.generative_models import (
|
|
GenerationConfig,
|
|
GenerativeModel,
|
|
Image,
|
|
)
|
|
except ImportError as e:
|
|
GenerativeModel = e
|
|
|
|
|
|
class VertexAI(BaseBackend):
|
|
def __init__(self, model_name, safety_settings=None):
|
|
super().__init__()
|
|
|
|
if isinstance(GenerativeModel, Exception):
|
|
raise GenerativeModel
|
|
|
|
project_id = os.environ["GCP_PROJECT_ID"]
|
|
location = os.environ.get("GCP_LOCATION")
|
|
vertexai.init(project=project_id, location=location)
|
|
|
|
self.model_name = model_name
|
|
self.chat_template = get_chat_template("default")
|
|
self.safety_settings = safety_settings
|
|
|
|
def get_chat_template(self):
|
|
return self.chat_template
|
|
|
|
def generate(
|
|
self,
|
|
s: StreamExecutor,
|
|
sampling_params: SglSamplingParams,
|
|
):
|
|
if s.messages_:
|
|
prompt = self.messages_to_vertexai_input(s.messages_)
|
|
else:
|
|
# single-turn
|
|
prompt = (
|
|
self.text_to_vertexai_input(s.text_, s.cur_images)
|
|
if s.cur_images
|
|
else s.text_
|
|
)
|
|
ret = GenerativeModel(self.model_name).generate_content(
|
|
prompt,
|
|
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
|
safety_settings=self.safety_settings,
|
|
)
|
|
|
|
comp = ret.text
|
|
|
|
return comp, {}
|
|
|
|
def generate_stream(
|
|
self,
|
|
s: StreamExecutor,
|
|
sampling_params: SglSamplingParams,
|
|
):
|
|
if s.messages_:
|
|
prompt = self.messages_to_vertexai_input(s.messages_)
|
|
else:
|
|
# single-turn
|
|
prompt = (
|
|
self.text_to_vertexai_input(s.text_, s.cur_images)
|
|
if s.cur_images
|
|
else s.text_
|
|
)
|
|
generator = GenerativeModel(self.model_name).generate_content(
|
|
prompt,
|
|
stream=True,
|
|
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
|
safety_settings=self.safety_settings,
|
|
)
|
|
for ret in generator:
|
|
yield ret.text, {}
|
|
|
|
def text_to_vertexai_input(self, text, images):
|
|
input = []
|
|
# split with image token
|
|
text_segs = text.split(self.chat_template.image_token)
|
|
for image_path, image_base64_data in images:
|
|
text_seg = text_segs.pop(0)
|
|
if text_seg != "":
|
|
input.append(text_seg)
|
|
input.append(Image.from_bytes(image_base64_data))
|
|
text_seg = text_segs.pop(0)
|
|
if text_seg != "":
|
|
input.append(text_seg)
|
|
return input
|
|
|
|
def messages_to_vertexai_input(self, messages):
|
|
vertexai_message = []
|
|
# from openai message format to vertexai message format
|
|
for msg in messages:
|
|
if isinstance(msg["content"], str):
|
|
text = msg["content"]
|
|
else:
|
|
text = msg["content"][0]["text"]
|
|
|
|
if msg["role"] == "system":
|
|
warnings.warn("Warning: system prompt is not supported in VertexAI.")
|
|
vertexai_message.append(
|
|
{
|
|
"role": "user",
|
|
"parts": [{"text": "System prompt: " + text}],
|
|
}
|
|
)
|
|
vertexai_message.append(
|
|
{
|
|
"role": "model",
|
|
"parts": [{"text": "Understood."}],
|
|
}
|
|
)
|
|
continue
|
|
if msg["role"] == "user":
|
|
vertexai_msg = {
|
|
"role": "user",
|
|
"parts": [{"text": text}],
|
|
}
|
|
elif msg["role"] == "assistant":
|
|
vertexai_msg = {
|
|
"role": "model",
|
|
"parts": [{"text": text}],
|
|
}
|
|
|
|
# images
|
|
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
|
|
for image in msg["content"][1:]:
|
|
assert image["type"] == "image_url"
|
|
vertexai_msg["parts"].append(
|
|
{
|
|
"inline_data": {
|
|
"data": image["image_url"]["url"].split(",")[1],
|
|
"mime_type": "image/jpeg",
|
|
}
|
|
}
|
|
)
|
|
|
|
vertexai_message.append(vertexai_msg)
|
|
return vertexai_message
|