sglang.0.4.8.post1/sglang/test/srt/test_triton_sliding_window.py

130 lines
3.9 KiB
Python

import unittest
from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
class TestSlidingWindowAttentionTriton(CustomTestCase):
"""Test sliding window attention functionality with triton backend."""
@classmethod
def setUpClass(cls):
"""Set up the test server with Gemma3 model and triton backend."""
# Gemma3 model supports sliding window attention
cls.model = "google/gemma-3-4b-it"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.common_args = [
"--trust-remote-code",
"--attention-backend",
"triton",
"--context-length",
"8192",
"--random-seed",
"42",
]
cls.short_context_prompt = "The capital of France is"
# Test prompt longer than window size
cls.long_context_prompt = (
"""
Once upon a time, there was a mountain. In the mountain, there was a temple. In the temple, there was an old monk telling a story. The story was:
"""
* 100
)
cls.long_context_prompt += "\nNow, summarize the story in one sentence:"
def _test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=200,
num_threads=32,
)
metrics = run_eval(args)
print(f"MMLU metrics with sliding window: {metrics}")
self.assertGreaterEqual(metrics["score"], 0.60)
def _test_short_context_generation(self):
response = requests.post(
self.base_url + "/generate",
json={
"text": self.short_context_prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 256,
},
},
)
self.assertEqual(response.status_code, 200)
result = response.json()
self.assertIn("paris", result["text"].lower())
print(f"Short context generation result: {result['text']}")
def _test_long_context_generation(self):
response = requests.post(
self.base_url + "/generate",
json={
"text": self.long_context_prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 256,
},
},
)
self.assertEqual(response.status_code, 200)
result = response.json()
self.assertGreater(len(result["text"].strip()), 0)
print(f"Long context generation result: {result['text'][:100]}...")
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
def test_no_cuda_graph(self):
self.no_cuda_graph_process = popen_launch_server(
self.model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=self.common_args + ["--disable-cuda-graph"],
)
try:
self._test_short_context_generation()
self._test_long_context_generation()
self._test_mmlu()
finally:
kill_process_tree(self.no_cuda_graph_process.pid)
def test_cuda_graph(self):
self.cuda_graph_process = popen_launch_server(
self.model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=self.common_args,
)
try:
self._test_short_context_generation()
self._test_long_context_generation()
self._test_mmlu()
finally:
kill_process_tree(self.cuda_graph_process.pid)
if __name__ == "__main__":
unittest.main()