sglang_v0.5.2/flashinfer_0.3.1/tests/conftest.py

1929 lines
30 KiB
Python

import functools
import gc
import os
import types
from typing import Any, Dict
import pytest
import torch
from torch.torch_version import TorchVersion
from torch.torch_version import __version__ as torch_version
import flashinfer
TORCH_COMPILE_FNS = [
flashinfer.activation.silu_and_mul,
flashinfer.activation.gelu_and_mul,
flashinfer.activation.gelu_tanh_and_mul,
flashinfer.cascade.merge_state,
flashinfer.cascade.merge_state_in_place,
flashinfer.cascade.merge_states,
flashinfer.cascade.MultiLevelCascadeAttentionWrapper.run,
flashinfer.cascade.BatchDecodeWithSharedPrefixPagedKVCacheWrapper.forward,
flashinfer.cascade.BatchPrefillWithSharedPrefixPagedKVCacheWrapper.forward,
flashinfer.decode.single_decode_with_kv_cache,
flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper.run,
flashinfer.gemm.bmm_fp8,
flashinfer.gemm.SegmentGEMMWrapper.run,
flashinfer.norm.rmsnorm,
flashinfer.norm.fused_add_rmsnorm,
flashinfer.norm.gemma_rmsnorm,
flashinfer.norm.gemma_fused_add_rmsnorm,
flashinfer.page.append_paged_kv_cache,
flashinfer.prefill.single_prefill_with_kv_cache,
flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper.run,
flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper.run,
flashinfer.quantization.packbits,
flashinfer.rope.apply_rope,
flashinfer.rope.apply_rope_inplace,
flashinfer.rope.apply_rope_pos_ids,
flashinfer.rope.apply_rope_pos_ids_inplace,
flashinfer.rope.apply_llama31_rope,
flashinfer.rope.apply_llama31_rope_inplace,
flashinfer.rope.apply_llama31_rope_pos_ids,
flashinfer.rope.apply_llama31_rope_pos_ids_inplace,
flashinfer.sampling.sampling_from_probs,
flashinfer.sampling.sampling_from_logits,
flashinfer.sampling.top_p_sampling_from_probs,
flashinfer.sampling.top_k_sampling_from_probs,
flashinfer.sampling.min_p_sampling_from_probs,
flashinfer.sampling.top_k_top_p_sampling_from_probs,
flashinfer.sampling.top_p_renorm_probs,
flashinfer.sampling.top_k_renorm_probs,
flashinfer.sampling.top_k_mask_logits,
flashinfer.sampling.chain_speculative_sampling,
]
_TORCH_COMPILE_CACHE: Dict[str, Any] = dict()
def _set_torch_compile_options():
import torch._dynamo.config
torch._dynamo.config.cache_size_limit = 128
def _monkeypatch_add_torch_compile(func):
"""
Replace the given function with its torch.compile version.
"""
from torch._library.custom_ops import CustomOpDef
if type(func) is types.FunctionType:
fn = func
elif isinstance(func, CustomOpDef):
fn = func._init_fn
else:
raise ValueError(f"Unsupported fn type {type(func)}")
fullname = fn.__module__ + "." + fn.__qualname__
components = fullname.split(".")
assert components[0] == "flashinfer"
module = flashinfer
for component in components[1:-1]:
module = getattr(module, component)
if not hasattr(module, components[-1]):
raise ValueError(f"Failed to monkeypatch: {fullname}")
def wrapper(*args, **kwargs):
compiled = _TORCH_COMPILE_CACHE.get(fullname)
if compiled is None:
# Warmup -- JIT compile / import the kernels.
#
# From user side, users also need to warmup the model beforehand,
# as suggested by PyTorch Cuda Graph docs (not sure if it's also
# recommended for torch.compile as well.)
#
# For the convenience of FlashInfer testing, we do the warmup here,
# on the first run of the function. The caveat is that the first
# call will run twice: once to warmup, and another through the
# compiled version.
func(*args, **kwargs)
# Compile
compiled = torch.compile(
func,
fullgraph=True,
backend="inductor",
mode="max-autotune-no-cudagraphs",
)
_TORCH_COMPILE_CACHE[fn.__name__] = compiled
return compiled(*args, **kwargs)
setattr(module, fn.__name__, wrapper)
print("Applied torch.compile to", fullname)
def pytest_configure(config):
if os.environ.get("FLASHINFER_TEST_TORCH_COMPILE", "0") == "1":
if torch_version < TorchVersion("2.4"):
pytest.skip("torch.compile requires torch >= 2.4")
_set_torch_compile_options()
for fn in TORCH_COMPILE_FNS:
_monkeypatch_add_torch_compile(fn)
def is_cuda_oom_error_str(e: str) -> bool:
return "CUDA" in e and "out of memory" in e
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_call(item):
# skip OOM error
try:
item.runtest()
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
if isinstance(e, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(str(e)):
pytest.skip("Skipping due to OOM")
else:
raise
@functools.cache
def get_device_properties(device: torch.device):
return torch.cuda.get_device_properties(device)
def clear_cuda_cache(device: torch.device) -> None:
total_memory = get_device_properties(device).total_memory
reserved_memory = torch.cuda.memory_reserved()
# FLASHINFER_TEST_MEMORY_THRESHOLD: threshold for PyTorch reserved memory usage (default: 0.9)
threshold = float(os.environ.get("FLASHINFER_TEST_MEMORY_THRESHOLD", "0.9"))
if reserved_memory > threshold * total_memory:
gc.collect()
torch.cuda.empty_cache()
# collected from gsk8k trace in sglang
VARLEN_INDPTR_PARAMS = [
[
0,
1276,
2551,
3838,
5115,
6428,
7705,
8985,
10293,
11607,
12909,
14216,
15508,
],
[
0,
1320,
2637,
3926,
5208,
6494,
7795,
9130,
10415,
11704,
12995,
14270,
15551,
],
[
0,
1333,
2762,
4068,
5345,
6635,
7936,
9237,
10518,
11874,
13146,
14466,
15785,
],
[
0,
1310,
2603,
3937,
5231,
6527,
7799,
9107,
10411,
11698,
12978,
14277,
15559,
],
[
0,
1286,
2561,
3857,
5163,
6459,
7763,
9057,
10380,
11679,
12968,
14284,
15610,
],
[0, 1350, 2667, 4003, 5347, 6631, 7919, 9208, 10524],
[
0,
1293,
2609,
3902,
5196,
6495,
7807,
9086,
10382,
11700,
12989,
14271,
15578,
],
[
0,
1276,
2551,
3838,
5115,
6428,
7705,
8985,
10293,
11607,
12909,
14216,
15508,
],
[
0,
1280,
2559,
3874,
5197,
6540,
7850,
9167,
10536,
11820,
13111,
14444,
15756,
],
[
0,
1306,
2598,
3895,
5181,
6458,
7750,
9047,
10333,
11635,
12906,
14207,
15514,
],
[
0,
1300,
2620,
3912,
5219,
6500,
7781,
9069,
10386,
11665,
12976,
14262,
15545,
],
[0, 1800, 3600, 5400, 7200, 9000, 10799, 12620, 14441, 16261],
[0, 1298, 2638],
[0, 1284],
[0, 1297, 2604],
[0, 1276],
[
0,
1286,
2614,
3909,
5207,
6490,
7785,
9067,
10356,
11633,
12915,
14231,
15511,
],
[
0,
1312,
2613,
3899,
5203,
6492,
7811,
9151,
10439,
11757,
13052,
14364,
15646,
],
[0, 1287],
[
0,
1353,
2684,
4039,
5326,
6615,
7932,
9217,
10528,
11862,
13207,
14490,
15785,
],
[0, 1307],
[
0,
1301,
2587,
3949,
5263,
6620,
7933,
9226,
10521,
11895,
13179,
14531,
15822,
],
[
0,
1334,
2618,
3914,
5198,
6476,
7771,
9076,
10362,
11675,
12974,
14264,
15540,
],
[
0,
1323,
2606,
3900,
5209,
6487,
7770,
9074,
10397,
11694,
13047,
14411,
15719,
],
[
0,
1286,
2563,
3845,
5181,
6471,
7789,
9087,
10394,
11750,
13021,
14335,
15616,
],
[
0,
1310,
2605,
3889,
5214,
6518,
7794,
9143,
10465,
11802,
13134,
14438,
15713,
],
[0, 1285, 2596, 3877, 5163, 6487, 7782, 9104, 10403],
[
0,
1299,
2594,
3921,
5222,
6494,
7777,
9098,
10406,
11736,
13026,
14317,
15621,
],
[0, 1268, 2560],
[0, 1536, 3061, 4578, 6177, 7774, 9378, 10958, 12636, 14292, 15954],
[0, 1240],
[
0,
1362,
2653,
3930,
5201,
6505,
7808,
9094,
10421,
11720,
12994,
14285,
15584,
],
[0, 1676, 3342, 5094, 6842, 8582, 10342, 12102, 13861, 15618],
[
0,
1329,
2656,
3977,
5253,
6532,
7831,
9150,
10444,
11749,
13090,
14388,
15675,
],
[
0,
1284,
2578,
3854,
5189,
6513,
7809,
9144,
10463,
11772,
13062,
14368,
15641,
],
[
0,
1286,
2651,
3960,
5286,
6573,
7857,
9146,
10445,
11723,
12995,
14270,
15537,
],
[0, 1716, 3555, 5298, 7041, 8781],
[0, 1281, 2574, 3866, 5176],
[
0,
1331,
2655,
3965,
5318,
6606,
7954,
9238,
10526,
11837,
13134,
14436,
15728,
],
[0, 1754, 3508, 5259, 7028, 8796, 10564, 12332, 14097, 15862],
[
0,
1293,
2584,
3871,
5157,
6466,
7831,
9118,
10444,
11728,
13017,
14295,
15594,
],
[
0,
1279,
2586,
3876,
5170,
6440,
7768,
9067,
10351,
11651,
12936,
14239,
15542,
],
[
0,
1293,
2563,
3861,
5139,
6491,
7776,
9121,
10422,
11731,
13033,
14338,
15639,
],
[
0,
1292,
2632,
3933,
5257,
6576,
7881,
9178,
10455,
11796,
13095,
14385,
15685,
],
[0, 1307],
[
0,
1307,
2590,
3897,
5206,
6527,
7826,
9104,
10400,
11696,
13022,
14326,
15615,
],
[
0,
1287,
2597,
3933,
5275,
6555,
7835,
9153,
10445,
11729,
13019,
14303,
15608,
],
[
0,
1294,
2589,
3904,
5205,
6504,
7803,
9087,
10375,
11671,
12970,
14279,
15615,
],
[
0,
1312,
2624,
3957,
5243,
6533,
7817,
9095,
10377,
11729,
13053,
14332,
15643,
],
[
0,
1278,
2579,
3858,
5147,
6461,
7745,
9038,
10376,
11654,
12941,
14265,
15592,
],
[
0,
1284,
2578,
3855,
5181,
6475,
7787,
9090,
10386,
11661,
13010,
14291,
15595,
],
[
0,
1293,
2604,
3893,
5211,
6526,
7859,
9139,
10439,
11723,
13071,
14369,
15669,
],
[
0,
1333,
2618,
3892,
5196,
6478,
7778,
9088,
10381,
11677,
12986,
14276,
15552,
],
[
0,
1287,
2569,
3876,
5156,
6463,
7754,
9053,
10363,
11642,
12946,
14230,
15501,
],
[
0,
1293,
2585,
3900,
5183,
6502,
7882,
9185,
10466,
11732,
13017,
14324,
15612,
],
[
0,
1323,
2597,
3877,
5183,
6483,
7793,
9084,
10417,
11719,
12999,
14294,
15587,
],
[
0,
1289,
2626,
3900,
5217,
6550,
7820,
9140,
10431,
11717,
13028,
14312,
15615,
],
[
0,
1306,
2588,
3890,
5192,
6540,
7858,
9170,
10492,
11772,
13051,
14348,
15636,
],
[
0,
1279,
2583,
3892,
5193,
6481,
7788,
9099,
10394,
11701,
13026,
14348,
15710,
],
[
0,
1287,
2599,
3939,
5223,
6523,
7822,
9102,
10435,
11714,
13006,
14294,
15622,
],
[
0,
1302,
2631,
3913,
5192,
6503,
7804,
9121,
10429,
11757,
13064,
14379,
15656,
],
[
0,
1278,
2569,
3914,
5211,
6480,
7805,
9089,
10383,
11687,
12971,
14281,
15605,
],
[
0,
1278,
2559,
3834,
5144,
6434,
7754,
9033,
10330,
11607,
12925,
14218,
15510,
],
[0, 1319],
[
0,
1269,
2564,
3849,
5130,
6430,
7740,
9060,
10409,
11698,
13001,
14286,
15557,
],
[
0,
1288,
2592,
3867,
5214,
6491,
7793,
9110,
10416,
11729,
13020,
14318,
15625,
],
[
0,
1326,
2643,
3972,
5270,
6591,
7872,
9139,
10437,
11731,
13031,
14327,
15633,
],
[
0,
1284,
2560,
3857,
5134,
6436,
7728,
9041,
10345,
11625,
12940,
14242,
15530,
],
[0, 1299, 2576],
[
0,
1296,
2574,
3866,
5162,
6448,
7745,
9020,
10294,
11588,
12895,
14218,
15525,
],
[
0,
1279,
2563,
3875,
5161,
6461,
7741,
9023,
10305,
11613,
12897,
14204,
15536,
],
[
0,
1273,
2553,
3848,
5210,
6493,
7775,
9058,
10375,
11695,
12984,
14278,
15588,
],
[
0,
1283,
2584,
3863,
5160,
6444,
7740,
9061,
10377,
11698,
12994,
14274,
15545,
],
[
0,
1329,
2648,
3962,
5309,
6622,
7930,
9242,
10544,
11828,
13183,
14476,
15809,
],
[
0,
1290,
2591,
3891,
5175,
6460,
7766,
9112,
10402,
11701,
13019,
14330,
15633,
],
[
0,
1333,
2673,
3958,
5270,
6589,
7911,
9203,
10549,
11841,
13146,
14471,
15776,
],
[
0,
1288,
2643,
3945,
5266,
6595,
7907,
9213,
10486,
11807,
13138,
14430,
15703,
],
[
0,
1306,
2620,
3944,
5260,
6569,
7852,
9144,
10460,
11785,
13075,
14368,
15672,
],
[
0,
1294,
2572,
3851,
5164,
6464,
7755,
9090,
10398,
11688,
13002,
14313,
15593,
],
[
0,
1340,
2651,
3959,
5258,
6545,
7836,
9157,
10465,
11772,
13065,
14368,
15747,
],
[
0,
1325,
2657,
3935,
5255,
6583,
7874,
9154,
10448,
11732,
13026,
14344,
15620,
],
[0, 1764, 3551, 5336, 7121, 8905, 10688, 12471, 14252, 16054],
[
0,
1280,
2590,
3896,
5187,
6520,
7822,
9117,
10397,
11690,
12977,
14270,
15561,
],
[
0,
1285,
2577,
3862,
5198,
6477,
7762,
9130,
10412,
11694,
13049,
14358,
15666,
],
[
0,
1287,
2617,
3942,
5240,
6510,
7807,
9090,
10390,
11743,
13031,
14325,
15615,
],
[0, 1310, 2584, 3990, 5291, 6598, 7908, 9192],
[
0,
1304,
2626,
3930,
5209,
6499,
7810,
9109,
10435,
11731,
13007,
14307,
15593,
],
[
0,
1308,
2612,
3927,
5227,
6515,
7812,
9146,
10447,
11731,
13017,
14317,
15602,
],
[0, 1820, 3640, 5460, 7277, 9115, 10953, 12791, 14628],
[
0,
1289,
2594,
3903,
5196,
6499,
7799,
9077,
10386,
11662,
12959,
14243,
15543,
],
[
0,
1300,
2601,
3876,
5165,
6436,
7725,
9039,
10352,
11639,
12927,
14209,
15490,
],
[0, 1837, 3674, 5206, 6693, 8229, 9790, 11329, 12910, 14474, 16037],
[
0,
1292,
2604,
3878,
5151,
6453,
7749,
9033,
10363,
11703,
13014,
14301,
15617,
],
[
0,
1275,
2556,
3843,
5147,
6427,
7712,
9003,
10311,
11600,
12970,
14264,
15545,
],
[
0,
1285,
2590,
3878,
5169,
6527,
7863,
9161,
10451,
11745,
13066,
14382,
15695,
],
[0, 1340, 2635],
[
0,
1314,
2600,
3894,
5194,
6490,
7797,
9105,
10385,
11667,
12967,
14255,
15550,
],
[
0,
1308,
2605,
3956,
5254,
6582,
7865,
9160,
10459,
11758,
13045,
14341,
15623,
],
[
0,
1282,
2576,
3882,
5190,
6510,
7819,
9142,
10427,
11736,
13041,
14359,
15683,
],
[0, 1300, 2614, 3924],
[
0,
1282,
2600,
3923,
5229,
6580,
7952,
9295,
10593,
11873,
13161,
14458,
15756,
],
[
0,
1286,
2578,
3884,
5184,
6494,
7779,
9078,
10356,
11677,
12976,
14256,
15560,
],
[
0,
1303,
2575,
3848,
5119,
6417,
7714,
9020,
10362,
11668,
12983,
14314,
15599,
],
[0, 1291, 2584],
[
0,
1299,
2617,
3938,
5328,
6600,
7885,
9163,
10489,
11771,
13053,
14332,
15691,
],
[0, 1305, 2617],
[0, 1573, 3257, 4935, 6605, 8256, 9906, 11529, 13171, 14809],
[
0,
1299,
2591,
3885,
5165,
6445,
7744,
9111,
10413,
11725,
13000,
14304,
15614,
],
[0, 1296],
[
0,
1295,
2570,
3912,
5252,
6527,
7806,
9121,
10408,
11710,
12988,
14270,
15585,
],
[
0,
1285,
2621,
3937,
5235,
6506,
7790,
9085,
10352,
11630,
12949,
14247,
15528,
],
[
0,
1297,
2575,
3868,
5146,
6436,
7775,
9066,
10376,
11708,
13005,
14365,
15649,
],
[
0,
1322,
2638,
3920,
5217,
6522,
7801,
9113,
10472,
11769,
13046,
14372,
15668,
],
[
0,
1272,
2539,
3871,
5146,
6471,
7791,
9069,
10360,
11688,
12968,
14262,
15580,
],
[
0,
1322,
2642,
3933,
5229,
6538,
7823,
9126,
10432,
11734,
13089,
14372,
15678,
],
[
0,
1310,
2658,
3987,
5316,
6608,
7878,
9171,
10463,
11757,
13060,
14356,
15660,
],
[
0,
1318,
2640,
3924,
5237,
6546,
7832,
9138,
10462,
11762,
13046,
14341,
15609,
],
[
0,
1280,
2558,
3850,
5191,
6495,
7820,
9113,
10401,
11717,
13040,
14314,
15614,
],
[
0,
1313,
2596,
3908,
5249,
6542,
7843,
9141,
10456,
11739,
13039,
14348,
15699,
],
[0, 1309],
[0, 1400, 2689],
[
0,
1362,
2646,
3947,
5228,
6517,
7824,
9116,
10402,
11683,
12976,
14271,
15583,
],
[
0,
1303,
2653,
3937,
5234,
6541,
7861,
9224,
10606,
11897,
13213,
14544,
15851,
],
[
0,
1309,
2636,
3924,
5216,
6500,
7775,
9085,
10380,
11696,
12999,
14337,
15613,
],
[0, 1310, 2611, 3904, 5238, 6532, 7804, 9100, 10408, 11707, 13011],
[
0,
1313,
2646,
3956,
5263,
6587,
7949,
9257,
10555,
11837,
13104,
14394,
15724,
],
[
0,
1321,
2612,
3915,
5231,
6551,
7838,
9128,
10440,
11759,
13099,
14416,
15700,
],
[
0,
1283,
2592,
3872,
5194,
6467,
7751,
9040,
10321,
11673,
13010,
14304,
15602,
],
[0, 1270, 2622, 3915, 5193, 6478, 7776, 9085, 10430, 11732, 13033, 14338],
[0, 1296, 2631, 3955],
[
0,
1315,
2622,
3949,
5243,
6592,
7894,
9216,
10533,
11830,
13123,
14419,
15722,
],
[
0,
1296,
2590,
3913,
5221,
6504,
7778,
9125,
10426,
11782,
13051,
14328,
15637,
],
[
0,
1294,
2579,
3886,
5160,
6456,
7746,
9047,
10347,
11638,
12962,
14261,
15550,
],
[0, 7],
[
0,
1298,
2599,
3887,
5201,
6506,
7843,
9158,
10456,
11749,
13058,
14337,
15630,
],
[
0,
1290,
2598,
3876,
5177,
6473,
7790,
9065,
10362,
11640,
12943,
14287,
15582,
],
[
0,
1333,
2623,
3903,
5189,
6467,
7759,
9063,
10388,
11729,
13022,
14310,
15626,
],
[
0,
1322,
2615,
3921,
5206,
6491,
7811,
9109,
10394,
11691,
12969,
14256,
15532,
],
[
0,
1302,
2610,
3942,
5267,
6545,
7859,
9154,
10460,
11733,
13053,
14326,
15661,
],
[0, 1289, 2616],
[
0,
1291,
2640,
3932,
5229,
6547,
7903,
9205,
10547,
11857,
13171,
14484,
15771,
],
[0, 1240],
[
0,
1289,
2665,
3954,
5276,
6576,
7883,
9167,
10535,
11868,
13215,
14548,
15862,
],
[
0,
1299,
2606,
3913,
5223,
6514,
7793,
9097,
10381,
11652,
12936,
14228,
15513,
],
[
0,
1334,
2615,
3932,
5214,
6511,
7818,
9109,
10403,
11701,
13036,
14306,
15648,
],
[
0,
1315,
2613,
3889,
5215,
6490,
7799,
9110,
10407,
11684,
13016,
14333,
15639,
],
[
0,
1304,
2591,
3907,
5275,
6563,
7887,
9203,
10539,
11836,
13169,
14459,
15745,
],
[
0,
1279,
2548,
3860,
5216,
6529,
7833,
9102,
10400,
11697,
13002,
14313,
15638,
],
[
0,
1284,
2569,
3861,
5165,
6452,
7768,
9056,
10424,
11748,
13064,
14361,
15697,
],
[0, 1302, 2600],
[0, 1289, 2586],
[0, 1287, 2577, 3855],
]
def assert_close_with_mismatch_tolerance(
actual: torch.Tensor,
expected: torch.Tensor,
rtol: float = 1e-5,
atol: float = 1e-8,
max_mismatched_elements: int = 0,
):
"""
Asserts that two tensors are close, allowing for a specified number of mismatched elements.
This function correctly implements the same logic as torch.isclose.
"""
# Ensure tensors are float for comparison
actual_float = actual.float()
expected_float = expected.float()
# This is the core logic from torch.isclose
# A mismatch occurs if the difference is greater than the combined tolerance
mismatched = torch.abs(actual_float - expected_float) > (
atol + rtol * torch.abs(expected_float)
)
num_mismatched = torch.sum(mismatched).item()
if num_mismatched > max_mismatched_elements:
# For a helpful error message, let's find the worst offenders
actual_flat = actual_float.flatten()
expected_flat = expected_float.flatten()
abs_diff = torch.abs(actual_flat - expected_flat)
# Calculate relative difference only where expected is not zero to avoid division by zero
# Add a small epsilon to the denominator for stability
rel_diff = abs_diff / (torch.abs(expected_flat) + 1e-12)
total_elements = actual_flat.numel()
raise AssertionError(
f"Tensors are not close enough!\n"
f"Mismatched elements: {num_mismatched} / {total_elements} "
f"({100.0 * num_mismatched / total_elements:.2f}%)\n"
f"Allowed mismatched elements: {max_mismatched_elements}, but found {num_mismatched}.\n"
f"Greatest absolute difference: {torch.max(abs_diff).item():.4g} (atol={atol})\n"
f"Greatest relative difference: {torch.max(rel_diff).item():.4g} (rtol={rtol})"
)