sglang_v0.5.2/sglang/test/srt/models/test_vlm_models.py

316 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import glob
import json
import os
import random
import subprocess
import sys
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
# VLM models for testing
MODELS = [
SimpleNamespace(model="google/gemma-3-27b-it", mmmu_accuracy=0.45),
SimpleNamespace(
model="Qwen/Qwen2.5-VL-3B-Instruct",
mmmu_accuracy=0.4,
),
SimpleNamespace(model="openbmb/MiniCPM-V-2_6", mmmu_accuracy=0.4),
]
class TestVLMModels(CustomTestCase):
parsed_args = None # Class variable to store args
@classmethod
def setUpClass(cls):
# Removed argument parsing from here
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.time_out = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
# Set OpenAI API key and base URL environment variables. Needed for lmm-evals to work.
os.environ["OPENAI_API_KEY"] = cls.api_key
os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1"
def _detect_eviction_in_logs(self, log_output):
"""Detect if eviction events occurred in the log output."""
eviction_keywords = ["Cache eviction: evicted"]
eviction_detected = False
eviction_count = 0
for line in log_output.split("\n"):
if any(keyword in line for keyword in eviction_keywords):
eviction_detected = True
eviction_count += 1
print(f"Eviction detected: {line.strip()}")
return eviction_detected, eviction_count
def run_mmmu_eval(
self,
model_version: str,
output_path: str,
*,
env: dict | None = None,
):
"""
Evaluate a VLM on the MMMU validation set with lmmseval.
Only `model_version` (checkpoint) and `chat_template` vary;
We are focusing only on the validation set due to resource constraints.
"""
# -------- fixed settings --------
model = "openai_compatible"
tp = 1
tasks = "mmmu_val"
batch_size = 2
log_suffix = "openai_compatible"
os.makedirs(output_path, exist_ok=True)
# -------- compose --model_args --------
model_args = f'model_version="{model_version}",' f"tp={tp}"
# -------- build command list --------
cmd = [
"python3",
"-m",
"lmms_eval",
"--model",
model,
"--model_args",
model_args,
"--tasks",
tasks,
"--batch_size",
str(batch_size),
"--log_samples",
"--log_samples_suffix",
log_suffix,
"--output_path",
str(output_path),
]
subprocess.run(
cmd,
check=True,
timeout=3600,
)
def _run_vlm_mmmu_test(
self,
model,
output_path,
test_name="",
custom_env=None,
log_level="info",
capture_output=False,
):
"""
Common method to run VLM MMMU benchmark test.
Args:
model: Model to test
output_path: Path for output logs
test_name: Optional test name for logging
custom_env: Optional custom environment variables
log_level: Log level for server (default: "info")
capture_output: Whether to capture server stdout/stderr
"""
print(f"\nTesting model: {model.model}{test_name}")
process = None
mmmu_accuracy = 0 # Initialize to handle potential exceptions
server_output = ""
try:
# Prepare environment variables
process_env = os.environ.copy()
if custom_env:
process_env.update(custom_env)
# Prepare stdout/stderr redirection if needed
stdout_file = None
stderr_file = None
if capture_output:
stdout_file = open("/tmp/server_stdout.log", "w")
stderr_file = open("/tmp/server_stderr.log", "w")
# Launch server for testing
process = popen_launch_server(
model.model,
base_url=self.base_url,
timeout=self.time_out,
api_key=self.api_key,
other_args=[
"--trust-remote-code",
"--cuda-graph-max-bs",
"32",
"--enable-multimodal",
"--mem-fraction-static",
str(self.parsed_args.mem_fraction_static), # Use class variable
"--log-level",
log_level,
],
env=process_env,
return_stdout_stderr=(
(stdout_file, stderr_file) if capture_output else None
),
)
# Run evaluation
self.run_mmmu_eval(model.model, output_path)
# Get the result file
result_file_path = glob.glob(f"{output_path}/*.json")[0]
with open(result_file_path, "r") as f:
result = json.load(f)
print(f"Result{test_name}\n: {result}")
# Process the result
mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
print(
f"Model {model.model} achieved accuracy{test_name}: {mmmu_accuracy:.4f}"
)
# Capture server output if requested
if capture_output and process:
server_output = self._read_output_from_files()
# Assert performance meets expected threshold
self.assertGreaterEqual(
mmmu_accuracy,
model.mmmu_accuracy,
f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f}){test_name}",
)
return server_output
except Exception as e:
print(f"Error testing {model.model}{test_name}: {e}")
self.fail(f"Test failed for {model.model}{test_name}: {e}")
finally:
# Ensure process cleanup happens regardless of success/failure
if process is not None and process.poll() is None:
print(f"Cleaning up process {process.pid}")
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process: {e}")
# clean up temporary files
if capture_output:
if stdout_file:
stdout_file.close()
if stderr_file:
stderr_file.close()
for filename in ["/tmp/server_stdout.log", "/tmp/server_stderr.log"]:
try:
if os.path.exists(filename):
os.remove(filename)
except Exception as e:
print(f"Error removing {filename}: {e}")
def _read_output_from_files(self):
output_lines = []
log_files = [
("/tmp/server_stdout.log", "[STDOUT]"),
("/tmp/server_stderr.log", "[STDERR]"),
]
for filename, tag in log_files:
try:
if os.path.exists(filename):
with open(filename, "r") as f:
for line in f:
output_lines.append(f"{tag} {line.rstrip()}")
except Exception as e:
print(f"Error reading {tag.lower()} file: {e}")
return "\n".join(output_lines)
def test_vlm_mmmu_benchmark(self):
"""Test VLM models against MMMU benchmark."""
models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model in models_to_test:
self._run_vlm_mmmu_test(model, "./logs")
def test_vlm_mmmu_benchmark_with_small_cache(self):
"""Test VLM models against MMMU benchmark with a small embedding cache to force eviction."""
models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model in models_to_test:
custom_env = {"SGLANG_VLM_CACHE_SIZE_MB": "5"}
# Run the test with output capture
server_output = self._run_vlm_mmmu_test(
model,
"./logs_small_cache",
test_name=" with small embedding cache (evict test)",
custom_env=custom_env,
log_level="debug", # Enable debug logging for eviction detection
capture_output=True, # Capture server output
)
# Print server output for debugging
print("Server output:\n", server_output)
# Analyze server output for eviction events
eviction_detected, eviction_count = self._detect_eviction_in_logs(
server_output
)
# Assert that eviction was detected (since we're using small cache)
self.assertTrue(
eviction_detected,
f"Expected eviction events to be detected with small cache (5MB), but none found. "
f"Cache size may be too large for the workload or eviction logic may not be working. "
f"Total log content length: {len(server_output)} characters",
)
print(
f"Eviction detection summary: {eviction_count} eviction events detected"
)
# Additional assertion: if eviction was detected, the test passed
if eviction_detected:
print("✅ Eviction logic successfully triggered and detected!")
if __name__ == "__main__":
# Define and parse arguments here, before unittest.main
parser = argparse.ArgumentParser(description="Test VLM models")
parser.add_argument(
"--mem-fraction-static",
type=float,
help="Static memory fraction for the model",
default=0.8,
)
# Parse args intended for unittest
args = parser.parse_args()
# Store the parsed args object on the class
TestVLMModels.parsed_args = args
# Pass args to unittest
unittest.main(argv=[sys.argv[0]])