228 lines
8.2 KiB
Python
228 lines
8.2 KiB
Python
"""
|
|
Test weight version functionality.
|
|
|
|
This test suite verifies the weight_version feature implementation including:
|
|
1. Default weight_version setting
|
|
2. /get_weight_version endpoint
|
|
3. /update_weight_version endpoint
|
|
4. /generate request meta_info contains weight_version
|
|
5. OpenAI API response metadata contains weight_version
|
|
"""
|
|
|
|
import unittest
|
|
|
|
import requests
|
|
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
CustomTestCase,
|
|
popen_launch_server,
|
|
)
|
|
|
|
|
|
class TestWeightVersion(CustomTestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
"""Start server once for all tests with custom weight version."""
|
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
cls.base_url = "http://127.0.0.1:30000"
|
|
cls.process = popen_launch_server(
|
|
cls.model,
|
|
base_url=cls.base_url,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
other_args=[
|
|
"--weight-version",
|
|
"test_version_1.0",
|
|
"--attention-backend",
|
|
"flashinfer",
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
"""Terminate server after all tests complete."""
|
|
if cls.process:
|
|
cls.process.terminate()
|
|
|
|
def test_weight_version_comprehensive(self):
|
|
"""Comprehensive test for all weight_version functionality."""
|
|
|
|
response = requests.get(f"{self.base_url}/get_model_info")
|
|
self.assertEqual(response.status_code, 200)
|
|
data = response.json()
|
|
self.assertIn("weight_version", data)
|
|
self.assertEqual(data["weight_version"], "test_version_1.0")
|
|
|
|
response = requests.get(f"{self.base_url}/get_weight_version")
|
|
self.assertEqual(response.status_code, 200)
|
|
data = response.json()
|
|
self.assertIn("weight_version", data)
|
|
self.assertEqual(data["weight_version"], "test_version_1.0")
|
|
|
|
request_data = {
|
|
"text": "Hello, how are you?",
|
|
"sampling_params": {
|
|
"temperature": 0.0,
|
|
"max_new_tokens": 5,
|
|
},
|
|
}
|
|
response = requests.post(f"{self.base_url}/generate", json=request_data)
|
|
self.assertEqual(response.status_code, 200)
|
|
data = response.json()
|
|
self.assertIn("meta_info", data)
|
|
self.assertIn("weight_version", data["meta_info"])
|
|
self.assertEqual(data["meta_info"]["weight_version"], "test_version_1.0")
|
|
|
|
request_data = {
|
|
"model": self.model,
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"max_tokens": 5,
|
|
"temperature": 0.0,
|
|
}
|
|
response = requests.post(
|
|
f"{self.base_url}/v1/chat/completions", json=request_data
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
data = response.json()
|
|
self.assertIn("metadata", data)
|
|
self.assertIn("weight_version", data["metadata"])
|
|
self.assertEqual(data["metadata"]["weight_version"], "test_version_1.0")
|
|
|
|
request_data = {
|
|
"model": self.model,
|
|
"prompt": "Hello",
|
|
"max_tokens": 5,
|
|
"temperature": 0.0,
|
|
}
|
|
response = requests.post(f"{self.base_url}/v1/completions", json=request_data)
|
|
self.assertEqual(response.status_code, 200)
|
|
data = response.json()
|
|
self.assertIn("metadata", data)
|
|
self.assertIn("weight_version", data["metadata"])
|
|
self.assertEqual(data["metadata"]["weight_version"], "test_version_1.0")
|
|
|
|
update_data = {
|
|
"new_version": "updated_version_2.0",
|
|
"abort_all_requests": False,
|
|
}
|
|
response = requests.post(
|
|
f"{self.base_url}/update_weight_version", json=update_data
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
data = response.json()
|
|
self.assertTrue(data["success"])
|
|
self.assertEqual(data["new_version"], "updated_version_2.0")
|
|
|
|
response = requests.get(f"{self.base_url}/get_weight_version")
|
|
self.assertEqual(response.status_code, 200)
|
|
data = response.json()
|
|
self.assertEqual(data["weight_version"], "updated_version_2.0")
|
|
|
|
gen_data = {
|
|
"text": "Test persistence",
|
|
"sampling_params": {"temperature": 0.0, "max_new_tokens": 3},
|
|
}
|
|
response = requests.post(f"{self.base_url}/generate", json=gen_data)
|
|
self.assertEqual(response.status_code, 200)
|
|
data = response.json()
|
|
self.assertEqual(data["meta_info"]["weight_version"], "updated_version_2.0")
|
|
|
|
chat_data = {
|
|
"model": self.model,
|
|
"messages": [{"role": "user", "content": "Test"}],
|
|
"max_tokens": 3,
|
|
"temperature": 0.0,
|
|
}
|
|
response = requests.post(f"{self.base_url}/v1/chat/completions", json=chat_data)
|
|
self.assertEqual(response.status_code, 200)
|
|
data = response.json()
|
|
self.assertEqual(data["metadata"]["weight_version"], "updated_version_2.0")
|
|
|
|
update_data = {"new_version": "final_version_3.0", "abort_all_requests": True}
|
|
response = requests.post(
|
|
f"{self.base_url}/update_weight_version", json=update_data
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
data = response.json()
|
|
self.assertTrue(data["success"])
|
|
self.assertEqual(data["new_version"], "final_version_3.0")
|
|
|
|
# Check /get_weight_version
|
|
response = requests.get(f"{self.base_url}/get_weight_version")
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response.json()["weight_version"], "final_version_3.0")
|
|
|
|
# Check /get_model_info
|
|
response = requests.get(f"{self.base_url}/get_model_info")
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response.json()["weight_version"], "final_version_3.0")
|
|
|
|
# Check /generate meta_info
|
|
response = requests.post(
|
|
f"{self.base_url}/generate",
|
|
json={
|
|
"text": "Final test",
|
|
"sampling_params": {"temperature": 0.0, "max_new_tokens": 2},
|
|
},
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(
|
|
response.json()["meta_info"]["weight_version"], "final_version_3.0"
|
|
)
|
|
|
|
# Check OpenAI chat metadata
|
|
response = requests.post(
|
|
f"{self.base_url}/v1/chat/completions",
|
|
json={
|
|
"model": self.model,
|
|
"messages": [{"role": "user", "content": "Final"}],
|
|
"max_tokens": 2,
|
|
"temperature": 0.0,
|
|
},
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(
|
|
response.json()["metadata"]["weight_version"], "final_version_3.0"
|
|
)
|
|
|
|
print("All weight_version functionality tests passed!")
|
|
|
|
def test_update_weight_version_with_weight_updates(self):
|
|
"""Test that weight_version can be updated along with weight updates using real model data."""
|
|
print("Testing weight_version update with real weight operations...")
|
|
|
|
# Get current model info for reference
|
|
model_info_response = requests.get(f"{self.base_url}/get_model_info")
|
|
self.assertEqual(model_info_response.status_code, 200)
|
|
current_model_path = model_info_response.json()["model_path"]
|
|
|
|
update_data = {
|
|
"model_path": current_model_path,
|
|
"load_format": "auto",
|
|
"abort_all_requests": False,
|
|
"weight_version": "disk_update_v2.0.0",
|
|
}
|
|
|
|
response = requests.post(
|
|
f"{self.base_url}/update_weights_from_disk", json=update_data
|
|
)
|
|
self.assertEqual(
|
|
response.status_code,
|
|
200,
|
|
f"update_weights_from_disk failed with status {response.status_code}",
|
|
)
|
|
|
|
# Verify version was updated
|
|
version_response = requests.get(f"{self.base_url}/get_weight_version")
|
|
self.assertEqual(version_response.status_code, 200)
|
|
self.assertEqual(
|
|
version_response.json()["weight_version"], "disk_update_v2.0.0"
|
|
)
|
|
|
|
print("Weight update with weight_version test completed!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|