263 lines
9.2 KiB
Python
263 lines
9.2 KiB
Python
import json
|
|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from sglang.srt.server_args import PortArgs, prepare_server_args
|
|
from sglang.test.test_utils import CustomTestCase
|
|
|
|
|
|
class TestPrepareServerArgs(CustomTestCase):
|
|
def test_prepare_server_args(self):
|
|
server_args = prepare_server_args(
|
|
[
|
|
"--model-path",
|
|
"model_path",
|
|
"--json-model-override-args",
|
|
'{"rope_scaling": {"factor": 2.0, "rope_type": "linear"}}',
|
|
]
|
|
)
|
|
self.assertEqual(server_args.model_path, "model_path")
|
|
self.assertEqual(
|
|
json.loads(server_args.json_model_override_args),
|
|
{"rope_scaling": {"factor": 2.0, "rope_type": "linear"}},
|
|
)
|
|
|
|
|
|
class TestPortArgs(unittest.TestCase):
|
|
@patch("sglang.srt.server_args.is_port_available")
|
|
@patch("sglang.srt.server_args.tempfile.NamedTemporaryFile")
|
|
def test_init_new_standard_case(self, mock_temp_file, mock_is_port_available):
|
|
|
|
mock_is_port_available.return_value = True
|
|
mock_temp_file.return_value.name = "temp_file"
|
|
|
|
server_args = MagicMock()
|
|
server_args.port = 30000
|
|
server_args.enable_dp_attention = False
|
|
|
|
port_args = PortArgs.init_new(server_args)
|
|
|
|
self.assertTrue(port_args.tokenizer_ipc_name.startswith("ipc://"))
|
|
self.assertTrue(port_args.scheduler_input_ipc_name.startswith("ipc://"))
|
|
self.assertTrue(port_args.detokenizer_ipc_name.startswith("ipc://"))
|
|
self.assertIsInstance(port_args.nccl_port, int)
|
|
|
|
@patch("sglang.srt.server_args.is_port_available")
|
|
def test_init_new_with_single_node_dp_attention(self, mock_is_port_available):
|
|
|
|
mock_is_port_available.return_value = True
|
|
|
|
server_args = MagicMock()
|
|
server_args.port = 30000
|
|
server_args.enable_dp_attention = True
|
|
server_args.nnodes = 1
|
|
server_args.dist_init_addr = None
|
|
|
|
port_args = PortArgs.init_new(server_args)
|
|
|
|
self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://127.0.0.1:"))
|
|
self.assertTrue(
|
|
port_args.scheduler_input_ipc_name.startswith("tcp://127.0.0.1:")
|
|
)
|
|
self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://127.0.0.1:"))
|
|
self.assertIsInstance(port_args.nccl_port, int)
|
|
|
|
@patch("sglang.srt.server_args.is_port_available")
|
|
def test_init_new_with_dp_rank(self, mock_is_port_available):
|
|
|
|
mock_is_port_available.return_value = True
|
|
|
|
server_args = MagicMock()
|
|
server_args.port = 30000
|
|
server_args.enable_dp_attention = True
|
|
server_args.nnodes = 1
|
|
server_args.dist_init_addr = "192.168.1.1:25000"
|
|
|
|
port_args = PortArgs.init_new(server_args, dp_rank=2)
|
|
|
|
print(f"{port_args=}")
|
|
self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25007"))
|
|
|
|
self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
|
|
self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
|
|
self.assertIsInstance(port_args.nccl_port, int)
|
|
|
|
@patch("sglang.srt.server_args.is_port_available")
|
|
def test_init_new_with_ipv4_address(self, mock_is_port_available):
|
|
|
|
mock_is_port_available.return_value = True
|
|
|
|
server_args = MagicMock()
|
|
server_args.port = 30000
|
|
server_args.enable_dp_attention = True
|
|
server_args.nnodes = 2
|
|
server_args.dist_init_addr = "192.168.1.1:25000"
|
|
|
|
port_args = PortArgs.init_new(server_args)
|
|
|
|
self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
|
|
self.assertTrue(
|
|
port_args.scheduler_input_ipc_name.startswith("tcp://192.168.1.1:")
|
|
)
|
|
self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
|
|
self.assertIsInstance(port_args.nccl_port, int)
|
|
|
|
@patch("sglang.srt.server_args.is_port_available")
|
|
def test_init_new_with_malformed_ipv4_address(self, mock_is_port_available):
|
|
|
|
mock_is_port_available.return_value = True
|
|
|
|
server_args = MagicMock()
|
|
server_args.port = 30000
|
|
server_args.enable_dp_attention = True
|
|
server_args.nnodes = 2
|
|
server_args.dist_init_addr = "192.168.1.1"
|
|
|
|
with self.assertRaises(AssertionError) as context:
|
|
PortArgs.init_new(server_args)
|
|
|
|
self.assertIn(
|
|
"please provide --dist-init-addr as host:port", str(context.exception)
|
|
)
|
|
|
|
@patch("sglang.srt.server_args.is_port_available")
|
|
def test_init_new_with_malformed_ipv4_address_invalid_port(
|
|
self, mock_is_port_available
|
|
):
|
|
|
|
mock_is_port_available.return_value = True
|
|
|
|
server_args = MagicMock()
|
|
server_args.port = 30000
|
|
server_args.enable_dp_attention = True
|
|
server_args.nnodes = 2
|
|
server_args.dist_init_addr = "192.168.1.1:abc"
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
PortArgs.init_new(server_args)
|
|
|
|
@patch("sglang.srt.server_args.is_port_available")
|
|
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
|
|
def test_init_new_with_ipv6_address(
|
|
self, mock_is_valid_ipv6, mock_is_port_available
|
|
):
|
|
|
|
mock_is_port_available.return_value = True
|
|
|
|
server_args = MagicMock()
|
|
server_args.port = 30000
|
|
server_args.enable_dp_attention = True
|
|
server_args.nnodes = 2
|
|
server_args.dist_init_addr = "[2001:db8::1]:25000"
|
|
|
|
port_args = PortArgs.init_new(server_args)
|
|
|
|
self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://[2001:db8::1]:"))
|
|
self.assertTrue(
|
|
port_args.scheduler_input_ipc_name.startswith("tcp://[2001:db8::1]:")
|
|
)
|
|
self.assertTrue(
|
|
port_args.detokenizer_ipc_name.startswith("tcp://[2001:db8::1]:")
|
|
)
|
|
self.assertIsInstance(port_args.nccl_port, int)
|
|
|
|
@patch("sglang.srt.server_args.is_port_available")
|
|
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=False)
|
|
def test_init_new_with_invalid_ipv6_address(
|
|
self, mock_is_valid_ipv6, mock_is_port_available
|
|
):
|
|
|
|
mock_is_port_available.return_value = True
|
|
|
|
server_args = MagicMock()
|
|
server_args.port = 30000
|
|
server_args.enable_dp_attention = True
|
|
server_args.nnodes = 2
|
|
server_args.dist_init_addr = "[invalid-ipv6]:25000"
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
PortArgs.init_new(server_args)
|
|
|
|
self.assertIn("invalid IPv6 address", str(context.exception))
|
|
|
|
@patch("sglang.srt.server_args.is_port_available")
|
|
def test_init_new_with_malformed_ipv6_address_missing_bracket(
|
|
self, mock_is_port_available
|
|
):
|
|
|
|
mock_is_port_available.return_value = True
|
|
|
|
server_args = MagicMock()
|
|
server_args.port = 30000
|
|
server_args.enable_dp_attention = True
|
|
server_args.nnodes = 2
|
|
server_args.dist_init_addr = "[2001:db8::1:25000"
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
PortArgs.init_new(server_args)
|
|
|
|
self.assertIn("invalid IPv6 address format", str(context.exception))
|
|
|
|
@patch("sglang.srt.server_args.is_port_available")
|
|
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
|
|
def test_init_new_with_malformed_ipv6_address_missing_port(
|
|
self, mock_is_valid_ipv6, mock_is_port_available
|
|
):
|
|
|
|
mock_is_port_available.return_value = True
|
|
|
|
server_args = MagicMock()
|
|
server_args.port = 30000
|
|
server_args.enable_dp_attention = True
|
|
server_args.nnodes = 2
|
|
server_args.dist_init_addr = "[2001:db8::1]"
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
PortArgs.init_new(server_args)
|
|
|
|
self.assertIn(
|
|
"a port must be specified in IPv6 address", str(context.exception)
|
|
)
|
|
|
|
@patch("sglang.srt.server_args.is_port_available")
|
|
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
|
|
def test_init_new_with_malformed_ipv6_address_invalid_port(
|
|
self, mock_is_valid_ipv6, mock_is_port_available
|
|
):
|
|
|
|
mock_is_port_available.return_value = True
|
|
|
|
server_args = MagicMock()
|
|
server_args.port = 30000
|
|
server_args.enable_dp_attention = True
|
|
server_args.nnodes = 2
|
|
server_args.dist_init_addr = "[2001:db8::1]:abcde"
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
PortArgs.init_new(server_args)
|
|
|
|
self.assertIn("invalid port in IPv6 address", str(context.exception))
|
|
|
|
@patch("sglang.srt.server_args.is_port_available")
|
|
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
|
|
def test_init_new_with_malformed_ipv6_address_wrong_separator(
|
|
self, mock_is_valid_ipv6, mock_is_port_available
|
|
):
|
|
|
|
mock_is_port_available.return_value = True
|
|
|
|
server_args = MagicMock()
|
|
server_args.port = 30000
|
|
server_args.enable_dp_attention = True
|
|
server_args.nnodes = 2
|
|
server_args.dist_init_addr = "[2001:db8::1]#25000"
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
PortArgs.init_new(server_args)
|
|
|
|
self.assertIn("expected ':' after ']'", str(context.exception))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|