# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for OpenAI API protocol models""" import json import time import unittest from typing import Dict, List, Optional from pydantic import ValidationError from sglang.srt.entrypoints.openai.protocol import ( BatchRequest, BatchResponse, ChatCompletionMessageContentImagePart, ChatCompletionMessageContentTextPart, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatCompletionTokenLogprob, ChatMessage, ChoiceLogprobs, CompletionRequest, CompletionResponse, CompletionResponseChoice, DeltaMessage, EmbeddingObject, EmbeddingRequest, EmbeddingResponse, ErrorResponse, FileDeleteResponse, FileRequest, FileResponse, Function, FunctionResponse, JsonSchemaResponseFormat, LogProbs, ModelCard, ModelList, MultimodalEmbeddingInput, ResponseFormat, ScoringRequest, ScoringResponse, StreamOptions, StructuralTagResponseFormat, Tool, ToolCall, ToolChoice, TopLogprob, UsageInfo, ) class TestModelCard(unittest.TestCase): """Test ModelCard protocol model""" def test_model_card_serialization(self): """Test model card JSON serialization""" card = ModelCard(id="test-model", max_model_len=4096) data = card.model_dump() self.assertEqual(data["id"], "test-model") self.assertEqual(data["object"], "model") self.assertEqual(data["max_model_len"], 4096) class TestModelList(unittest.TestCase): """Test ModelList protocol model""" def test_empty_model_list(self): """Test empty model list creation""" model_list = ModelList() self.assertEqual(model_list.object, "list") self.assertEqual(len(model_list.data), 0) def test_model_list_with_cards(self): """Test model list with model cards""" cards = [ ModelCard(id="model-1"), ModelCard(id="model-2", max_model_len=2048), ] model_list = ModelList(data=cards) self.assertEqual(len(model_list.data), 2) self.assertEqual(model_list.data[0].id, "model-1") self.assertEqual(model_list.data[1].id, "model-2") class TestCompletionRequest(unittest.TestCase): """Test CompletionRequest protocol model""" def test_basic_completion_request(self): """Test basic completion request""" request = CompletionRequest(model="test-model", prompt="Hello world") self.assertEqual(request.model, "test-model") self.assertEqual(request.prompt, "Hello world") self.assertEqual(request.max_tokens, 16) # default self.assertEqual(request.temperature, 1.0) # default self.assertEqual(request.n, 1) # default self.assertFalse(request.stream) # default self.assertFalse(request.echo) # default def test_completion_request_sglang_extensions(self): """Test completion request with SGLang-specific extensions""" request = CompletionRequest( model="test-model", prompt="Hello", top_k=50, min_p=0.1, repetition_penalty=1.1, regex=r"\d+", json_schema='{"type": "object"}', lora_path="/path/to/lora", ) self.assertEqual(request.top_k, 50) self.assertEqual(request.min_p, 0.1) self.assertEqual(request.repetition_penalty, 1.1) self.assertEqual(request.regex, r"\d+") self.assertEqual(request.json_schema, '{"type": "object"}') self.assertEqual(request.lora_path, "/path/to/lora") def test_completion_request_validation_errors(self): """Test completion request validation errors""" with self.assertRaises(ValidationError): CompletionRequest() # missing required fields with self.assertRaises(ValidationError): CompletionRequest(model="test-model") # missing prompt class TestChatCompletionRequest(unittest.TestCase): """Test ChatCompletionRequest protocol model""" def test_basic_chat_completion_request(self): """Test basic chat completion request""" messages = [{"role": "user", "content": "Hello"}] request = ChatCompletionRequest(model="test-model", messages=messages) self.assertEqual(request.model, "test-model") self.assertEqual(len(request.messages), 1) self.assertEqual(request.messages[0].role, "user") self.assertEqual(request.messages[0].content, "Hello") self.assertEqual(request.temperature, 0.7) # default self.assertFalse(request.stream) # default self.assertEqual(request.tool_choice, "none") # default when no tools def test_chat_completion_tool_choice_validation(self): """Test tool choice validation logic""" messages = [{"role": "user", "content": "Hello"}] # No tools, tool_choice should default to "none" request1 = ChatCompletionRequest(model="test-model", messages=messages) self.assertEqual(request1.tool_choice, "none") # With tools, tool_choice should default to "auto" tools = [ { "type": "function", "function": {"name": "test_func", "description": "Test function"}, } ] request2 = ChatCompletionRequest( model="test-model", messages=messages, tools=tools ) self.assertEqual(request2.tool_choice, "auto") def test_chat_completion_sglang_extensions(self): """Test chat completion with SGLang extensions""" messages = [{"role": "user", "content": "Hello"}] request = ChatCompletionRequest( model="test-model", messages=messages, top_k=40, min_p=0.05, separate_reasoning=False, stream_reasoning=False, chat_template_kwargs={"custom_param": "value"}, ) self.assertEqual(request.top_k, 40) self.assertEqual(request.min_p, 0.05) self.assertFalse(request.separate_reasoning) self.assertFalse(request.stream_reasoning) self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"}) class TestModelSerialization(unittest.TestCase): """Test model serialization with hidden states""" def test_hidden_states_excluded_when_none(self): """Test that None hidden_states are excluded with exclude_none=True""" choice = ChatCompletionResponseChoice( index=0, message=ChatMessage(role="assistant", content="Hello"), finish_reason="stop", hidden_states=None, ) response = ChatCompletionResponse( id="test-id", model="test-model", choices=[choice], usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6), ) # Test exclude_none serialization (should exclude None hidden_states) data = response.model_dump(exclude_none=True) self.assertNotIn("hidden_states", data["choices"][0]) def test_hidden_states_included_when_not_none(self): """Test that non-None hidden_states are included""" choice = ChatCompletionResponseChoice( index=0, message=ChatMessage(role="assistant", content="Hello"), finish_reason="stop", hidden_states=[0.1, 0.2, 0.3], ) response = ChatCompletionResponse( id="test-id", model="test-model", choices=[choice], usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6), ) # Test exclude_none serialization (should include non-None hidden_states) data = response.model_dump(exclude_none=True) self.assertIn("hidden_states", data["choices"][0]) self.assertEqual(data["choices"][0]["hidden_states"], [0.1, 0.2, 0.3]) class TestValidationEdgeCases(unittest.TestCase): """Test edge cases and validation scenarios""" def test_invalid_tool_choice_type(self): """Test invalid tool choice type""" messages = [{"role": "user", "content": "Hello"}] with self.assertRaises(ValidationError): ChatCompletionRequest( model="test-model", messages=messages, tool_choice=123 ) def test_negative_token_limits(self): """Test negative token limits""" with self.assertRaises(ValidationError): CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1) def test_model_serialization_roundtrip(self): """Test that models can be serialized and deserialized""" original_request = ChatCompletionRequest( model="test-model", messages=[{"role": "user", "content": "Hello"}], temperature=0.7, max_tokens=100, ) # Serialize to dict data = original_request.model_dump() # Deserialize back restored_request = ChatCompletionRequest(**data) self.assertEqual(restored_request.model, original_request.model) self.assertEqual(restored_request.temperature, original_request.temperature) self.assertEqual(restored_request.max_tokens, original_request.max_tokens) self.assertEqual(len(restored_request.messages), len(original_request.messages)) if __name__ == "__main__": unittest.main(verbosity=2)