import functools import gc import os import types from typing import Any, Dict import pytest import torch from torch.torch_version import TorchVersion from torch.torch_version import __version__ as torch_version import flashinfer TORCH_COMPILE_FNS = [ flashinfer.activation.silu_and_mul, flashinfer.activation.gelu_and_mul, flashinfer.activation.gelu_tanh_and_mul, flashinfer.cascade.merge_state, flashinfer.cascade.merge_state_in_place, flashinfer.cascade.merge_states, flashinfer.cascade.MultiLevelCascadeAttentionWrapper.run, flashinfer.cascade.BatchDecodeWithSharedPrefixPagedKVCacheWrapper.forward, flashinfer.cascade.BatchPrefillWithSharedPrefixPagedKVCacheWrapper.forward, flashinfer.decode.single_decode_with_kv_cache, flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper.run, flashinfer.gemm.bmm_fp8, flashinfer.gemm.SegmentGEMMWrapper.run, flashinfer.norm.rmsnorm, flashinfer.norm.fused_add_rmsnorm, flashinfer.norm.gemma_rmsnorm, flashinfer.norm.gemma_fused_add_rmsnorm, flashinfer.page.append_paged_kv_cache, flashinfer.prefill.single_prefill_with_kv_cache, flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper.run, flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper.run, flashinfer.quantization.packbits, flashinfer.rope.apply_rope, flashinfer.rope.apply_rope_inplace, flashinfer.rope.apply_rope_pos_ids, flashinfer.rope.apply_rope_pos_ids_inplace, flashinfer.rope.apply_llama31_rope, flashinfer.rope.apply_llama31_rope_inplace, flashinfer.rope.apply_llama31_rope_pos_ids, flashinfer.rope.apply_llama31_rope_pos_ids_inplace, flashinfer.sampling.sampling_from_probs, flashinfer.sampling.sampling_from_logits, flashinfer.sampling.top_p_sampling_from_probs, flashinfer.sampling.top_k_sampling_from_probs, flashinfer.sampling.min_p_sampling_from_probs, flashinfer.sampling.top_k_top_p_sampling_from_probs, flashinfer.sampling.top_p_renorm_probs, flashinfer.sampling.top_k_renorm_probs, flashinfer.sampling.top_k_mask_logits, flashinfer.sampling.chain_speculative_sampling, ] _TORCH_COMPILE_CACHE: Dict[str, Any] = dict() def _set_torch_compile_options(): import torch._dynamo.config torch._dynamo.config.cache_size_limit = 128 def _monkeypatch_add_torch_compile(func): """ Replace the given function with its torch.compile version. """ from torch._library.custom_ops import CustomOpDef if type(func) is types.FunctionType: fn = func elif isinstance(func, CustomOpDef): fn = func._init_fn else: raise ValueError(f"Unsupported fn type {type(func)}") fullname = fn.__module__ + "." + fn.__qualname__ components = fullname.split(".") assert components[0] == "flashinfer" module = flashinfer for component in components[1:-1]: module = getattr(module, component) if not hasattr(module, components[-1]): raise ValueError(f"Failed to monkeypatch: {fullname}") def wrapper(*args, **kwargs): compiled = _TORCH_COMPILE_CACHE.get(fullname) if compiled is None: # Warmup -- JIT compile / import the kernels. # # From user side, users also need to warmup the model beforehand, # as suggested by PyTorch Cuda Graph docs (not sure if it's also # recommended for torch.compile as well.) # # For the convenience of FlashInfer testing, we do the warmup here, # on the first run of the function. The caveat is that the first # call will run twice: once to warmup, and another through the # compiled version. func(*args, **kwargs) # Compile compiled = torch.compile( func, fullgraph=True, backend="inductor", mode="max-autotune-no-cudagraphs", ) _TORCH_COMPILE_CACHE[fn.__name__] = compiled return compiled(*args, **kwargs) setattr(module, fn.__name__, wrapper) print("Applied torch.compile to", fullname) def pytest_configure(config): if os.environ.get("FLASHINFER_TEST_TORCH_COMPILE", "0") == "1": if torch_version < TorchVersion("2.4"): pytest.skip("torch.compile requires torch >= 2.4") _set_torch_compile_options() for fn in TORCH_COMPILE_FNS: _monkeypatch_add_torch_compile(fn) def is_cuda_oom_error_str(e: str) -> bool: return "CUDA" in e and "out of memory" in e @pytest.hookimpl(tryfirst=True) def pytest_runtest_call(item): # skip OOM error try: item.runtest() except (torch.cuda.OutOfMemoryError, RuntimeError) as e: if isinstance(e, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(str(e)): pytest.skip("Skipping due to OOM") else: raise @functools.cache def get_device_properties(device: torch.device): return torch.cuda.get_device_properties(device) def clear_cuda_cache(device: torch.device) -> None: total_memory = get_device_properties(device).total_memory reserved_memory = torch.cuda.memory_reserved() # FLASHINFER_TEST_MEMORY_THRESHOLD: threshold for PyTorch reserved memory usage (default: 0.9) threshold = float(os.environ.get("FLASHINFER_TEST_MEMORY_THRESHOLD", "0.9")) if reserved_memory > threshold * total_memory: gc.collect() torch.cuda.empty_cache() # collected from gsk8k trace in sglang VARLEN_INDPTR_PARAMS = [ [ 0, 1276, 2551, 3838, 5115, 6428, 7705, 8985, 10293, 11607, 12909, 14216, 15508, ], [ 0, 1320, 2637, 3926, 5208, 6494, 7795, 9130, 10415, 11704, 12995, 14270, 15551, ], [ 0, 1333, 2762, 4068, 5345, 6635, 7936, 9237, 10518, 11874, 13146, 14466, 15785, ], [ 0, 1310, 2603, 3937, 5231, 6527, 7799, 9107, 10411, 11698, 12978, 14277, 15559, ], [ 0, 1286, 2561, 3857, 5163, 6459, 7763, 9057, 10380, 11679, 12968, 14284, 15610, ], [0, 1350, 2667, 4003, 5347, 6631, 7919, 9208, 10524], [ 0, 1293, 2609, 3902, 5196, 6495, 7807, 9086, 10382, 11700, 12989, 14271, 15578, ], [ 0, 1276, 2551, 3838, 5115, 6428, 7705, 8985, 10293, 11607, 12909, 14216, 15508, ], [ 0, 1280, 2559, 3874, 5197, 6540, 7850, 9167, 10536, 11820, 13111, 14444, 15756, ], [ 0, 1306, 2598, 3895, 5181, 6458, 7750, 9047, 10333, 11635, 12906, 14207, 15514, ], [ 0, 1300, 2620, 3912, 5219, 6500, 7781, 9069, 10386, 11665, 12976, 14262, 15545, ], [0, 1800, 3600, 5400, 7200, 9000, 10799, 12620, 14441, 16261], [0, 1298, 2638], [0, 1284], [0, 1297, 2604], [0, 1276], [ 0, 1286, 2614, 3909, 5207, 6490, 7785, 9067, 10356, 11633, 12915, 14231, 15511, ], [ 0, 1312, 2613, 3899, 5203, 6492, 7811, 9151, 10439, 11757, 13052, 14364, 15646, ], [0, 1287], [ 0, 1353, 2684, 4039, 5326, 6615, 7932, 9217, 10528, 11862, 13207, 14490, 15785, ], [0, 1307], [ 0, 1301, 2587, 3949, 5263, 6620, 7933, 9226, 10521, 11895, 13179, 14531, 15822, ], [ 0, 1334, 2618, 3914, 5198, 6476, 7771, 9076, 10362, 11675, 12974, 14264, 15540, ], [ 0, 1323, 2606, 3900, 5209, 6487, 7770, 9074, 10397, 11694, 13047, 14411, 15719, ], [ 0, 1286, 2563, 3845, 5181, 6471, 7789, 9087, 10394, 11750, 13021, 14335, 15616, ], [ 0, 1310, 2605, 3889, 5214, 6518, 7794, 9143, 10465, 11802, 13134, 14438, 15713, ], [0, 1285, 2596, 3877, 5163, 6487, 7782, 9104, 10403], [ 0, 1299, 2594, 3921, 5222, 6494, 7777, 9098, 10406, 11736, 13026, 14317, 15621, ], [0, 1268, 2560], [0, 1536, 3061, 4578, 6177, 7774, 9378, 10958, 12636, 14292, 15954], [0, 1240], [ 0, 1362, 2653, 3930, 5201, 6505, 7808, 9094, 10421, 11720, 12994, 14285, 15584, ], [0, 1676, 3342, 5094, 6842, 8582, 10342, 12102, 13861, 15618], [ 0, 1329, 2656, 3977, 5253, 6532, 7831, 9150, 10444, 11749, 13090, 14388, 15675, ], [ 0, 1284, 2578, 3854, 5189, 6513, 7809, 9144, 10463, 11772, 13062, 14368, 15641, ], [ 0, 1286, 2651, 3960, 5286, 6573, 7857, 9146, 10445, 11723, 12995, 14270, 15537, ], [0, 1716, 3555, 5298, 7041, 8781], [0, 1281, 2574, 3866, 5176], [ 0, 1331, 2655, 3965, 5318, 6606, 7954, 9238, 10526, 11837, 13134, 14436, 15728, ], [0, 1754, 3508, 5259, 7028, 8796, 10564, 12332, 14097, 15862], [ 0, 1293, 2584, 3871, 5157, 6466, 7831, 9118, 10444, 11728, 13017, 14295, 15594, ], [ 0, 1279, 2586, 3876, 5170, 6440, 7768, 9067, 10351, 11651, 12936, 14239, 15542, ], [ 0, 1293, 2563, 3861, 5139, 6491, 7776, 9121, 10422, 11731, 13033, 14338, 15639, ], [ 0, 1292, 2632, 3933, 5257, 6576, 7881, 9178, 10455, 11796, 13095, 14385, 15685, ], [0, 1307], [ 0, 1307, 2590, 3897, 5206, 6527, 7826, 9104, 10400, 11696, 13022, 14326, 15615, ], [ 0, 1287, 2597, 3933, 5275, 6555, 7835, 9153, 10445, 11729, 13019, 14303, 15608, ], [ 0, 1294, 2589, 3904, 5205, 6504, 7803, 9087, 10375, 11671, 12970, 14279, 15615, ], [ 0, 1312, 2624, 3957, 5243, 6533, 7817, 9095, 10377, 11729, 13053, 14332, 15643, ], [ 0, 1278, 2579, 3858, 5147, 6461, 7745, 9038, 10376, 11654, 12941, 14265, 15592, ], [ 0, 1284, 2578, 3855, 5181, 6475, 7787, 9090, 10386, 11661, 13010, 14291, 15595, ], [ 0, 1293, 2604, 3893, 5211, 6526, 7859, 9139, 10439, 11723, 13071, 14369, 15669, ], [ 0, 1333, 2618, 3892, 5196, 6478, 7778, 9088, 10381, 11677, 12986, 14276, 15552, ], [ 0, 1287, 2569, 3876, 5156, 6463, 7754, 9053, 10363, 11642, 12946, 14230, 15501, ], [ 0, 1293, 2585, 3900, 5183, 6502, 7882, 9185, 10466, 11732, 13017, 14324, 15612, ], [ 0, 1323, 2597, 3877, 5183, 6483, 7793, 9084, 10417, 11719, 12999, 14294, 15587, ], [ 0, 1289, 2626, 3900, 5217, 6550, 7820, 9140, 10431, 11717, 13028, 14312, 15615, ], [ 0, 1306, 2588, 3890, 5192, 6540, 7858, 9170, 10492, 11772, 13051, 14348, 15636, ], [ 0, 1279, 2583, 3892, 5193, 6481, 7788, 9099, 10394, 11701, 13026, 14348, 15710, ], [ 0, 1287, 2599, 3939, 5223, 6523, 7822, 9102, 10435, 11714, 13006, 14294, 15622, ], [ 0, 1302, 2631, 3913, 5192, 6503, 7804, 9121, 10429, 11757, 13064, 14379, 15656, ], [ 0, 1278, 2569, 3914, 5211, 6480, 7805, 9089, 10383, 11687, 12971, 14281, 15605, ], [ 0, 1278, 2559, 3834, 5144, 6434, 7754, 9033, 10330, 11607, 12925, 14218, 15510, ], [0, 1319], [ 0, 1269, 2564, 3849, 5130, 6430, 7740, 9060, 10409, 11698, 13001, 14286, 15557, ], [ 0, 1288, 2592, 3867, 5214, 6491, 7793, 9110, 10416, 11729, 13020, 14318, 15625, ], [ 0, 1326, 2643, 3972, 5270, 6591, 7872, 9139, 10437, 11731, 13031, 14327, 15633, ], [ 0, 1284, 2560, 3857, 5134, 6436, 7728, 9041, 10345, 11625, 12940, 14242, 15530, ], [0, 1299, 2576], [ 0, 1296, 2574, 3866, 5162, 6448, 7745, 9020, 10294, 11588, 12895, 14218, 15525, ], [ 0, 1279, 2563, 3875, 5161, 6461, 7741, 9023, 10305, 11613, 12897, 14204, 15536, ], [ 0, 1273, 2553, 3848, 5210, 6493, 7775, 9058, 10375, 11695, 12984, 14278, 15588, ], [ 0, 1283, 2584, 3863, 5160, 6444, 7740, 9061, 10377, 11698, 12994, 14274, 15545, ], [ 0, 1329, 2648, 3962, 5309, 6622, 7930, 9242, 10544, 11828, 13183, 14476, 15809, ], [ 0, 1290, 2591, 3891, 5175, 6460, 7766, 9112, 10402, 11701, 13019, 14330, 15633, ], [ 0, 1333, 2673, 3958, 5270, 6589, 7911, 9203, 10549, 11841, 13146, 14471, 15776, ], [ 0, 1288, 2643, 3945, 5266, 6595, 7907, 9213, 10486, 11807, 13138, 14430, 15703, ], [ 0, 1306, 2620, 3944, 5260, 6569, 7852, 9144, 10460, 11785, 13075, 14368, 15672, ], [ 0, 1294, 2572, 3851, 5164, 6464, 7755, 9090, 10398, 11688, 13002, 14313, 15593, ], [ 0, 1340, 2651, 3959, 5258, 6545, 7836, 9157, 10465, 11772, 13065, 14368, 15747, ], [ 0, 1325, 2657, 3935, 5255, 6583, 7874, 9154, 10448, 11732, 13026, 14344, 15620, ], [0, 1764, 3551, 5336, 7121, 8905, 10688, 12471, 14252, 16054], [ 0, 1280, 2590, 3896, 5187, 6520, 7822, 9117, 10397, 11690, 12977, 14270, 15561, ], [ 0, 1285, 2577, 3862, 5198, 6477, 7762, 9130, 10412, 11694, 13049, 14358, 15666, ], [ 0, 1287, 2617, 3942, 5240, 6510, 7807, 9090, 10390, 11743, 13031, 14325, 15615, ], [0, 1310, 2584, 3990, 5291, 6598, 7908, 9192], [ 0, 1304, 2626, 3930, 5209, 6499, 7810, 9109, 10435, 11731, 13007, 14307, 15593, ], [ 0, 1308, 2612, 3927, 5227, 6515, 7812, 9146, 10447, 11731, 13017, 14317, 15602, ], [0, 1820, 3640, 5460, 7277, 9115, 10953, 12791, 14628], [ 0, 1289, 2594, 3903, 5196, 6499, 7799, 9077, 10386, 11662, 12959, 14243, 15543, ], [ 0, 1300, 2601, 3876, 5165, 6436, 7725, 9039, 10352, 11639, 12927, 14209, 15490, ], [0, 1837, 3674, 5206, 6693, 8229, 9790, 11329, 12910, 14474, 16037], [ 0, 1292, 2604, 3878, 5151, 6453, 7749, 9033, 10363, 11703, 13014, 14301, 15617, ], [ 0, 1275, 2556, 3843, 5147, 6427, 7712, 9003, 10311, 11600, 12970, 14264, 15545, ], [ 0, 1285, 2590, 3878, 5169, 6527, 7863, 9161, 10451, 11745, 13066, 14382, 15695, ], [0, 1340, 2635], [ 0, 1314, 2600, 3894, 5194, 6490, 7797, 9105, 10385, 11667, 12967, 14255, 15550, ], [ 0, 1308, 2605, 3956, 5254, 6582, 7865, 9160, 10459, 11758, 13045, 14341, 15623, ], [ 0, 1282, 2576, 3882, 5190, 6510, 7819, 9142, 10427, 11736, 13041, 14359, 15683, ], [0, 1300, 2614, 3924], [ 0, 1282, 2600, 3923, 5229, 6580, 7952, 9295, 10593, 11873, 13161, 14458, 15756, ], [ 0, 1286, 2578, 3884, 5184, 6494, 7779, 9078, 10356, 11677, 12976, 14256, 15560, ], [ 0, 1303, 2575, 3848, 5119, 6417, 7714, 9020, 10362, 11668, 12983, 14314, 15599, ], [0, 1291, 2584], [ 0, 1299, 2617, 3938, 5328, 6600, 7885, 9163, 10489, 11771, 13053, 14332, 15691, ], [0, 1305, 2617], [0, 1573, 3257, 4935, 6605, 8256, 9906, 11529, 13171, 14809], [ 0, 1299, 2591, 3885, 5165, 6445, 7744, 9111, 10413, 11725, 13000, 14304, 15614, ], [0, 1296], [ 0, 1295, 2570, 3912, 5252, 6527, 7806, 9121, 10408, 11710, 12988, 14270, 15585, ], [ 0, 1285, 2621, 3937, 5235, 6506, 7790, 9085, 10352, 11630, 12949, 14247, 15528, ], [ 0, 1297, 2575, 3868, 5146, 6436, 7775, 9066, 10376, 11708, 13005, 14365, 15649, ], [ 0, 1322, 2638, 3920, 5217, 6522, 7801, 9113, 10472, 11769, 13046, 14372, 15668, ], [ 0, 1272, 2539, 3871, 5146, 6471, 7791, 9069, 10360, 11688, 12968, 14262, 15580, ], [ 0, 1322, 2642, 3933, 5229, 6538, 7823, 9126, 10432, 11734, 13089, 14372, 15678, ], [ 0, 1310, 2658, 3987, 5316, 6608, 7878, 9171, 10463, 11757, 13060, 14356, 15660, ], [ 0, 1318, 2640, 3924, 5237, 6546, 7832, 9138, 10462, 11762, 13046, 14341, 15609, ], [ 0, 1280, 2558, 3850, 5191, 6495, 7820, 9113, 10401, 11717, 13040, 14314, 15614, ], [ 0, 1313, 2596, 3908, 5249, 6542, 7843, 9141, 10456, 11739, 13039, 14348, 15699, ], [0, 1309], [0, 1400, 2689], [ 0, 1362, 2646, 3947, 5228, 6517, 7824, 9116, 10402, 11683, 12976, 14271, 15583, ], [ 0, 1303, 2653, 3937, 5234, 6541, 7861, 9224, 10606, 11897, 13213, 14544, 15851, ], [ 0, 1309, 2636, 3924, 5216, 6500, 7775, 9085, 10380, 11696, 12999, 14337, 15613, ], [0, 1310, 2611, 3904, 5238, 6532, 7804, 9100, 10408, 11707, 13011], [ 0, 1313, 2646, 3956, 5263, 6587, 7949, 9257, 10555, 11837, 13104, 14394, 15724, ], [ 0, 1321, 2612, 3915, 5231, 6551, 7838, 9128, 10440, 11759, 13099, 14416, 15700, ], [ 0, 1283, 2592, 3872, 5194, 6467, 7751, 9040, 10321, 11673, 13010, 14304, 15602, ], [0, 1270, 2622, 3915, 5193, 6478, 7776, 9085, 10430, 11732, 13033, 14338], [0, 1296, 2631, 3955], [ 0, 1315, 2622, 3949, 5243, 6592, 7894, 9216, 10533, 11830, 13123, 14419, 15722, ], [ 0, 1296, 2590, 3913, 5221, 6504, 7778, 9125, 10426, 11782, 13051, 14328, 15637, ], [ 0, 1294, 2579, 3886, 5160, 6456, 7746, 9047, 10347, 11638, 12962, 14261, 15550, ], [0, 7], [ 0, 1298, 2599, 3887, 5201, 6506, 7843, 9158, 10456, 11749, 13058, 14337, 15630, ], [ 0, 1290, 2598, 3876, 5177, 6473, 7790, 9065, 10362, 11640, 12943, 14287, 15582, ], [ 0, 1333, 2623, 3903, 5189, 6467, 7759, 9063, 10388, 11729, 13022, 14310, 15626, ], [ 0, 1322, 2615, 3921, 5206, 6491, 7811, 9109, 10394, 11691, 12969, 14256, 15532, ], [ 0, 1302, 2610, 3942, 5267, 6545, 7859, 9154, 10460, 11733, 13053, 14326, 15661, ], [0, 1289, 2616], [ 0, 1291, 2640, 3932, 5229, 6547, 7903, 9205, 10547, 11857, 13171, 14484, 15771, ], [0, 1240], [ 0, 1289, 2665, 3954, 5276, 6576, 7883, 9167, 10535, 11868, 13215, 14548, 15862, ], [ 0, 1299, 2606, 3913, 5223, 6514, 7793, 9097, 10381, 11652, 12936, 14228, 15513, ], [ 0, 1334, 2615, 3932, 5214, 6511, 7818, 9109, 10403, 11701, 13036, 14306, 15648, ], [ 0, 1315, 2613, 3889, 5215, 6490, 7799, 9110, 10407, 11684, 13016, 14333, 15639, ], [ 0, 1304, 2591, 3907, 5275, 6563, 7887, 9203, 10539, 11836, 13169, 14459, 15745, ], [ 0, 1279, 2548, 3860, 5216, 6529, 7833, 9102, 10400, 11697, 13002, 14313, 15638, ], [ 0, 1284, 2569, 3861, 5165, 6452, 7768, 9056, 10424, 11748, 13064, 14361, 15697, ], [0, 1302, 2600], [0, 1289, 2586], [0, 1287, 2577, 3855], ] def assert_close_with_mismatch_tolerance( actual: torch.Tensor, expected: torch.Tensor, rtol: float = 1e-5, atol: float = 1e-8, max_mismatched_elements: int = 0, ): """ Asserts that two tensors are close, allowing for a specified number of mismatched elements. This function correctly implements the same logic as torch.isclose. """ # Ensure tensors are float for comparison actual_float = actual.float() expected_float = expected.float() # This is the core logic from torch.isclose # A mismatch occurs if the difference is greater than the combined tolerance mismatched = torch.abs(actual_float - expected_float) > ( atol + rtol * torch.abs(expected_float) ) num_mismatched = torch.sum(mismatched).item() if num_mismatched > max_mismatched_elements: # For a helpful error message, let's find the worst offenders actual_flat = actual_float.flatten() expected_flat = expected_float.flatten() abs_diff = torch.abs(actual_flat - expected_flat) # Calculate relative difference only where expected is not zero to avoid division by zero # Add a small epsilon to the denominator for stability rel_diff = abs_diff / (torch.abs(expected_flat) + 1e-12) total_elements = actual_flat.numel() raise AssertionError( f"Tensors are not close enough!\n" f"Mismatched elements: {num_mismatched} / {total_elements} " f"({100.0 * num_mismatched / total_elements:.2f}%)\n" f"Allowed mismatched elements: {max_mismatched_elements}, but found {num_mismatched}.\n" f"Greatest absolute difference: {torch.max(abs_diff).item():.4g} (atol={atol})\n" f"Greatest relative difference: {torch.max(rel_diff).item():.4g} (rtol={rtol})" )