130 lines
3.9 KiB
Python
130 lines
3.9 KiB
Python
import unittest
|
|
from types import SimpleNamespace
|
|
|
|
import requests
|
|
|
|
from sglang.srt.utils import kill_process_tree
|
|
from sglang.test.run_eval import run_eval
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
DEFAULT_URL_FOR_TEST,
|
|
CustomTestCase,
|
|
is_in_ci,
|
|
popen_launch_server,
|
|
)
|
|
|
|
|
|
class TestSlidingWindowAttentionTriton(CustomTestCase):
|
|
"""Test sliding window attention functionality with triton backend."""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
"""Set up the test server with Gemma3 model and triton backend."""
|
|
# Gemma3 model supports sliding window attention
|
|
cls.model = "google/gemma-3-4b-it"
|
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
|
|
cls.common_args = [
|
|
"--trust-remote-code",
|
|
"--attention-backend",
|
|
"triton",
|
|
"--context-length",
|
|
"8192",
|
|
"--random-seed",
|
|
"42",
|
|
]
|
|
|
|
cls.short_context_prompt = "The capital of France is"
|
|
|
|
# Test prompt longer than window size
|
|
cls.long_context_prompt = (
|
|
"""
|
|
Once upon a time, there was a mountain. In the mountain, there was a temple. In the temple, there was an old monk telling a story. The story was:
|
|
"""
|
|
* 100
|
|
)
|
|
cls.long_context_prompt += "\nNow, summarize the story in one sentence:"
|
|
|
|
def _test_mmlu(self):
|
|
args = SimpleNamespace(
|
|
base_url=self.base_url,
|
|
model=self.model,
|
|
eval_name="mmlu",
|
|
num_examples=200,
|
|
num_threads=32,
|
|
)
|
|
|
|
metrics = run_eval(args)
|
|
print(f"MMLU metrics with sliding window: {metrics}")
|
|
|
|
self.assertGreaterEqual(metrics["score"], 0.60)
|
|
|
|
def _test_short_context_generation(self):
|
|
response = requests.post(
|
|
self.base_url + "/generate",
|
|
json={
|
|
"text": self.short_context_prompt,
|
|
"sampling_params": {
|
|
"temperature": 0,
|
|
"max_new_tokens": 256,
|
|
},
|
|
},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
result = response.json()
|
|
self.assertIn("paris", result["text"].lower())
|
|
print(f"Short context generation result: {result['text']}")
|
|
|
|
def _test_long_context_generation(self):
|
|
response = requests.post(
|
|
self.base_url + "/generate",
|
|
json={
|
|
"text": self.long_context_prompt,
|
|
"sampling_params": {
|
|
"temperature": 0,
|
|
"max_new_tokens": 256,
|
|
},
|
|
},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
result = response.json()
|
|
self.assertGreater(len(result["text"].strip()), 0)
|
|
print(f"Long context generation result: {result['text'][:100]}...")
|
|
|
|
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
|
|
def test_no_cuda_graph(self):
|
|
self.no_cuda_graph_process = popen_launch_server(
|
|
self.model,
|
|
self.base_url,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
other_args=self.common_args + ["--disable-cuda-graph"],
|
|
)
|
|
|
|
try:
|
|
self._test_short_context_generation()
|
|
self._test_long_context_generation()
|
|
self._test_mmlu()
|
|
finally:
|
|
kill_process_tree(self.no_cuda_graph_process.pid)
|
|
|
|
def test_cuda_graph(self):
|
|
self.cuda_graph_process = popen_launch_server(
|
|
self.model,
|
|
self.base_url,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
other_args=self.common_args,
|
|
)
|
|
|
|
try:
|
|
self._test_short_context_generation()
|
|
self._test_long_context_generation()
|
|
self._test_mmlu()
|
|
finally:
|
|
kill_process_tree(self.cuda_graph_process.pid)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|