1929 lines
30 KiB
Python
1929 lines
30 KiB
Python
import functools
|
|
import gc
|
|
import os
|
|
import types
|
|
from typing import Any, Dict
|
|
|
|
import pytest
|
|
import torch
|
|
from torch.torch_version import TorchVersion
|
|
from torch.torch_version import __version__ as torch_version
|
|
|
|
import flashinfer
|
|
|
|
TORCH_COMPILE_FNS = [
|
|
flashinfer.activation.silu_and_mul,
|
|
flashinfer.activation.gelu_and_mul,
|
|
flashinfer.activation.gelu_tanh_and_mul,
|
|
flashinfer.cascade.merge_state,
|
|
flashinfer.cascade.merge_state_in_place,
|
|
flashinfer.cascade.merge_states,
|
|
flashinfer.cascade.MultiLevelCascadeAttentionWrapper.run,
|
|
flashinfer.cascade.BatchDecodeWithSharedPrefixPagedKVCacheWrapper.forward,
|
|
flashinfer.cascade.BatchPrefillWithSharedPrefixPagedKVCacheWrapper.forward,
|
|
flashinfer.decode.single_decode_with_kv_cache,
|
|
flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper.run,
|
|
flashinfer.gemm.bmm_fp8,
|
|
flashinfer.gemm.SegmentGEMMWrapper.run,
|
|
flashinfer.norm.rmsnorm,
|
|
flashinfer.norm.fused_add_rmsnorm,
|
|
flashinfer.norm.gemma_rmsnorm,
|
|
flashinfer.norm.gemma_fused_add_rmsnorm,
|
|
flashinfer.page.append_paged_kv_cache,
|
|
flashinfer.prefill.single_prefill_with_kv_cache,
|
|
flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper.run,
|
|
flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper.run,
|
|
flashinfer.quantization.packbits,
|
|
flashinfer.rope.apply_rope,
|
|
flashinfer.rope.apply_rope_inplace,
|
|
flashinfer.rope.apply_rope_pos_ids,
|
|
flashinfer.rope.apply_rope_pos_ids_inplace,
|
|
flashinfer.rope.apply_llama31_rope,
|
|
flashinfer.rope.apply_llama31_rope_inplace,
|
|
flashinfer.rope.apply_llama31_rope_pos_ids,
|
|
flashinfer.rope.apply_llama31_rope_pos_ids_inplace,
|
|
flashinfer.sampling.sampling_from_probs,
|
|
flashinfer.sampling.sampling_from_logits,
|
|
flashinfer.sampling.top_p_sampling_from_probs,
|
|
flashinfer.sampling.top_k_sampling_from_probs,
|
|
flashinfer.sampling.min_p_sampling_from_probs,
|
|
flashinfer.sampling.top_k_top_p_sampling_from_probs,
|
|
flashinfer.sampling.top_p_renorm_probs,
|
|
flashinfer.sampling.top_k_renorm_probs,
|
|
flashinfer.sampling.top_k_mask_logits,
|
|
flashinfer.sampling.chain_speculative_sampling,
|
|
]
|
|
|
|
_TORCH_COMPILE_CACHE: Dict[str, Any] = dict()
|
|
|
|
|
|
def _set_torch_compile_options():
|
|
import torch._dynamo.config
|
|
|
|
torch._dynamo.config.cache_size_limit = 128
|
|
|
|
|
|
def _monkeypatch_add_torch_compile(func):
|
|
"""
|
|
Replace the given function with its torch.compile version.
|
|
"""
|
|
|
|
from torch._library.custom_ops import CustomOpDef
|
|
|
|
if type(func) is types.FunctionType:
|
|
fn = func
|
|
elif isinstance(func, CustomOpDef):
|
|
fn = func._init_fn
|
|
else:
|
|
raise ValueError(f"Unsupported fn type {type(func)}")
|
|
|
|
fullname = fn.__module__ + "." + fn.__qualname__
|
|
components = fullname.split(".")
|
|
assert components[0] == "flashinfer"
|
|
module = flashinfer
|
|
for component in components[1:-1]:
|
|
module = getattr(module, component)
|
|
if not hasattr(module, components[-1]):
|
|
raise ValueError(f"Failed to monkeypatch: {fullname}")
|
|
|
|
def wrapper(*args, **kwargs):
|
|
compiled = _TORCH_COMPILE_CACHE.get(fullname)
|
|
if compiled is None:
|
|
# Warmup -- JIT compile / import the kernels.
|
|
#
|
|
# From user side, users also need to warmup the model beforehand,
|
|
# as suggested by PyTorch Cuda Graph docs (not sure if it's also
|
|
# recommended for torch.compile as well.)
|
|
#
|
|
# For the convenience of FlashInfer testing, we do the warmup here,
|
|
# on the first run of the function. The caveat is that the first
|
|
# call will run twice: once to warmup, and another through the
|
|
# compiled version.
|
|
func(*args, **kwargs)
|
|
|
|
# Compile
|
|
compiled = torch.compile(
|
|
func,
|
|
fullgraph=True,
|
|
backend="inductor",
|
|
mode="max-autotune-no-cudagraphs",
|
|
)
|
|
_TORCH_COMPILE_CACHE[fn.__name__] = compiled
|
|
|
|
return compiled(*args, **kwargs)
|
|
|
|
setattr(module, fn.__name__, wrapper)
|
|
print("Applied torch.compile to", fullname)
|
|
|
|
|
|
def pytest_configure(config):
|
|
if os.environ.get("FLASHINFER_TEST_TORCH_COMPILE", "0") == "1":
|
|
if torch_version < TorchVersion("2.4"):
|
|
pytest.skip("torch.compile requires torch >= 2.4")
|
|
_set_torch_compile_options()
|
|
for fn in TORCH_COMPILE_FNS:
|
|
_monkeypatch_add_torch_compile(fn)
|
|
|
|
|
|
def is_cuda_oom_error_str(e: str) -> bool:
|
|
return "CUDA" in e and "out of memory" in e
|
|
|
|
|
|
@pytest.hookimpl(tryfirst=True)
|
|
def pytest_runtest_call(item):
|
|
# skip OOM error
|
|
try:
|
|
item.runtest()
|
|
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
|
if isinstance(e, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(str(e)):
|
|
pytest.skip("Skipping due to OOM")
|
|
else:
|
|
raise
|
|
|
|
|
|
@functools.cache
|
|
def get_device_properties(device: torch.device):
|
|
return torch.cuda.get_device_properties(device)
|
|
|
|
|
|
def clear_cuda_cache(device: torch.device) -> None:
|
|
total_memory = get_device_properties(device).total_memory
|
|
reserved_memory = torch.cuda.memory_reserved()
|
|
|
|
# FLASHINFER_TEST_MEMORY_THRESHOLD: threshold for PyTorch reserved memory usage (default: 0.9)
|
|
threshold = float(os.environ.get("FLASHINFER_TEST_MEMORY_THRESHOLD", "0.9"))
|
|
|
|
if reserved_memory > threshold * total_memory:
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
# collected from gsk8k trace in sglang
|
|
VARLEN_INDPTR_PARAMS = [
|
|
[
|
|
0,
|
|
1276,
|
|
2551,
|
|
3838,
|
|
5115,
|
|
6428,
|
|
7705,
|
|
8985,
|
|
10293,
|
|
11607,
|
|
12909,
|
|
14216,
|
|
15508,
|
|
],
|
|
[
|
|
0,
|
|
1320,
|
|
2637,
|
|
3926,
|
|
5208,
|
|
6494,
|
|
7795,
|
|
9130,
|
|
10415,
|
|
11704,
|
|
12995,
|
|
14270,
|
|
15551,
|
|
],
|
|
[
|
|
0,
|
|
1333,
|
|
2762,
|
|
4068,
|
|
5345,
|
|
6635,
|
|
7936,
|
|
9237,
|
|
10518,
|
|
11874,
|
|
13146,
|
|
14466,
|
|
15785,
|
|
],
|
|
[
|
|
0,
|
|
1310,
|
|
2603,
|
|
3937,
|
|
5231,
|
|
6527,
|
|
7799,
|
|
9107,
|
|
10411,
|
|
11698,
|
|
12978,
|
|
14277,
|
|
15559,
|
|
],
|
|
[
|
|
0,
|
|
1286,
|
|
2561,
|
|
3857,
|
|
5163,
|
|
6459,
|
|
7763,
|
|
9057,
|
|
10380,
|
|
11679,
|
|
12968,
|
|
14284,
|
|
15610,
|
|
],
|
|
[0, 1350, 2667, 4003, 5347, 6631, 7919, 9208, 10524],
|
|
[
|
|
0,
|
|
1293,
|
|
2609,
|
|
3902,
|
|
5196,
|
|
6495,
|
|
7807,
|
|
9086,
|
|
10382,
|
|
11700,
|
|
12989,
|
|
14271,
|
|
15578,
|
|
],
|
|
[
|
|
0,
|
|
1276,
|
|
2551,
|
|
3838,
|
|
5115,
|
|
6428,
|
|
7705,
|
|
8985,
|
|
10293,
|
|
11607,
|
|
12909,
|
|
14216,
|
|
15508,
|
|
],
|
|
[
|
|
0,
|
|
1280,
|
|
2559,
|
|
3874,
|
|
5197,
|
|
6540,
|
|
7850,
|
|
9167,
|
|
10536,
|
|
11820,
|
|
13111,
|
|
14444,
|
|
15756,
|
|
],
|
|
[
|
|
0,
|
|
1306,
|
|
2598,
|
|
3895,
|
|
5181,
|
|
6458,
|
|
7750,
|
|
9047,
|
|
10333,
|
|
11635,
|
|
12906,
|
|
14207,
|
|
15514,
|
|
],
|
|
[
|
|
0,
|
|
1300,
|
|
2620,
|
|
3912,
|
|
5219,
|
|
6500,
|
|
7781,
|
|
9069,
|
|
10386,
|
|
11665,
|
|
12976,
|
|
14262,
|
|
15545,
|
|
],
|
|
[0, 1800, 3600, 5400, 7200, 9000, 10799, 12620, 14441, 16261],
|
|
[0, 1298, 2638],
|
|
[0, 1284],
|
|
[0, 1297, 2604],
|
|
[0, 1276],
|
|
[
|
|
0,
|
|
1286,
|
|
2614,
|
|
3909,
|
|
5207,
|
|
6490,
|
|
7785,
|
|
9067,
|
|
10356,
|
|
11633,
|
|
12915,
|
|
14231,
|
|
15511,
|
|
],
|
|
[
|
|
0,
|
|
1312,
|
|
2613,
|
|
3899,
|
|
5203,
|
|
6492,
|
|
7811,
|
|
9151,
|
|
10439,
|
|
11757,
|
|
13052,
|
|
14364,
|
|
15646,
|
|
],
|
|
[0, 1287],
|
|
[
|
|
0,
|
|
1353,
|
|
2684,
|
|
4039,
|
|
5326,
|
|
6615,
|
|
7932,
|
|
9217,
|
|
10528,
|
|
11862,
|
|
13207,
|
|
14490,
|
|
15785,
|
|
],
|
|
[0, 1307],
|
|
[
|
|
0,
|
|
1301,
|
|
2587,
|
|
3949,
|
|
5263,
|
|
6620,
|
|
7933,
|
|
9226,
|
|
10521,
|
|
11895,
|
|
13179,
|
|
14531,
|
|
15822,
|
|
],
|
|
[
|
|
0,
|
|
1334,
|
|
2618,
|
|
3914,
|
|
5198,
|
|
6476,
|
|
7771,
|
|
9076,
|
|
10362,
|
|
11675,
|
|
12974,
|
|
14264,
|
|
15540,
|
|
],
|
|
[
|
|
0,
|
|
1323,
|
|
2606,
|
|
3900,
|
|
5209,
|
|
6487,
|
|
7770,
|
|
9074,
|
|
10397,
|
|
11694,
|
|
13047,
|
|
14411,
|
|
15719,
|
|
],
|
|
[
|
|
0,
|
|
1286,
|
|
2563,
|
|
3845,
|
|
5181,
|
|
6471,
|
|
7789,
|
|
9087,
|
|
10394,
|
|
11750,
|
|
13021,
|
|
14335,
|
|
15616,
|
|
],
|
|
[
|
|
0,
|
|
1310,
|
|
2605,
|
|
3889,
|
|
5214,
|
|
6518,
|
|
7794,
|
|
9143,
|
|
10465,
|
|
11802,
|
|
13134,
|
|
14438,
|
|
15713,
|
|
],
|
|
[0, 1285, 2596, 3877, 5163, 6487, 7782, 9104, 10403],
|
|
[
|
|
0,
|
|
1299,
|
|
2594,
|
|
3921,
|
|
5222,
|
|
6494,
|
|
7777,
|
|
9098,
|
|
10406,
|
|
11736,
|
|
13026,
|
|
14317,
|
|
15621,
|
|
],
|
|
[0, 1268, 2560],
|
|
[0, 1536, 3061, 4578, 6177, 7774, 9378, 10958, 12636, 14292, 15954],
|
|
[0, 1240],
|
|
[
|
|
0,
|
|
1362,
|
|
2653,
|
|
3930,
|
|
5201,
|
|
6505,
|
|
7808,
|
|
9094,
|
|
10421,
|
|
11720,
|
|
12994,
|
|
14285,
|
|
15584,
|
|
],
|
|
[0, 1676, 3342, 5094, 6842, 8582, 10342, 12102, 13861, 15618],
|
|
[
|
|
0,
|
|
1329,
|
|
2656,
|
|
3977,
|
|
5253,
|
|
6532,
|
|
7831,
|
|
9150,
|
|
10444,
|
|
11749,
|
|
13090,
|
|
14388,
|
|
15675,
|
|
],
|
|
[
|
|
0,
|
|
1284,
|
|
2578,
|
|
3854,
|
|
5189,
|
|
6513,
|
|
7809,
|
|
9144,
|
|
10463,
|
|
11772,
|
|
13062,
|
|
14368,
|
|
15641,
|
|
],
|
|
[
|
|
0,
|
|
1286,
|
|
2651,
|
|
3960,
|
|
5286,
|
|
6573,
|
|
7857,
|
|
9146,
|
|
10445,
|
|
11723,
|
|
12995,
|
|
14270,
|
|
15537,
|
|
],
|
|
[0, 1716, 3555, 5298, 7041, 8781],
|
|
[0, 1281, 2574, 3866, 5176],
|
|
[
|
|
0,
|
|
1331,
|
|
2655,
|
|
3965,
|
|
5318,
|
|
6606,
|
|
7954,
|
|
9238,
|
|
10526,
|
|
11837,
|
|
13134,
|
|
14436,
|
|
15728,
|
|
],
|
|
[0, 1754, 3508, 5259, 7028, 8796, 10564, 12332, 14097, 15862],
|
|
[
|
|
0,
|
|
1293,
|
|
2584,
|
|
3871,
|
|
5157,
|
|
6466,
|
|
7831,
|
|
9118,
|
|
10444,
|
|
11728,
|
|
13017,
|
|
14295,
|
|
15594,
|
|
],
|
|
[
|
|
0,
|
|
1279,
|
|
2586,
|
|
3876,
|
|
5170,
|
|
6440,
|
|
7768,
|
|
9067,
|
|
10351,
|
|
11651,
|
|
12936,
|
|
14239,
|
|
15542,
|
|
],
|
|
[
|
|
0,
|
|
1293,
|
|
2563,
|
|
3861,
|
|
5139,
|
|
6491,
|
|
7776,
|
|
9121,
|
|
10422,
|
|
11731,
|
|
13033,
|
|
14338,
|
|
15639,
|
|
],
|
|
[
|
|
0,
|
|
1292,
|
|
2632,
|
|
3933,
|
|
5257,
|
|
6576,
|
|
7881,
|
|
9178,
|
|
10455,
|
|
11796,
|
|
13095,
|
|
14385,
|
|
15685,
|
|
],
|
|
[0, 1307],
|
|
[
|
|
0,
|
|
1307,
|
|
2590,
|
|
3897,
|
|
5206,
|
|
6527,
|
|
7826,
|
|
9104,
|
|
10400,
|
|
11696,
|
|
13022,
|
|
14326,
|
|
15615,
|
|
],
|
|
[
|
|
0,
|
|
1287,
|
|
2597,
|
|
3933,
|
|
5275,
|
|
6555,
|
|
7835,
|
|
9153,
|
|
10445,
|
|
11729,
|
|
13019,
|
|
14303,
|
|
15608,
|
|
],
|
|
[
|
|
0,
|
|
1294,
|
|
2589,
|
|
3904,
|
|
5205,
|
|
6504,
|
|
7803,
|
|
9087,
|
|
10375,
|
|
11671,
|
|
12970,
|
|
14279,
|
|
15615,
|
|
],
|
|
[
|
|
0,
|
|
1312,
|
|
2624,
|
|
3957,
|
|
5243,
|
|
6533,
|
|
7817,
|
|
9095,
|
|
10377,
|
|
11729,
|
|
13053,
|
|
14332,
|
|
15643,
|
|
],
|
|
[
|
|
0,
|
|
1278,
|
|
2579,
|
|
3858,
|
|
5147,
|
|
6461,
|
|
7745,
|
|
9038,
|
|
10376,
|
|
11654,
|
|
12941,
|
|
14265,
|
|
15592,
|
|
],
|
|
[
|
|
0,
|
|
1284,
|
|
2578,
|
|
3855,
|
|
5181,
|
|
6475,
|
|
7787,
|
|
9090,
|
|
10386,
|
|
11661,
|
|
13010,
|
|
14291,
|
|
15595,
|
|
],
|
|
[
|
|
0,
|
|
1293,
|
|
2604,
|
|
3893,
|
|
5211,
|
|
6526,
|
|
7859,
|
|
9139,
|
|
10439,
|
|
11723,
|
|
13071,
|
|
14369,
|
|
15669,
|
|
],
|
|
[
|
|
0,
|
|
1333,
|
|
2618,
|
|
3892,
|
|
5196,
|
|
6478,
|
|
7778,
|
|
9088,
|
|
10381,
|
|
11677,
|
|
12986,
|
|
14276,
|
|
15552,
|
|
],
|
|
[
|
|
0,
|
|
1287,
|
|
2569,
|
|
3876,
|
|
5156,
|
|
6463,
|
|
7754,
|
|
9053,
|
|
10363,
|
|
11642,
|
|
12946,
|
|
14230,
|
|
15501,
|
|
],
|
|
[
|
|
0,
|
|
1293,
|
|
2585,
|
|
3900,
|
|
5183,
|
|
6502,
|
|
7882,
|
|
9185,
|
|
10466,
|
|
11732,
|
|
13017,
|
|
14324,
|
|
15612,
|
|
],
|
|
[
|
|
0,
|
|
1323,
|
|
2597,
|
|
3877,
|
|
5183,
|
|
6483,
|
|
7793,
|
|
9084,
|
|
10417,
|
|
11719,
|
|
12999,
|
|
14294,
|
|
15587,
|
|
],
|
|
[
|
|
0,
|
|
1289,
|
|
2626,
|
|
3900,
|
|
5217,
|
|
6550,
|
|
7820,
|
|
9140,
|
|
10431,
|
|
11717,
|
|
13028,
|
|
14312,
|
|
15615,
|
|
],
|
|
[
|
|
0,
|
|
1306,
|
|
2588,
|
|
3890,
|
|
5192,
|
|
6540,
|
|
7858,
|
|
9170,
|
|
10492,
|
|
11772,
|
|
13051,
|
|
14348,
|
|
15636,
|
|
],
|
|
[
|
|
0,
|
|
1279,
|
|
2583,
|
|
3892,
|
|
5193,
|
|
6481,
|
|
7788,
|
|
9099,
|
|
10394,
|
|
11701,
|
|
13026,
|
|
14348,
|
|
15710,
|
|
],
|
|
[
|
|
0,
|
|
1287,
|
|
2599,
|
|
3939,
|
|
5223,
|
|
6523,
|
|
7822,
|
|
9102,
|
|
10435,
|
|
11714,
|
|
13006,
|
|
14294,
|
|
15622,
|
|
],
|
|
[
|
|
0,
|
|
1302,
|
|
2631,
|
|
3913,
|
|
5192,
|
|
6503,
|
|
7804,
|
|
9121,
|
|
10429,
|
|
11757,
|
|
13064,
|
|
14379,
|
|
15656,
|
|
],
|
|
[
|
|
0,
|
|
1278,
|
|
2569,
|
|
3914,
|
|
5211,
|
|
6480,
|
|
7805,
|
|
9089,
|
|
10383,
|
|
11687,
|
|
12971,
|
|
14281,
|
|
15605,
|
|
],
|
|
[
|
|
0,
|
|
1278,
|
|
2559,
|
|
3834,
|
|
5144,
|
|
6434,
|
|
7754,
|
|
9033,
|
|
10330,
|
|
11607,
|
|
12925,
|
|
14218,
|
|
15510,
|
|
],
|
|
[0, 1319],
|
|
[
|
|
0,
|
|
1269,
|
|
2564,
|
|
3849,
|
|
5130,
|
|
6430,
|
|
7740,
|
|
9060,
|
|
10409,
|
|
11698,
|
|
13001,
|
|
14286,
|
|
15557,
|
|
],
|
|
[
|
|
0,
|
|
1288,
|
|
2592,
|
|
3867,
|
|
5214,
|
|
6491,
|
|
7793,
|
|
9110,
|
|
10416,
|
|
11729,
|
|
13020,
|
|
14318,
|
|
15625,
|
|
],
|
|
[
|
|
0,
|
|
1326,
|
|
2643,
|
|
3972,
|
|
5270,
|
|
6591,
|
|
7872,
|
|
9139,
|
|
10437,
|
|
11731,
|
|
13031,
|
|
14327,
|
|
15633,
|
|
],
|
|
[
|
|
0,
|
|
1284,
|
|
2560,
|
|
3857,
|
|
5134,
|
|
6436,
|
|
7728,
|
|
9041,
|
|
10345,
|
|
11625,
|
|
12940,
|
|
14242,
|
|
15530,
|
|
],
|
|
[0, 1299, 2576],
|
|
[
|
|
0,
|
|
1296,
|
|
2574,
|
|
3866,
|
|
5162,
|
|
6448,
|
|
7745,
|
|
9020,
|
|
10294,
|
|
11588,
|
|
12895,
|
|
14218,
|
|
15525,
|
|
],
|
|
[
|
|
0,
|
|
1279,
|
|
2563,
|
|
3875,
|
|
5161,
|
|
6461,
|
|
7741,
|
|
9023,
|
|
10305,
|
|
11613,
|
|
12897,
|
|
14204,
|
|
15536,
|
|
],
|
|
[
|
|
0,
|
|
1273,
|
|
2553,
|
|
3848,
|
|
5210,
|
|
6493,
|
|
7775,
|
|
9058,
|
|
10375,
|
|
11695,
|
|
12984,
|
|
14278,
|
|
15588,
|
|
],
|
|
[
|
|
0,
|
|
1283,
|
|
2584,
|
|
3863,
|
|
5160,
|
|
6444,
|
|
7740,
|
|
9061,
|
|
10377,
|
|
11698,
|
|
12994,
|
|
14274,
|
|
15545,
|
|
],
|
|
[
|
|
0,
|
|
1329,
|
|
2648,
|
|
3962,
|
|
5309,
|
|
6622,
|
|
7930,
|
|
9242,
|
|
10544,
|
|
11828,
|
|
13183,
|
|
14476,
|
|
15809,
|
|
],
|
|
[
|
|
0,
|
|
1290,
|
|
2591,
|
|
3891,
|
|
5175,
|
|
6460,
|
|
7766,
|
|
9112,
|
|
10402,
|
|
11701,
|
|
13019,
|
|
14330,
|
|
15633,
|
|
],
|
|
[
|
|
0,
|
|
1333,
|
|
2673,
|
|
3958,
|
|
5270,
|
|
6589,
|
|
7911,
|
|
9203,
|
|
10549,
|
|
11841,
|
|
13146,
|
|
14471,
|
|
15776,
|
|
],
|
|
[
|
|
0,
|
|
1288,
|
|
2643,
|
|
3945,
|
|
5266,
|
|
6595,
|
|
7907,
|
|
9213,
|
|
10486,
|
|
11807,
|
|
13138,
|
|
14430,
|
|
15703,
|
|
],
|
|
[
|
|
0,
|
|
1306,
|
|
2620,
|
|
3944,
|
|
5260,
|
|
6569,
|
|
7852,
|
|
9144,
|
|
10460,
|
|
11785,
|
|
13075,
|
|
14368,
|
|
15672,
|
|
],
|
|
[
|
|
0,
|
|
1294,
|
|
2572,
|
|
3851,
|
|
5164,
|
|
6464,
|
|
7755,
|
|
9090,
|
|
10398,
|
|
11688,
|
|
13002,
|
|
14313,
|
|
15593,
|
|
],
|
|
[
|
|
0,
|
|
1340,
|
|
2651,
|
|
3959,
|
|
5258,
|
|
6545,
|
|
7836,
|
|
9157,
|
|
10465,
|
|
11772,
|
|
13065,
|
|
14368,
|
|
15747,
|
|
],
|
|
[
|
|
0,
|
|
1325,
|
|
2657,
|
|
3935,
|
|
5255,
|
|
6583,
|
|
7874,
|
|
9154,
|
|
10448,
|
|
11732,
|
|
13026,
|
|
14344,
|
|
15620,
|
|
],
|
|
[0, 1764, 3551, 5336, 7121, 8905, 10688, 12471, 14252, 16054],
|
|
[
|
|
0,
|
|
1280,
|
|
2590,
|
|
3896,
|
|
5187,
|
|
6520,
|
|
7822,
|
|
9117,
|
|
10397,
|
|
11690,
|
|
12977,
|
|
14270,
|
|
15561,
|
|
],
|
|
[
|
|
0,
|
|
1285,
|
|
2577,
|
|
3862,
|
|
5198,
|
|
6477,
|
|
7762,
|
|
9130,
|
|
10412,
|
|
11694,
|
|
13049,
|
|
14358,
|
|
15666,
|
|
],
|
|
[
|
|
0,
|
|
1287,
|
|
2617,
|
|
3942,
|
|
5240,
|
|
6510,
|
|
7807,
|
|
9090,
|
|
10390,
|
|
11743,
|
|
13031,
|
|
14325,
|
|
15615,
|
|
],
|
|
[0, 1310, 2584, 3990, 5291, 6598, 7908, 9192],
|
|
[
|
|
0,
|
|
1304,
|
|
2626,
|
|
3930,
|
|
5209,
|
|
6499,
|
|
7810,
|
|
9109,
|
|
10435,
|
|
11731,
|
|
13007,
|
|
14307,
|
|
15593,
|
|
],
|
|
[
|
|
0,
|
|
1308,
|
|
2612,
|
|
3927,
|
|
5227,
|
|
6515,
|
|
7812,
|
|
9146,
|
|
10447,
|
|
11731,
|
|
13017,
|
|
14317,
|
|
15602,
|
|
],
|
|
[0, 1820, 3640, 5460, 7277, 9115, 10953, 12791, 14628],
|
|
[
|
|
0,
|
|
1289,
|
|
2594,
|
|
3903,
|
|
5196,
|
|
6499,
|
|
7799,
|
|
9077,
|
|
10386,
|
|
11662,
|
|
12959,
|
|
14243,
|
|
15543,
|
|
],
|
|
[
|
|
0,
|
|
1300,
|
|
2601,
|
|
3876,
|
|
5165,
|
|
6436,
|
|
7725,
|
|
9039,
|
|
10352,
|
|
11639,
|
|
12927,
|
|
14209,
|
|
15490,
|
|
],
|
|
[0, 1837, 3674, 5206, 6693, 8229, 9790, 11329, 12910, 14474, 16037],
|
|
[
|
|
0,
|
|
1292,
|
|
2604,
|
|
3878,
|
|
5151,
|
|
6453,
|
|
7749,
|
|
9033,
|
|
10363,
|
|
11703,
|
|
13014,
|
|
14301,
|
|
15617,
|
|
],
|
|
[
|
|
0,
|
|
1275,
|
|
2556,
|
|
3843,
|
|
5147,
|
|
6427,
|
|
7712,
|
|
9003,
|
|
10311,
|
|
11600,
|
|
12970,
|
|
14264,
|
|
15545,
|
|
],
|
|
[
|
|
0,
|
|
1285,
|
|
2590,
|
|
3878,
|
|
5169,
|
|
6527,
|
|
7863,
|
|
9161,
|
|
10451,
|
|
11745,
|
|
13066,
|
|
14382,
|
|
15695,
|
|
],
|
|
[0, 1340, 2635],
|
|
[
|
|
0,
|
|
1314,
|
|
2600,
|
|
3894,
|
|
5194,
|
|
6490,
|
|
7797,
|
|
9105,
|
|
10385,
|
|
11667,
|
|
12967,
|
|
14255,
|
|
15550,
|
|
],
|
|
[
|
|
0,
|
|
1308,
|
|
2605,
|
|
3956,
|
|
5254,
|
|
6582,
|
|
7865,
|
|
9160,
|
|
10459,
|
|
11758,
|
|
13045,
|
|
14341,
|
|
15623,
|
|
],
|
|
[
|
|
0,
|
|
1282,
|
|
2576,
|
|
3882,
|
|
5190,
|
|
6510,
|
|
7819,
|
|
9142,
|
|
10427,
|
|
11736,
|
|
13041,
|
|
14359,
|
|
15683,
|
|
],
|
|
[0, 1300, 2614, 3924],
|
|
[
|
|
0,
|
|
1282,
|
|
2600,
|
|
3923,
|
|
5229,
|
|
6580,
|
|
7952,
|
|
9295,
|
|
10593,
|
|
11873,
|
|
13161,
|
|
14458,
|
|
15756,
|
|
],
|
|
[
|
|
0,
|
|
1286,
|
|
2578,
|
|
3884,
|
|
5184,
|
|
6494,
|
|
7779,
|
|
9078,
|
|
10356,
|
|
11677,
|
|
12976,
|
|
14256,
|
|
15560,
|
|
],
|
|
[
|
|
0,
|
|
1303,
|
|
2575,
|
|
3848,
|
|
5119,
|
|
6417,
|
|
7714,
|
|
9020,
|
|
10362,
|
|
11668,
|
|
12983,
|
|
14314,
|
|
15599,
|
|
],
|
|
[0, 1291, 2584],
|
|
[
|
|
0,
|
|
1299,
|
|
2617,
|
|
3938,
|
|
5328,
|
|
6600,
|
|
7885,
|
|
9163,
|
|
10489,
|
|
11771,
|
|
13053,
|
|
14332,
|
|
15691,
|
|
],
|
|
[0, 1305, 2617],
|
|
[0, 1573, 3257, 4935, 6605, 8256, 9906, 11529, 13171, 14809],
|
|
[
|
|
0,
|
|
1299,
|
|
2591,
|
|
3885,
|
|
5165,
|
|
6445,
|
|
7744,
|
|
9111,
|
|
10413,
|
|
11725,
|
|
13000,
|
|
14304,
|
|
15614,
|
|
],
|
|
[0, 1296],
|
|
[
|
|
0,
|
|
1295,
|
|
2570,
|
|
3912,
|
|
5252,
|
|
6527,
|
|
7806,
|
|
9121,
|
|
10408,
|
|
11710,
|
|
12988,
|
|
14270,
|
|
15585,
|
|
],
|
|
[
|
|
0,
|
|
1285,
|
|
2621,
|
|
3937,
|
|
5235,
|
|
6506,
|
|
7790,
|
|
9085,
|
|
10352,
|
|
11630,
|
|
12949,
|
|
14247,
|
|
15528,
|
|
],
|
|
[
|
|
0,
|
|
1297,
|
|
2575,
|
|
3868,
|
|
5146,
|
|
6436,
|
|
7775,
|
|
9066,
|
|
10376,
|
|
11708,
|
|
13005,
|
|
14365,
|
|
15649,
|
|
],
|
|
[
|
|
0,
|
|
1322,
|
|
2638,
|
|
3920,
|
|
5217,
|
|
6522,
|
|
7801,
|
|
9113,
|
|
10472,
|
|
11769,
|
|
13046,
|
|
14372,
|
|
15668,
|
|
],
|
|
[
|
|
0,
|
|
1272,
|
|
2539,
|
|
3871,
|
|
5146,
|
|
6471,
|
|
7791,
|
|
9069,
|
|
10360,
|
|
11688,
|
|
12968,
|
|
14262,
|
|
15580,
|
|
],
|
|
[
|
|
0,
|
|
1322,
|
|
2642,
|
|
3933,
|
|
5229,
|
|
6538,
|
|
7823,
|
|
9126,
|
|
10432,
|
|
11734,
|
|
13089,
|
|
14372,
|
|
15678,
|
|
],
|
|
[
|
|
0,
|
|
1310,
|
|
2658,
|
|
3987,
|
|
5316,
|
|
6608,
|
|
7878,
|
|
9171,
|
|
10463,
|
|
11757,
|
|
13060,
|
|
14356,
|
|
15660,
|
|
],
|
|
[
|
|
0,
|
|
1318,
|
|
2640,
|
|
3924,
|
|
5237,
|
|
6546,
|
|
7832,
|
|
9138,
|
|
10462,
|
|
11762,
|
|
13046,
|
|
14341,
|
|
15609,
|
|
],
|
|
[
|
|
0,
|
|
1280,
|
|
2558,
|
|
3850,
|
|
5191,
|
|
6495,
|
|
7820,
|
|
9113,
|
|
10401,
|
|
11717,
|
|
13040,
|
|
14314,
|
|
15614,
|
|
],
|
|
[
|
|
0,
|
|
1313,
|
|
2596,
|
|
3908,
|
|
5249,
|
|
6542,
|
|
7843,
|
|
9141,
|
|
10456,
|
|
11739,
|
|
13039,
|
|
14348,
|
|
15699,
|
|
],
|
|
[0, 1309],
|
|
[0, 1400, 2689],
|
|
[
|
|
0,
|
|
1362,
|
|
2646,
|
|
3947,
|
|
5228,
|
|
6517,
|
|
7824,
|
|
9116,
|
|
10402,
|
|
11683,
|
|
12976,
|
|
14271,
|
|
15583,
|
|
],
|
|
[
|
|
0,
|
|
1303,
|
|
2653,
|
|
3937,
|
|
5234,
|
|
6541,
|
|
7861,
|
|
9224,
|
|
10606,
|
|
11897,
|
|
13213,
|
|
14544,
|
|
15851,
|
|
],
|
|
[
|
|
0,
|
|
1309,
|
|
2636,
|
|
3924,
|
|
5216,
|
|
6500,
|
|
7775,
|
|
9085,
|
|
10380,
|
|
11696,
|
|
12999,
|
|
14337,
|
|
15613,
|
|
],
|
|
[0, 1310, 2611, 3904, 5238, 6532, 7804, 9100, 10408, 11707, 13011],
|
|
[
|
|
0,
|
|
1313,
|
|
2646,
|
|
3956,
|
|
5263,
|
|
6587,
|
|
7949,
|
|
9257,
|
|
10555,
|
|
11837,
|
|
13104,
|
|
14394,
|
|
15724,
|
|
],
|
|
[
|
|
0,
|
|
1321,
|
|
2612,
|
|
3915,
|
|
5231,
|
|
6551,
|
|
7838,
|
|
9128,
|
|
10440,
|
|
11759,
|
|
13099,
|
|
14416,
|
|
15700,
|
|
],
|
|
[
|
|
0,
|
|
1283,
|
|
2592,
|
|
3872,
|
|
5194,
|
|
6467,
|
|
7751,
|
|
9040,
|
|
10321,
|
|
11673,
|
|
13010,
|
|
14304,
|
|
15602,
|
|
],
|
|
[0, 1270, 2622, 3915, 5193, 6478, 7776, 9085, 10430, 11732, 13033, 14338],
|
|
[0, 1296, 2631, 3955],
|
|
[
|
|
0,
|
|
1315,
|
|
2622,
|
|
3949,
|
|
5243,
|
|
6592,
|
|
7894,
|
|
9216,
|
|
10533,
|
|
11830,
|
|
13123,
|
|
14419,
|
|
15722,
|
|
],
|
|
[
|
|
0,
|
|
1296,
|
|
2590,
|
|
3913,
|
|
5221,
|
|
6504,
|
|
7778,
|
|
9125,
|
|
10426,
|
|
11782,
|
|
13051,
|
|
14328,
|
|
15637,
|
|
],
|
|
[
|
|
0,
|
|
1294,
|
|
2579,
|
|
3886,
|
|
5160,
|
|
6456,
|
|
7746,
|
|
9047,
|
|
10347,
|
|
11638,
|
|
12962,
|
|
14261,
|
|
15550,
|
|
],
|
|
[0, 7],
|
|
[
|
|
0,
|
|
1298,
|
|
2599,
|
|
3887,
|
|
5201,
|
|
6506,
|
|
7843,
|
|
9158,
|
|
10456,
|
|
11749,
|
|
13058,
|
|
14337,
|
|
15630,
|
|
],
|
|
[
|
|
0,
|
|
1290,
|
|
2598,
|
|
3876,
|
|
5177,
|
|
6473,
|
|
7790,
|
|
9065,
|
|
10362,
|
|
11640,
|
|
12943,
|
|
14287,
|
|
15582,
|
|
],
|
|
[
|
|
0,
|
|
1333,
|
|
2623,
|
|
3903,
|
|
5189,
|
|
6467,
|
|
7759,
|
|
9063,
|
|
10388,
|
|
11729,
|
|
13022,
|
|
14310,
|
|
15626,
|
|
],
|
|
[
|
|
0,
|
|
1322,
|
|
2615,
|
|
3921,
|
|
5206,
|
|
6491,
|
|
7811,
|
|
9109,
|
|
10394,
|
|
11691,
|
|
12969,
|
|
14256,
|
|
15532,
|
|
],
|
|
[
|
|
0,
|
|
1302,
|
|
2610,
|
|
3942,
|
|
5267,
|
|
6545,
|
|
7859,
|
|
9154,
|
|
10460,
|
|
11733,
|
|
13053,
|
|
14326,
|
|
15661,
|
|
],
|
|
[0, 1289, 2616],
|
|
[
|
|
0,
|
|
1291,
|
|
2640,
|
|
3932,
|
|
5229,
|
|
6547,
|
|
7903,
|
|
9205,
|
|
10547,
|
|
11857,
|
|
13171,
|
|
14484,
|
|
15771,
|
|
],
|
|
[0, 1240],
|
|
[
|
|
0,
|
|
1289,
|
|
2665,
|
|
3954,
|
|
5276,
|
|
6576,
|
|
7883,
|
|
9167,
|
|
10535,
|
|
11868,
|
|
13215,
|
|
14548,
|
|
15862,
|
|
],
|
|
[
|
|
0,
|
|
1299,
|
|
2606,
|
|
3913,
|
|
5223,
|
|
6514,
|
|
7793,
|
|
9097,
|
|
10381,
|
|
11652,
|
|
12936,
|
|
14228,
|
|
15513,
|
|
],
|
|
[
|
|
0,
|
|
1334,
|
|
2615,
|
|
3932,
|
|
5214,
|
|
6511,
|
|
7818,
|
|
9109,
|
|
10403,
|
|
11701,
|
|
13036,
|
|
14306,
|
|
15648,
|
|
],
|
|
[
|
|
0,
|
|
1315,
|
|
2613,
|
|
3889,
|
|
5215,
|
|
6490,
|
|
7799,
|
|
9110,
|
|
10407,
|
|
11684,
|
|
13016,
|
|
14333,
|
|
15639,
|
|
],
|
|
[
|
|
0,
|
|
1304,
|
|
2591,
|
|
3907,
|
|
5275,
|
|
6563,
|
|
7887,
|
|
9203,
|
|
10539,
|
|
11836,
|
|
13169,
|
|
14459,
|
|
15745,
|
|
],
|
|
[
|
|
0,
|
|
1279,
|
|
2548,
|
|
3860,
|
|
5216,
|
|
6529,
|
|
7833,
|
|
9102,
|
|
10400,
|
|
11697,
|
|
13002,
|
|
14313,
|
|
15638,
|
|
],
|
|
[
|
|
0,
|
|
1284,
|
|
2569,
|
|
3861,
|
|
5165,
|
|
6452,
|
|
7768,
|
|
9056,
|
|
10424,
|
|
11748,
|
|
13064,
|
|
14361,
|
|
15697,
|
|
],
|
|
[0, 1302, 2600],
|
|
[0, 1289, 2586],
|
|
[0, 1287, 2577, 3855],
|
|
]
|
|
|
|
|
|
def assert_close_with_mismatch_tolerance(
|
|
actual: torch.Tensor,
|
|
expected: torch.Tensor,
|
|
rtol: float = 1e-5,
|
|
atol: float = 1e-8,
|
|
max_mismatched_elements: int = 0,
|
|
):
|
|
"""
|
|
Asserts that two tensors are close, allowing for a specified number of mismatched elements.
|
|
This function correctly implements the same logic as torch.isclose.
|
|
"""
|
|
# Ensure tensors are float for comparison
|
|
actual_float = actual.float()
|
|
expected_float = expected.float()
|
|
|
|
# This is the core logic from torch.isclose
|
|
# A mismatch occurs if the difference is greater than the combined tolerance
|
|
mismatched = torch.abs(actual_float - expected_float) > (
|
|
atol + rtol * torch.abs(expected_float)
|
|
)
|
|
|
|
num_mismatched = torch.sum(mismatched).item()
|
|
|
|
if num_mismatched > max_mismatched_elements:
|
|
# For a helpful error message, let's find the worst offenders
|
|
actual_flat = actual_float.flatten()
|
|
expected_flat = expected_float.flatten()
|
|
abs_diff = torch.abs(actual_flat - expected_flat)
|
|
|
|
# Calculate relative difference only where expected is not zero to avoid division by zero
|
|
# Add a small epsilon to the denominator for stability
|
|
rel_diff = abs_diff / (torch.abs(expected_flat) + 1e-12)
|
|
|
|
total_elements = actual_flat.numel()
|
|
|
|
raise AssertionError(
|
|
f"Tensors are not close enough!\n"
|
|
f"Mismatched elements: {num_mismatched} / {total_elements} "
|
|
f"({100.0 * num_mismatched / total_elements:.2f}%)\n"
|
|
f"Allowed mismatched elements: {max_mismatched_elements}, but found {num_mismatched}.\n"
|
|
f"Greatest absolute difference: {torch.max(abs_diff).item():.4g} (atol={atol})\n"
|
|
f"Greatest relative difference: {torch.max(rel_diff).item():.4g} (rtol={rtol})"
|
|
)
|