sglang_v0.5.2/sglang/test/srt/test_jinja_template_utils.py

311 lines
9.7 KiB
Python

"""
Unit tests for Jinja chat template utils.
"""
import unittest
from sglang.srt.parser.jinja_template_utils import (
detect_jinja_template_content_format,
process_content_for_template_format,
)
from sglang.test.test_utils import CustomTestCase
class TestTemplateContentFormatDetection(CustomTestCase):
"""Test template content format detection functionality."""
def test_detect_llama4_openai_format(self):
"""Test detection of llama4-style template (should be 'openai' format)."""
llama4_pattern = """
{%- for message in messages %}
{%- if message['content'] is string %}
{{- message['content'] }}
{%- else %}
{%- for content in message['content'] %}
{%- if content['type'] == 'image' %}
{{- '<|image|>' }}
{%- elif content['type'] == 'text' %}
{{- content['text'] | trim }}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- endfor %}
"""
result = detect_jinja_template_content_format(llama4_pattern)
self.assertEqual(result, "openai")
def test_detect_deepseek_string_format(self):
"""Test detection of deepseek-style template (should be 'string' format)."""
deepseek_pattern = """
{%- for message in messages %}
{%- if message['role'] == 'user' %}
{{- '<|User|>' + message['content'] + '<|Assistant|>' }}
{%- endif %}
{%- endfor %}
"""
result = detect_jinja_template_content_format(deepseek_pattern)
self.assertEqual(result, "string")
def test_detect_invalid_template(self):
"""Test handling of invalid template (should default to 'string')."""
invalid_pattern = "{{{{ invalid jinja syntax }}}}"
result = detect_jinja_template_content_format(invalid_pattern)
self.assertEqual(result, "string")
def test_detect_empty_template(self):
"""Test handling of empty template (should default to 'string')."""
result = detect_jinja_template_content_format("")
self.assertEqual(result, "string")
def test_detect_msg_content_pattern(self):
"""Test detection of template with msg.content pattern (should be 'openai' format)."""
msg_content_pattern = """
[gMASK]<sop>
{%- for msg in messages %}
{%- if msg.role == 'system' %}
<|system|>
{{ msg.content }}
{%- elif msg.role == 'user' %}
<|user|>{{ '\n' }}
{%- if msg.content is string %}
{{ msg.content }}
{%- else %}
{%- for item in msg.content %}
{%- if item.type == 'video' or 'video' in item %}
<|begin_of_video|><|video|><|end_of_video|>
{%- elif item.type == 'image' or 'image' in item %}
<|begin_of_image|><|image|><|end_of_image|>
{%- elif item.type == 'text' %}
{{ item.text }}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- elif msg.role == 'assistant' %}
{%- if msg.metadata %}
<|assistant|>{{ msg.metadata }}
{{ msg.content }}
{%- else %}
<|assistant|>
{{ msg.content }}
{%- endif %}
{%- endif %}
{%- endfor %}
{% if add_generation_prompt %}<|assistant|>
{% endif %}
"""
result = detect_jinja_template_content_format(msg_content_pattern)
self.assertEqual(result, "openai")
def test_detect_m_content_pattern(self):
"""Test detection of template with m.content pattern (should be 'openai' format)."""
msg_content_pattern = """
[gMASK]<sop>
{%- for m in messages %}
{%- if m.role == 'system' %}
<|system|>
{{ m.content }}
{%- elif m.role == 'user' %}
<|user|>{{ '\n' }}
{%- if m.content is string %}
{{ m.content }}
{%- else %}
{%- for item in m.content %}
{%- if item.type == 'video' or 'video' in item %}
<|begin_of_video|><|video|><|end_of_video|>
{%- elif item.type == 'image' or 'image' in item %}
<|begin_of_image|><|image|><|end_of_image|>
{%- elif item.type == 'text' %}
{{ item.text }}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- elif m.role == 'assistant' %}
{%- if m.metadata %}
<|assistant|>{{ m.metadata }}
{{ m.content }}
{%- else %}
<|assistant|>
{{ m.content }}
{%- endif %}
{%- endif %}
{%- endfor %}
{% if add_generation_prompt %}<|assistant|>
{% endif %}
"""
result = detect_jinja_template_content_format(msg_content_pattern)
self.assertEqual(result, "openai")
def test_process_content_openai_format(self):
"""Test content processing for openai format."""
msg_dict = {
"role": "user",
"content": [
{"type": "text", "text": "Look at this image:"},
{
"type": "image_url",
"image_url": {"url": "http://example.com/image.jpg"},
},
{"type": "text", "text": "What do you see?"},
],
}
image_data = []
video_data = []
audio_data = []
modalities = []
result = process_content_for_template_format(
msg_dict, "openai", image_data, video_data, audio_data, modalities
)
# Check that image_data was extracted
self.assertEqual(len(image_data), 1)
self.assertEqual(image_data[0].url, "http://example.com/image.jpg")
# Check that content was normalized
expected_content = [
{"type": "text", "text": "Look at this image:"},
{"type": "image"}, # normalized from image_url
{"type": "text", "text": "What do you see?"},
]
self.assertEqual(result["content"], expected_content)
self.assertEqual(result["role"], "user")
def test_process_content_string_format(self):
"""Test content processing for string format."""
msg_dict = {
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{
"type": "image_url",
"image_url": {"url": "http://example.com/image.jpg"},
},
{"type": "text", "text": "world"},
],
}
image_data = []
video_data = []
audio_data = []
modalities = []
result = process_content_for_template_format(
msg_dict, "string", image_data, video_data, audio_data, modalities
)
# For string format, should flatten to text only
self.assertEqual(result["content"], "Hello world")
self.assertEqual(result["role"], "user")
# Image data should not be extracted for string format
self.assertEqual(len(image_data), 0)
def test_process_content_with_audio(self):
"""Test content processing with audio content."""
msg_dict = {
"role": "user",
"content": [
{"type": "text", "text": "Listen to this:"},
{
"type": "audio_url",
"audio_url": {"url": "http://example.com/audio.mp3"},
},
],
}
image_data = []
video_data = []
audio_data = []
modalities = []
result = process_content_for_template_format(
msg_dict, "openai", image_data, video_data, audio_data, modalities
)
# Check that audio_data was extracted
self.assertEqual(len(audio_data), 1)
self.assertEqual(audio_data[0], "http://example.com/audio.mp3")
# Check that content was normalized
expected_content = [
{"type": "text", "text": "Listen to this:"},
{"type": "audio"}, # normalized from audio_url
]
self.assertEqual(result["content"], expected_content)
def test_process_content_already_string(self):
"""Test processing content that's already a string."""
msg_dict = {"role": "user", "content": "Hello world"}
image_data = []
video_data = []
audio_data = []
modalities = []
result = process_content_for_template_format(
msg_dict, "openai", image_data, video_data, audio_data, modalities
)
# Should pass through unchanged
self.assertEqual(result["content"], "Hello world")
self.assertEqual(result["role"], "user")
self.assertEqual(len(image_data), 0)
def test_process_content_with_modalities(self):
"""Test content processing with modalities field."""
msg_dict = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": "http://example.com/image.jpg"},
"modalities": ["vision"],
}
],
}
image_data = []
video_data = []
audio_data = []
modalities = []
result = process_content_for_template_format(
msg_dict, "openai", image_data, video_data, audio_data, modalities
)
# Check that modalities was extracted
self.assertEqual(len(modalities), 1)
self.assertEqual(modalities[0], ["vision"])
def test_process_content_filter_none_values(self):
"""Test that None values are filtered out of processed messages."""
msg_dict = {
"role": "user",
"content": "Hello",
"name": None,
"tool_call_id": None,
}
image_data = []
video_data = []
audio_data = []
modalities = []
result = process_content_for_template_format(
msg_dict, "string", image_data, video_data, audio_data, modalities
)
# None values should be filtered out
expected_keys = {"role", "content"}
self.assertEqual(set(result.keys()), expected_keys)
if __name__ == "__main__":
unittest.main()