196 lines
7.2 KiB
Python
196 lines
7.2 KiB
Python
"""
|
|
Tests for the execution of separate_reasoning functionality in sglang.
|
|
|
|
Usage:
|
|
python3 -m unittest test/lang/test_separate_reasoning_execution.py
|
|
"""
|
|
|
|
import threading
|
|
import time
|
|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from sglang import assistant, gen, separate_reasoning, user
|
|
from sglang.lang.interpreter import StreamExecutor
|
|
from sglang.lang.ir import SglGen, SglSeparateReasoning
|
|
from sglang.test.test_utils import CustomTestCase
|
|
|
|
|
|
# Helper function to create events that won't block program exit
|
|
def create_daemon_event():
|
|
event = threading.Event()
|
|
return event
|
|
|
|
|
|
class MockReasoningParser:
|
|
def __init__(self, model_type):
|
|
self.model_type = model_type
|
|
self.parse_non_stream_called = False
|
|
self.parse_stream_chunk_called = False
|
|
|
|
def parse_non_stream(self, full_text):
|
|
self.parse_non_stream_called = True
|
|
# Simulate parsing by adding a prefix to indicate reasoning
|
|
reasoning = f"[REASONING from {self.model_type}]: {full_text}"
|
|
normal_text = f"[NORMAL from {self.model_type}]: {full_text}"
|
|
return reasoning, normal_text
|
|
|
|
def parse_stream_chunk(self, chunk_text):
|
|
self.parse_stream_chunk_called = True
|
|
# Simulate parsing by adding a prefix to indicate reasoning
|
|
reasoning = f"[REASONING from {self.model_type}]: {chunk_text}"
|
|
normal_text = f"[NORMAL from {self.model_type}]: {chunk_text}"
|
|
return reasoning, normal_text
|
|
|
|
|
|
class TestSeparateReasoningExecution(CustomTestCase):
|
|
def setUp(self):
|
|
"""Set up for the test."""
|
|
super().setUp()
|
|
# Store any events created during the test
|
|
self.events = []
|
|
|
|
def tearDown(self):
|
|
"""Clean up any threads that might have been created during the test."""
|
|
super().tearDown()
|
|
|
|
# Set all events to ensure any waiting threads are released
|
|
for event in self.events:
|
|
event.set()
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
# wake up all threads
|
|
for ev in self.events:
|
|
ev.set()
|
|
|
|
@patch("sglang.srt.parser.reasoning_parser.ReasoningParser")
|
|
def test_execute_separate_reasoning(self, mock_parser_class):
|
|
"""Test that _execute_separate_reasoning correctly calls the ReasoningParser."""
|
|
# Setup mock parser
|
|
mock_parser = MockReasoningParser("deepseek-r1")
|
|
mock_parser_class.return_value = mock_parser
|
|
|
|
# Create a mock backend to avoid AttributeError in __del__
|
|
mock_backend = MagicMock()
|
|
|
|
# Create a StreamExecutor with necessary setup
|
|
executor = StreamExecutor(
|
|
backend=mock_backend,
|
|
arguments={},
|
|
default_sampling_para={},
|
|
chat_template={
|
|
"role_map": {"user": "user", "assistant": "assistant"}
|
|
}, # Simple chat template
|
|
stream=False,
|
|
use_thread=False,
|
|
)
|
|
|
|
# Set up the executor with a variable and its value
|
|
var_name = "test_var"
|
|
reasoning_name = f"{var_name}_reasoning_content"
|
|
var_value = "Test content"
|
|
executor.variables = {var_name: var_value}
|
|
|
|
# Create events and track them for cleanup
|
|
var_event = create_daemon_event()
|
|
reasoning_event = create_daemon_event()
|
|
self.events.extend([var_event, reasoning_event])
|
|
|
|
executor.variable_event = {var_name: var_event, reasoning_name: reasoning_event}
|
|
executor.variable_event[var_name].set() # Mark as ready
|
|
|
|
# Set up the current role
|
|
executor.cur_role = "assistant"
|
|
executor.cur_role_begin_pos = 0
|
|
executor.text_ = var_value
|
|
|
|
# Create a gen expression and a separate_reasoning expression
|
|
gen_expr = SglGen(var_name)
|
|
expr = SglSeparateReasoning("deepseek-r1", expr=gen_expr)
|
|
|
|
# Execute separate_reasoning
|
|
executor._execute_separate_reasoning(expr)
|
|
|
|
# Verify that the parser was created with the correct model type
|
|
mock_parser_class.assert_called_once_with("deepseek-r1")
|
|
|
|
# Verify that parse_non_stream was called
|
|
self.assertTrue(mock_parser.parse_non_stream_called)
|
|
|
|
# Verify that the variables were updated correctly
|
|
reasoning_name = f"{var_name}_reasoning_content"
|
|
self.assertIn(reasoning_name, executor.variables)
|
|
self.assertEqual(
|
|
executor.variables[reasoning_name],
|
|
f"[REASONING from deepseek-r1]: {var_value}",
|
|
)
|
|
self.assertEqual(
|
|
executor.variables[var_name], f"[NORMAL from deepseek-r1]: {var_value}"
|
|
)
|
|
|
|
# Verify that the variable event was set
|
|
self.assertIn(reasoning_name, executor.variable_event)
|
|
self.assertTrue(executor.variable_event[reasoning_name].is_set())
|
|
|
|
# Verify that the text was updated
|
|
self.assertEqual(executor.text_, f"[NORMAL from deepseek-r1]: {var_value}")
|
|
|
|
@patch("sglang.srt.parser.reasoning_parser.ReasoningParser")
|
|
def test_reasoning_parser_integration(self, mock_parser_class):
|
|
"""Test the integration between separate_reasoning and ReasoningParser."""
|
|
# Setup mock parsers for different model types
|
|
deepseek_parser = MockReasoningParser("deepseek-r1")
|
|
qwen_parser = MockReasoningParser("qwen3")
|
|
|
|
# Configure the mock to return different parsers based on model type
|
|
def get_parser(model_type):
|
|
if model_type == "deepseek-r1":
|
|
return deepseek_parser
|
|
elif model_type == "qwen3":
|
|
return qwen_parser
|
|
else:
|
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
|
|
mock_parser_class.side_effect = get_parser
|
|
|
|
# Test with DeepSeek-R1 model
|
|
test_text = "This is a test"
|
|
reasoning, normal_text = deepseek_parser.parse_non_stream(test_text)
|
|
|
|
self.assertEqual(reasoning, f"[REASONING from deepseek-r1]: {test_text}")
|
|
self.assertEqual(normal_text, f"[NORMAL from deepseek-r1]: {test_text}")
|
|
|
|
# Test with Qwen3 model
|
|
reasoning, normal_text = qwen_parser.parse_non_stream(test_text)
|
|
|
|
self.assertEqual(reasoning, f"[REASONING from qwen3]: {test_text}")
|
|
self.assertEqual(normal_text, f"[NORMAL from qwen3]: {test_text}")
|
|
|
|
@patch("sglang.srt.parser.reasoning_parser.ReasoningParser")
|
|
def test_reasoning_parser_invalid_model(self, mock_parser_class):
|
|
"""Test that ReasoningParser raises an error for invalid model types."""
|
|
|
|
# Configure the mock to raise an error for invalid model types
|
|
def get_parser(model_type):
|
|
if model_type in ["deepseek-r1", "qwen3"]:
|
|
return MockReasoningParser(model_type)
|
|
elif model_type is None:
|
|
raise ValueError("Model type must be specified")
|
|
else:
|
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
|
|
mock_parser_class.side_effect = get_parser
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
mock_parser_class("invalid-model")
|
|
self.assertIn("Unsupported model type", str(context.exception))
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
mock_parser_class(None)
|
|
self.assertIn("Model type must be specified", str(context.exception))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|