sglang0.4.5.post1/test/srt/test_gptqmodel_dynamic.py

215 lines
7.0 KiB
Python

import time
import unittest
import requests
import torch
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
def check_quant_method(model_path: str, use_marlin_kernel: bool):
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
get_tp_group,
init_distributed_environment,
initialize_model_parallel,
set_custom_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.quantization import get_dynamic_override
from sglang.srt.model_loader import get_model
from sglang.srt.server_args import PortArgs, ServerArgs
try:
init_distributed_environment(
backend="nccl",
world_size=1,
rank=0,
local_rank=0,
distributed_init_method="tcp://127.0.0.1:2646",
)
initialize_model_parallel(tensor_model_parallel_size=1)
monkey_patch_vllm_parallel_state()
except AssertionError:
# ignore this error: tensor model parallel group is already initialized
pass
server_args = ServerArgs(model_path=model_path, dtype=torch.float16)
model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
load_config = LoadConfig()
device_config = DeviceConfig("cuda")
model = get_model(
model_config=model_config, load_config=load_config, device_config=device_config
)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
)
from sglang.srt.layers.linear import UnquantizedLinearMethod
linear_method_cls = (
GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod)
)
for name, submodule in model.named_modules():
if name == "lm_head":
assert isinstance(submodule.quant_method, linear_method_cls)
elif name == "model.layers.0.self_attn.qkv_proj":
# The first layer is quantized using bits=4, group_size=128
# desc_act=True
assert isinstance(submodule.quant_method, linear_method_cls)
config = submodule.quant_method.quant_config
assert config.weight_bits == 4
assert config.group_size == 128
assert config.desc_act
elif name == "model.layers.1.self_attn.qkv_proj":
# The second layer is quantized using bits=8, group_size=32
# desc_act=False
assert isinstance(submodule.quant_method, linear_method_cls)
config = submodule.quant_method.quant_config
assert get_dynamic_override(config, layer_name=name, key="bits") == 8
assert get_dynamic_override(config, layer_name=name, key="group_size") == 32
assert not get_dynamic_override(config, layer_name=name, key="desc_act")
elif (
name == "model.layers.2.self_attn.qkv_proj"
or name == "model.layers.2.mlp.gate_up_proj"
):
# All other layers (layer index >= 2) are not quantized
assert isinstance(submodule.quant_method, UnquantizedLinearMethod)
del model
# GPTQ with Dynamic Per/Module Quantization Control
# Leverages GPTQModel (pypi) to produce the `dynamic` models
# Test GPTQ fallback kernel that is not Marlin
class TestGPTQModelDynamic(CustomTestCase):
MODEL_PATH = (
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse"
)
@classmethod
def setUpClass(cls):
cls.model = cls.MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--dtype", "float16"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, max_new_tokens):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"max_new_tokens": max_new_tokens,
"temperature": 0.001,
},
},
)
return response.json()
def test_throughput(self):
max_tokens = 256
tic = time.time()
result = self.run_decode(max_tokens)
tok = time.time()
print(f"result = `{result}`")
self.assertIn("paris", result["text"].lower())
throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s")
self.assertGreaterEqual(throughput, 140)
def test_gptq_module(self):
check_quant_method(self.MODEL_PATH, use_marlin_kernel=False)
# GPTQ with Dynamic Per/Module Quantization Control
# Leverages GPTQModel (pypi) to produce the `dynamic` models
# Test Marlin kernel
class TestGPTQModelDynamicWithMarlin(CustomTestCase):
MODEL_PATH = (
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue"
)
@classmethod
def setUpClass(cls):
cls.model = cls.MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--dtype", "float16"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, max_new_tokens):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"max_new_tokens": max_new_tokens,
"temperature": 0.001,
},
},
)
return response.json()
def test_throughput(self):
max_tokens = 256
tic = time.time()
result = self.run_decode(max_tokens)
tok = time.time()
print(f"result = `{result}`")
assert "paris" in result["text"].lower()
throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s")
assert throughput >= 140
def test_gptq_marlin_module(self):
check_quant_method(self.MODEL_PATH, use_marlin_kernel=True)
if __name__ == "__main__":
unittest.main()