300 lines
11 KiB
Python
300 lines
11 KiB
Python
import json
|
|
import time
|
|
import unittest
|
|
|
|
import openai
|
|
|
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
from sglang.srt.utils import kill_process_tree
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
DEFAULT_URL_FOR_TEST,
|
|
CustomTestCase,
|
|
popen_launch_server,
|
|
)
|
|
|
|
|
|
class TestOpenAIServerFunctionCalling(CustomTestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
# Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
cls.api_key = "sk-123456"
|
|
|
|
# Start the local OpenAI Server. If necessary, you can add other parameters such as --enable-tools.
|
|
cls.process = popen_launch_server(
|
|
cls.model,
|
|
cls.base_url,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
api_key=cls.api_key,
|
|
other_args=[
|
|
# If your server needs extra parameters to test function calling, please add them here.
|
|
"--tool-call-parser",
|
|
"llama3",
|
|
],
|
|
)
|
|
cls.base_url += "/v1"
|
|
cls.tokenizer = get_tokenizer(cls.model)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
kill_process_tree(cls.process.pid)
|
|
|
|
def test_function_calling_format(self):
|
|
"""
|
|
Test: Whether the function call format returned by the AI is correct.
|
|
When returning a tool call, message.content should be None, and tool_calls should be a list.
|
|
"""
|
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "add",
|
|
"description": "Compute the sum of two numbers",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"a": {
|
|
"type": "int",
|
|
"description": "A number",
|
|
},
|
|
"b": {
|
|
"type": "int",
|
|
"description": "A number",
|
|
},
|
|
},
|
|
"required": ["a", "b"],
|
|
},
|
|
},
|
|
}
|
|
]
|
|
|
|
messages = [{"role": "user", "content": "Compute (3+5)"}]
|
|
response = client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
temperature=0.8,
|
|
top_p=0.8,
|
|
stream=False,
|
|
tools=tools,
|
|
)
|
|
|
|
content = response.choices[0].message.content
|
|
tool_calls = response.choices[0].message.tool_calls
|
|
|
|
assert content is None, (
|
|
"When function call is successful, message.content should be None, "
|
|
f"but got: {content}"
|
|
)
|
|
assert (
|
|
isinstance(tool_calls, list) and len(tool_calls) > 0
|
|
), "tool_calls should be a non-empty list"
|
|
|
|
function_name = tool_calls[0].function.name
|
|
assert function_name == "add", "Function name should be 'add'"
|
|
|
|
def test_function_calling_streaming_simple(self):
|
|
"""
|
|
Test: Whether the function name can be correctly recognized in streaming mode.
|
|
- Expect a function call to be found, and the function name to be correct.
|
|
- Verify that streaming mode returns at least multiple chunks.
|
|
"""
|
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_current_weather",
|
|
"description": "Get the current weather in a given location",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"city": {
|
|
"type": "string",
|
|
"description": "The city to find the weather for",
|
|
},
|
|
"unit": {
|
|
"type": "string",
|
|
"description": "Weather unit (celsius or fahrenheit)",
|
|
"enum": ["celsius", "fahrenheit"],
|
|
},
|
|
},
|
|
"required": ["city", "unit"],
|
|
},
|
|
},
|
|
}
|
|
]
|
|
|
|
messages = [{"role": "user", "content": "What is the temperature in Paris?"}]
|
|
|
|
response_stream = client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
temperature=0.8,
|
|
top_p=0.8,
|
|
stream=True,
|
|
tools=tools,
|
|
)
|
|
|
|
chunks = list(response_stream)
|
|
self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk")
|
|
|
|
found_function_name = False
|
|
for chunk in chunks:
|
|
choice = chunk.choices[0]
|
|
# Check whether the current chunk contains tool_calls
|
|
if choice.delta.tool_calls:
|
|
tool_call = choice.delta.tool_calls[0]
|
|
if tool_call.function.name:
|
|
self.assertEqual(
|
|
tool_call.function.name,
|
|
"get_current_weather",
|
|
"Function name should be 'get_current_weather'",
|
|
)
|
|
found_function_name = True
|
|
break
|
|
|
|
self.assertTrue(
|
|
found_function_name,
|
|
"Target function name 'get_current_weather' was not found in the streaming chunks",
|
|
)
|
|
|
|
def test_function_calling_streaming_args_parsing(self):
|
|
"""
|
|
Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON.
|
|
- The user request requires multiple parameters.
|
|
- AI may return the arguments in chunks that need to be concatenated.
|
|
"""
|
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "add",
|
|
"description": "Compute the sum of two integers",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"a": {
|
|
"type": "int",
|
|
"description": "First integer",
|
|
},
|
|
"b": {
|
|
"type": "int",
|
|
"description": "Second integer",
|
|
},
|
|
},
|
|
"required": ["a", "b"],
|
|
},
|
|
},
|
|
}
|
|
]
|
|
|
|
messages = [
|
|
{"role": "user", "content": "Please sum 5 and 7, just call the function."}
|
|
]
|
|
|
|
response_stream = client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
temperature=0.9,
|
|
top_p=0.9,
|
|
stream=True,
|
|
tools=tools,
|
|
)
|
|
|
|
argument_fragments = []
|
|
function_name = None
|
|
for chunk in response_stream:
|
|
choice = chunk.choices[0]
|
|
if choice.delta.tool_calls:
|
|
tool_call = choice.delta.tool_calls[0]
|
|
# Record the function name on first occurrence
|
|
function_name = tool_call.function.name or function_name
|
|
# In case of multiple chunks, JSON fragments may need to be concatenated
|
|
if tool_call.function.arguments:
|
|
argument_fragments.append(tool_call.function.arguments)
|
|
|
|
self.assertEqual(function_name, "add", "Function name should be 'add'")
|
|
joined_args = "".join(argument_fragments)
|
|
self.assertTrue(
|
|
len(joined_args) > 0,
|
|
"No parameter fragments were returned in the function call",
|
|
)
|
|
|
|
# Check whether the concatenated JSON is valid
|
|
try:
|
|
args_obj = json.loads(joined_args)
|
|
except json.JSONDecodeError:
|
|
self.fail(
|
|
"The concatenated tool call arguments are not valid JSON, parsing failed"
|
|
)
|
|
|
|
self.assertIn("a", args_obj, "Missing parameter 'a'")
|
|
self.assertIn("b", args_obj, "Missing parameter 'b'")
|
|
self.assertEqual(str(args_obj["a"]), "5", "Parameter a should be 5")
|
|
self.assertEqual(str(args_obj["b"]), "7", "Parameter b should be 7")
|
|
|
|
def test_function_call_strict(self):
|
|
"""
|
|
Test: Whether the strict mode of function calling works as expected.
|
|
- When strict mode is enabled, the AI should not return a function call if the function name is not recognized.
|
|
"""
|
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "sub",
|
|
"description": "Compute the difference of two integers",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"int_a": {
|
|
"type": "int",
|
|
"description": "First integer",
|
|
},
|
|
"int_b": {
|
|
"type": "int",
|
|
"description": "Second integer",
|
|
},
|
|
},
|
|
"required": ["int_a", "int_b"],
|
|
},
|
|
"strict": True,
|
|
},
|
|
}
|
|
]
|
|
|
|
messages = [
|
|
{"role": "user", "content": "Please compute 5 - 7, using your tool."}
|
|
]
|
|
response = client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
temperature=0.8,
|
|
top_p=0.8,
|
|
stream=False,
|
|
tools=tools,
|
|
)
|
|
|
|
tool_calls = response.choices[0].message.tool_calls
|
|
function_name = tool_calls[0].function.name
|
|
arguments = tool_calls[0].function.arguments
|
|
args_obj = json.loads(arguments)
|
|
|
|
self.assertEqual(function_name, "sub", "Function name should be 'sub'")
|
|
self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5")
|
|
self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|