import json import unittest from xgrammar import GrammarCompiler, TokenizerInfo from sglang.srt.entrypoints.openai.protocol import Function, Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector from sglang.srt.function_call.qwen25_detector import Qwen25Detector from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST class TestPythonicDetector(unittest.TestCase): def setUp(self): # Create sample tools for testing self.tools = [ Tool( type="function", function=Function( name="get_weather", description="Get weather information", parameters={ "properties": { "location": { "type": "string", "description": "Location to get weather for", }, "unit": { "type": "string", "description": "Temperature unit", "enum": ["celsius", "fahrenheit"], }, }, "required": ["location"], }, ), ), Tool( type="function", function=Function( name="search", description="Search for information", parameters={ "properties": { "query": { "type": "string", "description": "Search query", }, }, "required": ["query"], }, ), ), ] self.detector = PythonicDetector() def test_parse_streaming_no_brackets(self): """Test parsing text with no brackets (no tool calls).""" text = "This is just normal text without any tool calls." result = self.detector.parse_streaming_increment(text, self.tools) self.assertEqual(result.normal_text, text) self.assertEqual(result.calls, []) self.assertEqual(self.detector._buffer, "") # Buffer should be cleared def test_parse_streaming_complete_tool_call(self): """Test parsing a complete tool call.""" text = "Here's a tool call: [get_weather(location='New York', unit='celsius')]" result = self.detector.parse_streaming_increment(text, self.tools) self.assertEqual(result.normal_text, "Here's a tool call: ") self.assertEqual(len(result.calls), 1) self.assertEqual(result.calls[0].name, "get_weather") self.assertEqual( self.detector._buffer, "" ) # Buffer should be cleared after processing # Check the parameters params = json.loads(result.calls[0].parameters) self.assertEqual(params["location"], "New York") self.assertEqual(params["unit"], "celsius") def test_parse_streaming_text_before_tool_call(self): """Test parsing text that appears before a tool call.""" text = "This is some text before [get_weather(location='London')]" result = self.detector.parse_streaming_increment(text, self.tools) self.assertEqual(result.normal_text, "This is some text before ") self.assertEqual(len(result.calls), 1) self.assertEqual(result.calls[0].name, "get_weather") # Check the parameters params = json.loads(result.calls[0].parameters) self.assertEqual(params["location"], "London") def test_parse_streaming_partial_tool_call(self): """Test parsing a partial tool call that spans multiple chunks.""" # First chunk with opening bracket but no closing bracket text1 = "Let me check the weather: [get_weather(location=" result1 = self.detector.parse_streaming_increment(text1, self.tools) self.assertEqual(result1.normal_text, "Let me check the weather: ") self.assertEqual(result1.calls, []) self.assertEqual( self.detector._buffer, "[get_weather(location=" ) # Partial tool call remains in buffer # Second chunk completing the tool call text2 = "'Paris')]" result2 = self.detector.parse_streaming_increment(text2, self.tools) self.assertEqual(result2.normal_text, "") self.assertEqual(len(result2.calls), 1) self.assertEqual(result2.calls[0].name, "get_weather") # Check the parameters params = json.loads(result2.calls[0].parameters) self.assertEqual(params["location"], "Paris") self.assertEqual( self.detector._buffer, "" ) # Buffer should be cleared after processing def test_parse_streaming_bracket_without_text_before(self): """Test parsing a tool call that starts at the beginning of the text.""" text = "[search(query='python programming')]" result = self.detector.parse_streaming_increment(text, self.tools) self.assertEqual(result.normal_text, "") self.assertEqual(len(result.calls), 1) self.assertEqual(result.calls[0].name, "search") # Check the parameters params = json.loads(result.calls[0].parameters) self.assertEqual(params["query"], "python programming") def test_parse_streaming_text_after_tool_call(self): """Test parsing text that appears after a tool call.""" # First chunk with complete tool call and some text after text = "[get_weather(location='Tokyo')] Here's the forecast:" result = self.detector.parse_streaming_increment(text, self.tools) self.assertEqual(result.normal_text, "") self.assertEqual(len(result.calls), 1) self.assertEqual(result.calls[0].name, "get_weather") self.assertEqual( self.detector._buffer, " Here's the forecast:" ) # Text after tool call remains in buffer # Process the remaining text in buffer result2 = self.detector.parse_streaming_increment("", self.tools) self.assertEqual(result2.normal_text, " Here's the forecast:") self.assertEqual(result2.calls, []) self.assertEqual(self.detector._buffer, "") # Buffer should be cleared def test_parse_streaming_multiple_tool_calls(self): """Test parsing multiple tool calls in sequence.""" text = "[get_weather(location='Berlin')] and [search(query='restaurants')]" # First tool call result1 = self.detector.parse_streaming_increment(text, self.tools) self.assertEqual(len(result1.calls), 1) self.assertEqual(result1.calls[0].name, "get_weather") self.assertEqual(self.detector._buffer, " and [search(query='restaurants')]") # Second tool call result2 = self.detector.parse_streaming_increment("", self.tools) self.assertEqual(result2.normal_text, " and ") self.assertEqual(len(result2.calls), 1) self.assertEqual(result2.calls[0].name, "search") self.assertEqual(self.detector._buffer, "") def test_parse_streaming_opening_bracket_only(self): """Test parsing text with only an opening bracket but no closing bracket.""" text = "Let's try this: [" result = self.detector.parse_streaming_increment(text, self.tools) self.assertEqual(result.normal_text, "Let's try this: ") self.assertEqual(result.calls, []) self.assertEqual( self.detector._buffer, "[" ) # Opening bracket remains in buffer def test_parse_streaming_nested_brackets(self): """Test parsing tool calls with nested brackets in arguments.""" # Test with list argument containing nested brackets text = "[get_weather(location='New York', unit='celsius', data=[1, 2, 3])]" result = self.detector.parse_streaming_increment(text, self.tools) self.assertEqual(result.normal_text, "") self.assertEqual(len(result.calls), 1) self.assertEqual(result.calls[0].name, "get_weather") self.assertEqual(self.detector._buffer, "") # Check the parameters params = json.loads(result.calls[0].parameters) self.assertEqual(params["location"], "New York") self.assertEqual(params["unit"], "celsius") self.assertEqual(params["data"], [1, 2, 3]) def test_parse_streaming_nested_brackets_dict(self): """Test parsing tool calls with nested dictionaries and lists.""" # Test with nested dict and list arguments text = "[search(query='test', config={'options': [1, 2], 'nested': {'key': 'value'}})]" result = self.detector.parse_streaming_increment(text, self.tools) self.assertEqual(result.normal_text, "") self.assertEqual(len(result.calls), 1) self.assertEqual(result.calls[0].name, "search") self.assertEqual(self.detector._buffer, "") # Check the parameters params = json.loads(result.calls[0].parameters) self.assertEqual(params["query"], "test") self.assertEqual(params["config"]["options"], [1, 2]) self.assertEqual(params["config"]["nested"]["key"], "value") def test_parse_streaming_multiple_tools_with_nested_brackets(self): """Test parsing multiple tool calls with nested brackets.""" text = "[get_weather(location='Paris', data=[10, 20]), search(query='test', filters=['a', 'b'])]" result = self.detector.parse_streaming_increment(text, self.tools) self.assertEqual(result.normal_text, "") self.assertEqual(len(result.calls), 2) self.assertEqual(self.detector._buffer, "") # Check first tool call params1 = json.loads(result.calls[0].parameters) self.assertEqual(result.calls[0].name, "get_weather") self.assertEqual(params1["location"], "Paris") self.assertEqual(params1["data"], [10, 20]) # Check second tool call params2 = json.loads(result.calls[1].parameters) self.assertEqual(result.calls[1].name, "search") self.assertEqual(params2["query"], "test") self.assertEqual(params2["filters"], ["a", "b"]) def test_parse_streaming_partial_nested_brackets(self): """Test parsing partial tool calls with nested brackets across chunks.""" # First chunk with nested brackets but incomplete text1 = "Here's a call: [get_weather(location='Tokyo', data=[1, 2" result1 = self.detector.parse_streaming_increment(text1, self.tools) self.assertEqual(result1.normal_text, "Here's a call: ") self.assertEqual(result1.calls, []) self.assertEqual( self.detector._buffer, "[get_weather(location='Tokyo', data=[1, 2" ) # Second chunk completing the nested brackets text2 = ", 3])]" result2 = self.detector.parse_streaming_increment(text2, self.tools) self.assertEqual(result2.normal_text, "") self.assertEqual(len(result2.calls), 1) self.assertEqual(result2.calls[0].name, "get_weather") self.assertEqual(self.detector._buffer, "") # Check the parameters params = json.loads(result2.calls[0].parameters) self.assertEqual(params["location"], "Tokyo") self.assertEqual(params["data"], [1, 2, 3]) def test_parse_streaming_with_python_start_and_end_token(self): """Test parsing a message that starts with <|python_start|> and <|python_end|> across chunks.""" chunks = [ "Here's a call: ", "<|python_", "start|>[get_weather(location=", "'Tokyo', data=[1, 2", ", 3])]<|python_end|>", ] normal_text = "" call_name = "" parameters = "" for chunk in chunks: result = self.detector.parse_streaming_increment(chunk, self.tools) if result.normal_text: normal_text += result.normal_text if result.calls: call_name += result.calls[0].name parameters += result.calls[0].parameters self.assertEqual(normal_text, "Here's a call: ") self.assertEqual(call_name, "get_weather") self.assertEqual(self.detector._buffer, "") self.assertEqual( result.normal_text, "", "Final result should have no normal text" ) # Check the parameters params = json.loads(parameters) self.assertEqual(params["location"], "Tokyo") self.assertEqual(params["data"], [1, 2, 3]) chunks = [ "Here's a call: <|python_start|>[get_weather(location='Tokyo', data=[1, 2, 3])]<|python_end|>" ] normal_text = "" call_name = "" parameters = "" for chunk in chunks: result = self.detector.parse_streaming_increment(chunk, self.tools) if result.normal_text: normal_text += result.normal_text if result.calls: call_name += result.calls[0].name parameters += result.calls[0].parameters self.assertEqual(normal_text, "Here's a call: ") self.assertEqual(call_name, "get_weather") self.assertEqual(self.detector._buffer, "") # Check the parameters params = json.loads(parameters) self.assertEqual(params["location"], "Tokyo") self.assertEqual(params["data"], [1, 2, 3]) def test_detect_and_parse_with_python_start_and_end_token(self): """Test parsing a message that starts with <|python_start|> and contains a valid tool call.""" text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars." result = self.detector.detect_and_parse(text, self.tools) self.assertEqual( result.normal_text, "User wants to get the weather in Mars. In this way we will get the weather in Mars.", ) self.assertEqual(len(result.calls), 1) self.assertEqual(result.calls[0].name, "get_weather") self.assertEqual(self.detector._buffer, "") # Check the parameters params = json.loads(result.calls[0].parameters) self.assertEqual(params["location"], "Mars") self.assertEqual(params["unit"], "celsius") class TestMistralDetector(unittest.TestCase): def setUp(self): """Set up test tools and detector for Mistral format testing.""" self.tools = [ Tool( type="function", function=Function( name="make_next_step_decision", description="Test function for decision making", parameters={ "type": "object", "properties": { "decision": { "type": "string", "description": "The next step to take", }, "content": { "type": "string", "description": "The content of the next step", }, }, "required": ["decision", "content"], }, ), ), ] self.detector = MistralDetector() def test_detect_and_parse_with_nested_brackets_in_content(self): """Test parsing Mistral format with nested brackets in JSON content. This test case specifically addresses the issue where the regex pattern was incorrectly truncating JSON when it contained nested brackets like [City Name]. """ # This is the exact problematic text from the original test failure test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"","content":"```\\nTOOL: Access a weather API or service\\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\\n```"}}]' result = self.detector.detect_and_parse(test_text, self.tools) # Verify that the parsing was successful self.assertEqual(len(result.calls), 1, "Should detect exactly one tool call") call = result.calls[0] self.assertEqual( call.name, "make_next_step_decision", "Should detect the correct function name", ) # Verify that the parameters are valid JSON and contain the expected content params = json.loads(call.parameters) self.assertEqual( params["decision"], "", "Decision parameter should be empty string" ) # The content should contain the full text including the nested brackets [City Name] expected_content = "```\nTOOL: Access a weather API or service\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\n```" self.assertEqual( params["content"], expected_content, "Content should include nested brackets without truncation", ) # Verify that normal text is empty (since the entire input is a tool call) self.assertEqual( result.normal_text, "", "Normal text should be empty for pure tool call" ) def test_detect_and_parse_simple_case(self): """Test parsing a simple Mistral format tool call without nested brackets.""" test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"TOOL", "content":"Use weather API"}}]' result = self.detector.detect_and_parse(test_text, self.tools) self.assertEqual(len(result.calls), 1) call = result.calls[0] self.assertEqual(call.name, "make_next_step_decision") params = json.loads(call.parameters) self.assertEqual(params["decision"], "TOOL") self.assertEqual(params["content"], "Use weather API") def test_detect_and_parse_no_tool_calls(self): """Test parsing text without any tool calls.""" test_text = "This is just normal text without any tool calls." result = self.detector.detect_and_parse(test_text, self.tools) self.assertEqual(len(result.calls), 0, "Should detect no tool calls") self.assertEqual( result.normal_text, test_text, "Should return the original text as normal text", ) def test_detect_and_parse_with_text_before_tool_call(self): """Test parsing text that has content before the tool call.""" test_text = 'Here is some text before the tool call: [TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"ANSWER", "content":"The answer is 42"}}]' result = self.detector.detect_and_parse(test_text, self.tools) self.assertEqual(len(result.calls), 1) self.assertEqual(result.normal_text, "Here is some text before the tool call:") call = result.calls[0] self.assertEqual(call.name, "make_next_step_decision") params = json.loads(call.parameters) self.assertEqual(params["decision"], "ANSWER") self.assertEqual(params["content"], "The answer is 42") class TestEBNFGeneration(unittest.TestCase): def setUp(self): # Create sample tools for testing self.tools = [ Tool( type="function", function=Function( name="get_weather", description="Get weather information", parameters={ "properties": { "location": { "type": "string", "description": "Location to get weather for", }, "unit": { "type": "string", "description": "Temperature unit", "enum": ["celsius", "fahrenheit"], }, }, "required": ["location"], }, ), ), Tool( type="function", function=Function( name="search", description="Search for information", parameters={ "properties": { "query": { "type": "string", "description": "Search query", }, }, "required": ["query"], }, ), ), ] self.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST) tokenizer_info = TokenizerInfo.from_huggingface(self.tokenizer) self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info) # Initialize all detectors self.pythonic_detector = PythonicDetector() self.deepseekv3_detector = DeepSeekV3Detector() self.llama32_detector = Llama32Detector() self.mistral_detector = MistralDetector() self.qwen25_detector = Qwen25Detector() def test_pythonic_detector_ebnf(self): """Test that the PythonicDetector generates valid EBNF.""" ebnf = self.pythonic_detector.build_ebnf(self.tools) self.assertIsNotNone(ebnf) # Check that the EBNF contains expected patterns self.assertIn('call_get_weather ::= "get_weather" "(" ', ebnf) self.assertIn('"location" "=" basic_string', ebnf) self.assertIn('( "unit" "=" ("\\"celsius\\"" | "\\"fahrenheit\\"") )', ebnf) # Validate that the EBNF can be compiled by GrammarCompiler try: ctx = self.grammar_compiler.compile_grammar(ebnf) self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") except RuntimeError as e: self.fail(f"Failed to compile EBNF: {e}") def test_deepseekv3_detector_ebnf(self): """Test that the DeepSeekV3Detector generates valid EBNF.""" ebnf = self.deepseekv3_detector.build_ebnf(self.tools) self.assertIsNotNone(ebnf) # Check that the EBNF contains expected patterns self.assertIn("<|tool▁calls▁begin|>", ebnf) self.assertIn("<|tool▁call▁begin|>function<|tool▁sep|>get_weather", ebnf) self.assertIn('\\"location\\"" ":" basic_string ', ebnf) # Validate that the EBNF can be compiled by GrammarCompiler try: ctx = self.grammar_compiler.compile_grammar(ebnf) self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") except RuntimeError as e: self.fail(f"Failed to compile EBNF: {e}") def test_llama32_detector_ebnf(self): """Test that the Llama32Detector generates valid EBNF.""" ebnf = self.llama32_detector.build_ebnf(self.tools) self.assertIsNotNone(ebnf) # Check that the EBNF contains expected patterns self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf) self.assertIn('"\\"arguments\\"" ":"', ebnf) # Validate that the EBNF can be compiled by GrammarCompiler try: ctx = self.grammar_compiler.compile_grammar(ebnf) self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") except RuntimeError as e: self.fail(f"Failed to compile EBNF: {e}") def test_mistral_detector_ebnf(self): """Test that the MistralDetector generates valid EBNF.""" ebnf = self.mistral_detector.build_ebnf(self.tools) self.assertIsNotNone(ebnf) # Check that the EBNF contains expected patterns self.assertIn('"[TOOL_CALLS] ["', ebnf) self.assertIn("call_get_weather | call_search", ebnf) self.assertIn('"\\"arguments\\"" ":"', ebnf) # Validate that the EBNF can be compiled by GrammarCompiler try: ctx = self.grammar_compiler.compile_grammar(ebnf) self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") except RuntimeError as e: self.fail(f"Failed to compile EBNF: {e}") def test_qwen25_detector_ebnf(self): """Test that the Qwen25Detector generates valid EBNF.""" ebnf = self.qwen25_detector.build_ebnf(self.tools) self.assertIsNotNone(ebnf) # Check that the EBNF contains expected patterns self.assertIn("", ebnf) self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf) self.assertIn('"\\"arguments\\"" ":"', ebnf) # Validate that the EBNF can be compiled by GrammarCompiler try: ctx = self.grammar_compiler.compile_grammar(ebnf) self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") except RuntimeError as e: self.fail(f"Failed to compile EBNF: {e}") def test_weather_function_optional_parameter_handling(self): """Test that weather function with optional unit parameter generates correct EBNF without trailing commas.""" # Create a weather tool with required location and optional unit weather_tool = Tool( type="function", function=Function( name="get_current_weather", description="Get the current weather in a given location", parameters={ "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], }, ), ) # Test all detectors with the weather tool detectors = { "pythonic": self.pythonic_detector, "deepseekv3": self.deepseekv3_detector, "llama32": self.llama32_detector, "mistral": self.mistral_detector, "qwen25": self.qwen25_detector, } for name, detector in detectors.items(): with self.subTest(detector=name): ebnf = detector.build_ebnf([weather_tool]) self.assertIsNotNone(ebnf, f"{name} detector should generate EBNF") # Check that the EBNF properly handles optional parameters if name == "pythonic": # Pythonic format: location="Paris" ( , ( unit=("celsius" | "fahrenheit") )? self.assertIn('"location" "=" basic_string', ebnf) # The comma should be inside the optional brackets for unit self.assertIn('( "," ( "unit" "=" ', ebnf) else: # JSON format: "location": "Paris" ( , ( "unit": ("celsius" | "fahrenheit") )? self.assertIn('"location\\"" ":" basic_string', ebnf) # The comma should be part of the optional group # This pattern ensures no trailing comma when unit is omitted self.assertIn('( "," ( "\\"unit\\"" ":"', ebnf) # Validate that the EBNF can be compiled try: ctx = self.grammar_compiler.compile_grammar(ebnf) self.assertIsNotNone( ctx, f"{name} EBNF should compile successfully" ) except RuntimeError as e: self.fail(f"Failed to compile {name} EBNF: {e}") def test_multiple_optional_parameters_flexible_ordering(self): """Test that multiple optional parameters allow flexible ordering using llama.cpp approach.""" # Create a tool with one required and multiple optional parameters test_tool = Tool( type="function", function=Function( name="test_func", description="Test function with multiple optional parameters", parameters={ "type": "object", "properties": { "required_field": {"type": "string"}, "opt1": {"type": "number"}, "opt2": {"type": "boolean"}, "opt3": {"type": "string"}, }, "required": ["required_field"], }, ), ) # Test JSON-based detectors (not pythonic) json_detectors = { "deepseekv3": self.deepseekv3_detector, "llama32": self.llama32_detector, "mistral": self.mistral_detector, "qwen25": self.qwen25_detector, } for name, detector in json_detectors.items(): with self.subTest(detector=name): ebnf = detector.build_ebnf([test_tool]) self.assertIsNotNone(ebnf, f"{name} detector should generate EBNF") # Print the arguments rule for debugging lines = ebnf.split("\n") args_rule = None for line in lines: if line.startswith("arguments_test_func ::="): args_rule = line break self.assertIsNotNone( args_rule, f"{name} should have arguments_test_func rule" ) # Check required field self.assertIn('"required_field\\"" ":" basic_string', ebnf) # Check the structure for optional parameters # The pattern should be: required_field ( "," ( opt1 ... | opt2 ... | opt3 ... ) )? # This allows flexible ordering where any optional can be first # Check that optional parameters are in a group with comma if args_rule: # Only check if args_rule was found self.assertIn( '( ","', args_rule, f"{name} should have comma grouped with optional parameters", ) # Check for the alternation pattern that allows flexible ordering # Should contain patterns like: opt1 ... | opt2 ... | opt3 self.assertIn('"opt1\\"" ":" basic_number', args_rule) self.assertIn('"opt2\\"" ":" basic_boolean', args_rule) self.assertIn('"opt3\\"" ":" basic_string', args_rule) # Check for alternation (|) which allows skipping optional parameters self.assertIn( "|", args_rule, f"{name} should use alternation for flexible optional ordering", ) # Check that the pattern ends properly with closing braces self.assertTrue( args_rule.endswith('"}"'), f"{name} arguments rule should end with closing brace", ) # Validate compilation try: ctx = self.grammar_compiler.compile_grammar(ebnf) self.assertIsNotNone( ctx, f"{name} EBNF should compile successfully" ) except RuntimeError as e: self.fail(f"Failed to compile {name} EBNF: {e}") def test_all_optional_parameters_ordering(self): """Test the behavior when ALL parameters are optional - verifies ordering constraints.""" # Create a tool with only optional parameters all_optional_tool = Tool( type="function", function=Function( name="optional_func", description="Function with all optional parameters", parameters={ "type": "object", "properties": { "opt1": {"type": "string"}, "opt2": {"type": "number"}, "opt3": {"type": "boolean"}, }, "required": [], # No required parameters }, ), ) # Test JSON-based detectors json_detectors = { "deepseekv3": self.deepseekv3_detector, "llama32": self.llama32_detector, "mistral": self.mistral_detector, "qwen25": self.qwen25_detector, } for name, detector in json_detectors.items(): with self.subTest(detector=name): ebnf = detector.build_ebnf([all_optional_tool]) self.assertIsNotNone(ebnf, f"{name} detector should generate EBNF") # Extract the arguments rule lines = ebnf.split("\n") args_rule = None for line in lines: if line.startswith("arguments_optional_func ::="): args_rule = line break self.assertIsNotNone( args_rule, f"{name} should have arguments_optional_func rule" ) if args_rule: # When all parameters are optional, the pattern now uses alternation: # "{" ( opt1 ... | opt2 ... | opt3 ... )? "}" # This allows flexible ordering where any optional can appear first # Check the structure self.assertIn('"opt1\\"" ":" basic_string', args_rule) self.assertIn('"opt2\\"" ":" basic_number', args_rule) self.assertIn('"opt3\\"" ":" basic_boolean', args_rule) # The pattern SHOULD have alternation (|) for flexible ordering self.assertIn( "|", args_rule, f"{name} should use alternation for flexible ordering even when all properties are optional", ) # Validate compilation try: ctx = self.grammar_compiler.compile_grammar(ebnf) self.assertIsNotNone( ctx, f"{name} EBNF should compile successfully" ) except RuntimeError as e: self.fail(f"Failed to compile {name} EBNF: {e}") class TestBaseFormatDetector(unittest.TestCase): """Test buffer management and sequential tool index assignment in BaseFormatDetector.""" def setUp(self): """Set up test detector and tools.""" # Create a concrete implementation of BaseFormatDetector for testing class TestFormatDetector(BaseFormatDetector): def __init__(self): super().__init__() self.bot_token = "" self.eot_token = "" def detect_and_parse(self, text, tools): # Not used in streaming tests pass def has_tool_call(self, text): return "" in text def structure_info(self): # Not used in streaming tests pass def build_ebnf(self, tools): # Not used in streaming tests pass self.detector = TestFormatDetector() self.tools = [ Tool( type="function", function=Function( name="get_weather", description="Get weather information", parameters={ "type": "object", "properties": { "city": { "type": "string", "description": "City name", } }, "required": ["city"], }, ), ), Tool( type="function", function=Function( name="get_tourist_attractions", description="Get tourist attractions", parameters={ "type": "object", "properties": { "city": { "type": "string", "description": "City name", } }, "required": ["city"], }, ), ), ] def test_sequential_tool_index_assignment(self): """Test that multiple tool calls get sequential tool_index values (0, 1, 2, ...).""" # Simulate streaming chunks for two consecutive tool calls chunks = [ "", '{"name": "get_weather", ', '"arguments": {"city": "Paris"}}', ", ", '{"name": "get_tourist_attractions", ', '"arguments": {"city": "London"}}', "", ] tool_indices_seen = [] for chunk in chunks: result = self.detector.parse_streaming_increment(chunk, self.tools) if result.calls: for call in result.calls: if call.tool_index is not None: tool_indices_seen.append(call.tool_index) # Verify we got sequential tool indices unique_indices = sorted(set(tool_indices_seen)) self.assertEqual( unique_indices, [0, 1], f"Expected sequential tool indices [0, 1], got {unique_indices}", ) def test_buffer_content_preservation(self): """Test that buffer correctly preserves unprocessed content when tool completes.""" # Test simpler scenario: tool completion followed by new tool start chunks = [ "", '{"name": "get_weather", ', '"arguments": {"city": "Paris"}}', ", ", '{"name": "get_tourist_attractions", ', '"arguments": {"city": "London"}} ', ] tool_calls_seen = [] for chunk in chunks: result = self.detector.parse_streaming_increment(chunk, self.tools) if result.calls: for call in result.calls: if ( call.name ): # Only count calls with names (not just parameter updates) tool_calls_seen.append(call.name) # Should see both tool names self.assertIn("get_weather", tool_calls_seen, "Should process first tool") self.assertIn( "get_tourist_attractions", tool_calls_seen, "Should process second tool" ) def test_current_tool_id_increment_on_completion(self): """Test that current_tool_id increments when a tool completes.""" # Initial state self.assertEqual( self.detector.current_tool_id, -1, "Should start with current_tool_id=-1" ) # Process first tool completely chunks = [ "", '{"name": "get_weather", ', ] for chunk in chunks: result = self.detector.parse_streaming_increment(chunk, self.tools) self.assertEqual( self.detector.current_tool_id, 0, "current_tool_id should be 0" ) self.assertEqual( result.calls[0].name, "get_weather", "The first tool should be get_weather" ) self.assertEqual( result.calls[0].tool_index, 0, "The first tool index should be 0" ) # Complete second tool name - this should show that current_tool_id is now 1 result = self.detector.parse_streaming_increment( '"arguments": {"city": "Paris"}}, {"name": "get_', self.tools ) self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}') self.assertEqual( self.detector.current_tool_id, 1, "current_tool_id should be 1 after first tool completes and second tool starts", ) result = self.detector.parse_streaming_increment( 'tourist_attractions", ', self.tools ) # Second tool should have tool_index=1 tourist_calls = [ call for call in result.calls if call.name == "get_tourist_attractions" ] self.assertEqual( tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1" ) def test_tool_name_streaming_with_correct_index(self): """Test that tool names are streamed with correct tool_index values.""" # Process first tool self.detector.parse_streaming_increment("", self.tools) result1 = self.detector.parse_streaming_increment( '{"name": "get_weather", ', self.tools ) # First tool name should have tool_index=0 weather_calls = [call for call in result1.calls if call.name == "get_weather"] self.assertEqual(len(weather_calls), 1, "Should have one weather call") self.assertEqual( weather_calls[0].tool_index, 0, "First tool should have tool_index=0" ) # Complete first tool self.detector.parse_streaming_increment( '"arguments": {"city": "Paris"}}', self.tools ) # Start second tool self.detector.parse_streaming_increment(", ", self.tools) result2 = self.detector.parse_streaming_increment( '{"name": "get_tourist_attractions", ', self.tools ) # Second tool name should have tool_index=1 tourist_calls = [ call for call in result2.calls if call.name == "get_tourist_attractions" ] self.assertEqual( len(tourist_calls), 1, "Should have one tourist attractions call" ) self.assertEqual( tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1" ) def test_buffer_reset_on_invalid_tool(self): """Test that buffer and state are reset when an invalid tool name is encountered.""" # Start fresh with an invalid tool name from the beginning result = self.detector.parse_streaming_increment( '{"name": "invalid_tool", ', self.tools ) # Should return empty result and reset state self.assertEqual(result.calls, [], "Should return no calls for invalid tool") self.assertEqual( self.detector.current_tool_id, -1, "current_tool_id should remain -1 for invalid tool", ) self.assertEqual( self.detector._buffer, "", "Buffer should be cleared for invalid tool" ) class TestLlama32Detector(unittest.TestCase): def setUp(self): """Set up test tools and detector for Mistral format testing.""" self.tools = [ Tool( type="function", function=Function( name="get_weather", description="Get weather information", parameters={ "type": "object", "properties": { "city": { "type": "string", "description": "City name", } }, "required": ["city"], }, ), ), Tool( type="function", function=Function( name="get_tourist_attractions", description="Get tourist attractions", parameters={ "type": "object", "properties": { "city": { "type": "string", "description": "City name", } }, "required": ["city"], }, ), ), ] self.detector = Llama32Detector() def test_single_json(self): text = '{"name": "get_weather", "parameters": {"city": "Paris"}}' result = self.detector.detect_and_parse(text, self.tools) assert len(result.calls) == 1 assert result.calls[0].name == "get_weather" assert result.normal_text == "" def test_multiple_json_with_separator(self): text = ( '<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}};' '{"name": "get_tourist_attractions", "parameters": {"city": "Paris"}}' ) result = self.detector.detect_and_parse(text, self.tools) self.assertEqual(len(result.calls), 2) self.assertEqual(result.calls[1].name, "get_tourist_attractions") self.assertEqual(result.normal_text, "") def test_multiple_json_with_separator_customized(self): text = ( '<|python_tag|>{"name": "get_weather", "parameters": {}}' '<|python_tag|>{"name": "get_tourist_attractions", "parameters": {}}' ) result = self.detector.detect_and_parse(text, self.tools) self.assertEqual(len(result.calls), 2) self.assertEqual(result.calls[1].name, "get_tourist_attractions") self.assertEqual(result.normal_text, "") def test_json_with_trailing_text(self): text = '{"name": "get_weather", "parameters": {}} Some follow-up text' result = self.detector.detect_and_parse(text, self.tools) self.assertEqual(len(result.calls), 1) self.assertIn("follow-up", result.normal_text) def test_invalid_then_valid_json(self): text = ( '{"name": "get_weather", "parameters": {' # malformed '{"name": "get_weather", "parameters": {}}' ) result = self.detector.detect_and_parse(text, self.tools) self.assertEqual(len(result.calls), 1) self.assertEqual(result.calls[0].name, "get_weather") def test_plain_text_only(self): text = "This is just plain explanation text." result = self.detector.detect_and_parse(text, self.tools) self.assertEqual(result.calls, []) self.assertEqual(result.normal_text, text) def test_with_python_tag_prefix(self): text = 'Some intro. <|python_tag|>{"name": "get_weather", "parameters": {}}' result = self.detector.detect_and_parse(text, self.tools) self.assertEqual(len(result.calls), 1) self.assertTrue(result.normal_text.strip().startswith("Some intro.")) if __name__ == "__main__": unittest.main()