sglang.0.4.8.post1/sglang/test/srt/test_jinja_template_utils.py

226 lines
7.2 KiB
Python

"""
Unit tests for Jinja chat template utils.
"""
import unittest
from unittest.mock import patch
from sglang.srt.jinja_template_utils import (
detect_jinja_template_content_format,
process_content_for_template_format,
)
from sglang.test.test_utils import CustomTestCase
class TestTemplateContentFormatDetection(CustomTestCase):
"""Test template content format detection functionality."""
def test_detect_llama4_openai_format(self):
"""Test detection of llama4-style template (should be 'openai' format)."""
llama4_pattern = """
{%- for message in messages %}
{%- if message['content'] is string %}
{{- message['content'] }}
{%- else %}
{%- for content in message['content'] %}
{%- if content['type'] == 'image' %}
{{- '<|image|>' }}
{%- elif content['type'] == 'text' %}
{{- content['text'] | trim }}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- endfor %}
"""
result = detect_jinja_template_content_format(llama4_pattern)
self.assertEqual(result, "openai")
def test_detect_deepseek_string_format(self):
"""Test detection of deepseek-style template (should be 'string' format)."""
deepseek_pattern = """
{%- for message in messages %}
{%- if message['role'] == 'user' %}
{{- '<|User|>' + message['content'] + '<|Assistant|>' }}
{%- endif %}
{%- endfor %}
"""
result = detect_jinja_template_content_format(deepseek_pattern)
self.assertEqual(result, "string")
def test_detect_invalid_template(self):
"""Test handling of invalid template (should default to 'string')."""
invalid_pattern = "{{{{ invalid jinja syntax }}}}"
result = detect_jinja_template_content_format(invalid_pattern)
self.assertEqual(result, "string")
def test_detect_empty_template(self):
"""Test handling of empty template (should default to 'string')."""
result = detect_jinja_template_content_format("")
self.assertEqual(result, "string")
def test_process_content_openai_format(self):
"""Test content processing for openai format."""
msg_dict = {
"role": "user",
"content": [
{"type": "text", "text": "Look at this image:"},
{
"type": "image_url",
"image_url": {"url": "http://example.com/image.jpg"},
},
{"type": "text", "text": "What do you see?"},
],
}
image_data = []
audio_data = []
modalities = []
result = process_content_for_template_format(
msg_dict, "openai", image_data, audio_data, modalities
)
# Check that image_data was extracted
self.assertEqual(len(image_data), 1)
self.assertEqual(image_data[0], "http://example.com/image.jpg")
# Check that content was normalized
expected_content = [
{"type": "text", "text": "Look at this image:"},
{"type": "image"}, # normalized from image_url
{"type": "text", "text": "What do you see?"},
]
self.assertEqual(result["content"], expected_content)
self.assertEqual(result["role"], "user")
def test_process_content_string_format(self):
"""Test content processing for string format."""
msg_dict = {
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{
"type": "image_url",
"image_url": {"url": "http://example.com/image.jpg"},
},
{"type": "text", "text": "world"},
],
}
image_data = []
audio_data = []
modalities = []
result = process_content_for_template_format(
msg_dict, "string", image_data, audio_data, modalities
)
# For string format, should flatten to text only
self.assertEqual(result["content"], "Hello world")
self.assertEqual(result["role"], "user")
# Image data should not be extracted for string format
self.assertEqual(len(image_data), 0)
def test_process_content_with_audio(self):
"""Test content processing with audio content."""
msg_dict = {
"role": "user",
"content": [
{"type": "text", "text": "Listen to this:"},
{
"type": "audio_url",
"audio_url": {"url": "http://example.com/audio.mp3"},
},
],
}
image_data = []
audio_data = []
modalities = []
result = process_content_for_template_format(
msg_dict, "openai", image_data, audio_data, modalities
)
# Check that audio_data was extracted
self.assertEqual(len(audio_data), 1)
self.assertEqual(audio_data[0], "http://example.com/audio.mp3")
# Check that content was normalized
expected_content = [
{"type": "text", "text": "Listen to this:"},
{"type": "audio"}, # normalized from audio_url
]
self.assertEqual(result["content"], expected_content)
def test_process_content_already_string(self):
"""Test processing content that's already a string."""
msg_dict = {"role": "user", "content": "Hello world"}
image_data = []
audio_data = []
modalities = []
result = process_content_for_template_format(
msg_dict, "openai", image_data, audio_data, modalities
)
# Should pass through unchanged
self.assertEqual(result["content"], "Hello world")
self.assertEqual(result["role"], "user")
self.assertEqual(len(image_data), 0)
def test_process_content_with_modalities(self):
"""Test content processing with modalities field."""
msg_dict = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": "http://example.com/image.jpg"},
"modalities": ["vision"],
}
],
}
image_data = []
audio_data = []
modalities = []
result = process_content_for_template_format(
msg_dict, "openai", image_data, audio_data, modalities
)
# Check that modalities was extracted
self.assertEqual(len(modalities), 1)
self.assertEqual(modalities[0], ["vision"])
def test_process_content_filter_none_values(self):
"""Test that None values are filtered out of processed messages."""
msg_dict = {
"role": "user",
"content": "Hello",
"name": None,
"tool_call_id": None,
}
image_data = []
audio_data = []
modalities = []
result = process_content_for_template_format(
msg_dict, "string", image_data, audio_data, modalities
)
# None values should be filtered out
expected_keys = {"role", "content"}
self.assertEqual(set(result.keys()), expected_keys)
if __name__ == "__main__":
unittest.main()