280 lines
9.9 KiB
Python
280 lines
9.9 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for OpenAI API protocol models"""
|
|
|
|
import json
|
|
import time
|
|
import unittest
|
|
from typing import Dict, List, Optional
|
|
|
|
from pydantic import ValidationError
|
|
|
|
from sglang.srt.entrypoints.openai.protocol import (
|
|
BatchRequest,
|
|
BatchResponse,
|
|
ChatCompletionMessageContentImagePart,
|
|
ChatCompletionMessageContentTextPart,
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseChoice,
|
|
ChatCompletionResponseStreamChoice,
|
|
ChatCompletionStreamResponse,
|
|
ChatCompletionTokenLogprob,
|
|
ChatMessage,
|
|
ChoiceLogprobs,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionResponseChoice,
|
|
DeltaMessage,
|
|
EmbeddingObject,
|
|
EmbeddingRequest,
|
|
EmbeddingResponse,
|
|
ErrorResponse,
|
|
FileDeleteResponse,
|
|
FileRequest,
|
|
FileResponse,
|
|
Function,
|
|
FunctionResponse,
|
|
JsonSchemaResponseFormat,
|
|
LogProbs,
|
|
ModelCard,
|
|
ModelList,
|
|
MultimodalEmbeddingInput,
|
|
ResponseFormat,
|
|
ScoringRequest,
|
|
ScoringResponse,
|
|
StreamOptions,
|
|
StructuralTagResponseFormat,
|
|
Tool,
|
|
ToolCall,
|
|
ToolChoice,
|
|
TopLogprob,
|
|
UsageInfo,
|
|
)
|
|
|
|
|
|
class TestModelCard(unittest.TestCase):
|
|
"""Test ModelCard protocol model"""
|
|
|
|
def test_model_card_serialization(self):
|
|
"""Test model card JSON serialization"""
|
|
card = ModelCard(id="test-model", max_model_len=4096)
|
|
data = card.model_dump()
|
|
self.assertEqual(data["id"], "test-model")
|
|
self.assertEqual(data["object"], "model")
|
|
self.assertEqual(data["max_model_len"], 4096)
|
|
|
|
|
|
class TestModelList(unittest.TestCase):
|
|
"""Test ModelList protocol model"""
|
|
|
|
def test_empty_model_list(self):
|
|
"""Test empty model list creation"""
|
|
model_list = ModelList()
|
|
self.assertEqual(model_list.object, "list")
|
|
self.assertEqual(len(model_list.data), 0)
|
|
|
|
def test_model_list_with_cards(self):
|
|
"""Test model list with model cards"""
|
|
cards = [
|
|
ModelCard(id="model-1"),
|
|
ModelCard(id="model-2", max_model_len=2048),
|
|
]
|
|
model_list = ModelList(data=cards)
|
|
self.assertEqual(len(model_list.data), 2)
|
|
self.assertEqual(model_list.data[0].id, "model-1")
|
|
self.assertEqual(model_list.data[1].id, "model-2")
|
|
|
|
|
|
class TestCompletionRequest(unittest.TestCase):
|
|
"""Test CompletionRequest protocol model"""
|
|
|
|
def test_basic_completion_request(self):
|
|
"""Test basic completion request"""
|
|
request = CompletionRequest(model="test-model", prompt="Hello world")
|
|
self.assertEqual(request.model, "test-model")
|
|
self.assertEqual(request.prompt, "Hello world")
|
|
self.assertEqual(request.max_tokens, 16) # default
|
|
self.assertEqual(request.temperature, 1.0) # default
|
|
self.assertEqual(request.n, 1) # default
|
|
self.assertFalse(request.stream) # default
|
|
self.assertFalse(request.echo) # default
|
|
|
|
def test_completion_request_sglang_extensions(self):
|
|
"""Test completion request with SGLang-specific extensions"""
|
|
request = CompletionRequest(
|
|
model="test-model",
|
|
prompt="Hello",
|
|
top_k=50,
|
|
min_p=0.1,
|
|
repetition_penalty=1.1,
|
|
regex=r"\d+",
|
|
json_schema='{"type": "object"}',
|
|
lora_path="/path/to/lora",
|
|
)
|
|
self.assertEqual(request.top_k, 50)
|
|
self.assertEqual(request.min_p, 0.1)
|
|
self.assertEqual(request.repetition_penalty, 1.1)
|
|
self.assertEqual(request.regex, r"\d+")
|
|
self.assertEqual(request.json_schema, '{"type": "object"}')
|
|
self.assertEqual(request.lora_path, "/path/to/lora")
|
|
|
|
def test_completion_request_validation_errors(self):
|
|
"""Test completion request validation errors"""
|
|
with self.assertRaises(ValidationError):
|
|
CompletionRequest() # missing required fields
|
|
|
|
with self.assertRaises(ValidationError):
|
|
CompletionRequest(model="test-model") # missing prompt
|
|
|
|
|
|
class TestChatCompletionRequest(unittest.TestCase):
|
|
"""Test ChatCompletionRequest protocol model"""
|
|
|
|
def test_basic_chat_completion_request(self):
|
|
"""Test basic chat completion request"""
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
request = ChatCompletionRequest(model="test-model", messages=messages)
|
|
self.assertEqual(request.model, "test-model")
|
|
self.assertEqual(len(request.messages), 1)
|
|
self.assertEqual(request.messages[0].role, "user")
|
|
self.assertEqual(request.messages[0].content, "Hello")
|
|
self.assertEqual(request.temperature, 0.7) # default
|
|
self.assertFalse(request.stream) # default
|
|
self.assertEqual(request.tool_choice, "none") # default when no tools
|
|
|
|
def test_chat_completion_tool_choice_validation(self):
|
|
"""Test tool choice validation logic"""
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
|
|
# No tools, tool_choice should default to "none"
|
|
request1 = ChatCompletionRequest(model="test-model", messages=messages)
|
|
self.assertEqual(request1.tool_choice, "none")
|
|
|
|
# With tools, tool_choice should default to "auto"
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {"name": "test_func", "description": "Test function"},
|
|
}
|
|
]
|
|
request2 = ChatCompletionRequest(
|
|
model="test-model", messages=messages, tools=tools
|
|
)
|
|
self.assertEqual(request2.tool_choice, "auto")
|
|
|
|
def test_chat_completion_sglang_extensions(self):
|
|
"""Test chat completion with SGLang extensions"""
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
request = ChatCompletionRequest(
|
|
model="test-model",
|
|
messages=messages,
|
|
top_k=40,
|
|
min_p=0.05,
|
|
separate_reasoning=False,
|
|
stream_reasoning=False,
|
|
chat_template_kwargs={"custom_param": "value"},
|
|
)
|
|
self.assertEqual(request.top_k, 40)
|
|
self.assertEqual(request.min_p, 0.05)
|
|
self.assertFalse(request.separate_reasoning)
|
|
self.assertFalse(request.stream_reasoning)
|
|
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
|
|
|
|
|
|
class TestModelSerialization(unittest.TestCase):
|
|
"""Test model serialization with hidden states"""
|
|
|
|
def test_hidden_states_excluded_when_none(self):
|
|
"""Test that None hidden_states are excluded with exclude_none=True"""
|
|
choice = ChatCompletionResponseChoice(
|
|
index=0,
|
|
message=ChatMessage(role="assistant", content="Hello"),
|
|
finish_reason="stop",
|
|
hidden_states=None,
|
|
)
|
|
|
|
response = ChatCompletionResponse(
|
|
id="test-id",
|
|
model="test-model",
|
|
choices=[choice],
|
|
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
|
|
)
|
|
|
|
# Test exclude_none serialization (should exclude None hidden_states)
|
|
data = response.model_dump(exclude_none=True)
|
|
self.assertNotIn("hidden_states", data["choices"][0])
|
|
|
|
def test_hidden_states_included_when_not_none(self):
|
|
"""Test that non-None hidden_states are included"""
|
|
choice = ChatCompletionResponseChoice(
|
|
index=0,
|
|
message=ChatMessage(role="assistant", content="Hello"),
|
|
finish_reason="stop",
|
|
hidden_states=[0.1, 0.2, 0.3],
|
|
)
|
|
|
|
response = ChatCompletionResponse(
|
|
id="test-id",
|
|
model="test-model",
|
|
choices=[choice],
|
|
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
|
|
)
|
|
|
|
# Test exclude_none serialization (should include non-None hidden_states)
|
|
data = response.model_dump(exclude_none=True)
|
|
self.assertIn("hidden_states", data["choices"][0])
|
|
self.assertEqual(data["choices"][0]["hidden_states"], [0.1, 0.2, 0.3])
|
|
|
|
|
|
class TestValidationEdgeCases(unittest.TestCase):
|
|
"""Test edge cases and validation scenarios"""
|
|
|
|
def test_invalid_tool_choice_type(self):
|
|
"""Test invalid tool choice type"""
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
with self.assertRaises(ValidationError):
|
|
ChatCompletionRequest(
|
|
model="test-model", messages=messages, tool_choice=123
|
|
)
|
|
|
|
def test_negative_token_limits(self):
|
|
"""Test negative token limits"""
|
|
with self.assertRaises(ValidationError):
|
|
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
|
|
|
|
def test_model_serialization_roundtrip(self):
|
|
"""Test that models can be serialized and deserialized"""
|
|
original_request = ChatCompletionRequest(
|
|
model="test-model",
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
temperature=0.7,
|
|
max_tokens=100,
|
|
)
|
|
|
|
# Serialize to dict
|
|
data = original_request.model_dump()
|
|
|
|
# Deserialize back
|
|
restored_request = ChatCompletionRequest(**data)
|
|
|
|
self.assertEqual(restored_request.model, original_request.model)
|
|
self.assertEqual(restored_request.temperature, original_request.temperature)
|
|
self.assertEqual(restored_request.max_tokens, original_request.max_tokens)
|
|
self.assertEqual(len(restored_request.messages), len(original_request.messages))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|