""" Copyright (c) 2023 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import random import time from typing import Tuple, Any import os import sys import numpy as np import torch from einops import rearrange, reduce, repeat from flashinfer.utils import round_up def _ceil_to_ue8m0(x: torch.Tensor): """imported from DeepGEMM""" assert x.view(-1).amax().item() > 0 return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """imported from DeepGEMM""" assert x.dim() == 2 and x.size(1) % 128 == 0 m, n = x.shape x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) sf = _ceil_to_ue8m0(x_amax / 448.0) return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), sf def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """imported from DeepGEMM""" assert x.dim() == 2 m, n = x.shape x_padded = torch.zeros( (round_up(m, 128), round_up(n, 128)), dtype=x.dtype, device=x.device ) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) sf = _ceil_to_ue8m0(x_amax / 448.0) x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( x_view.size(0), x_view.size(2) ) def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode): """ Quantizes a 2D or 3D tensor to FP8. Args: x (torch.Tensor): The 2D or 3D input tensor. scale_shape (tuple): The shape of the scale tensor. tile_shape (tuple): The shape of the tiles. scale_major_mode (str): The tiling order, "K" for row-major like, or another value for column-major like. Returns: tuple: A tuple containing the quantized FP8 tensor and the calculated float32 scales. """ # 1. Assertions and Initial Setup ndim = x.ndim assert ndim in [2, 3], f"x.ndim must be 2 or 3, but got {ndim}" assert ndim == len(scale_shape) == len(tile_shape) fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_amax = torch.tensor(fp8_info.max, device=x.device, dtype=torch.float32) # 2. Tiling and Scale Calculation if ndim == 2: s0, s1 = scale_shape t0, t1 = tile_shape if scale_major_mode == "K": # Tile x and find the max absolute value in each tile x_tiled = rearrange(x, "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4) x_scale = abs_max / fp8_amax x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) # Broadcast scales back to the original tensor shape scales_repeated = repeat(x_scale, "s0 s1 -> (s0 t0) (s1 t1)", t0=t0, t1=t1) else: # Handle column-major tiling x_tiled = rearrange(x, "(s1 t0) (s0 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4) x_scale = abs_max / fp8_amax x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) # Permute scale axes before repeating to match layout scales_permuted = rearrange(x_scale, "s0 s1 -> s1 s0") scales_repeated = repeat( scales_permuted, "s1 s0 -> (s1 t0) (s0 t1)", t0=t0, t1=t1 ) elif ndim == 3: s0, s1, s2 = scale_shape t0, t1, t2 = tile_shape if scale_major_mode == "K": # Tile x and find the max absolute value in each tile x_tiled = rearrange( x, "(s0 t0) (s1 t1) (s2 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 ) abs_max = reduce( x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max" ).clamp(1e-4) x_scale = abs_max / fp8_amax x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) # Broadcast scales back to the original tensor shape scales_repeated = repeat( x_scale, "s0 s1 s2 -> (s0 t0) (s1 t1) (s2 t2)", t0=t0, t1=t1, t2=t2 ) else: # Handle layout where the last two axes are swapped x_tiled = rearrange( x, "(s0 t0) (s2 t1) (s1 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 ) abs_max = reduce( x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max" ).clamp(1e-4) x_scale = abs_max / fp8_amax x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) # Permute scale axes before repeating to match layout scales_permuted = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1") scales_repeated = repeat( scales_permuted, "s0 s2 s1 -> (s0 t0) (s2 t1) (s1 t2)", t0=t0, t1=t1, t2=t2, ) # 3. Final Quantization # Divide the original tensor by the broadcasted scales x_fp32 = x / (scales_repeated + 1e-8) # Convert the result to the target FP8 format x_fp8 = x_fp32.to(torch.float8_e4m3fn) return x_fp8, x_scale def dequantize_fp8(x, x_scale, scale_major_mode): """ Quantizes a 2D or 3D tensor to FP8. Args: x (torch.Tensor): The 2D or 3D input tensor. scale_shape (tuple): The shape of the scale tensor. tile_shape (tuple): The shape of the tiles. scale_major_mode (str): The tiling order, "K" for row-major like, or another value for column-major like. Returns: tuple: A tuple containing the quantized FP8 tensor and the calculated float32 scales. """ # 1. Assertions and Initial Setup ndim = x.ndim assert ndim in [2, 3], f"x.ndim must be 2 or 3, but got {ndim}" assert ndim == len(x_scale.shape) # 2. Tiling and Scale Calculation if ndim == 2: if scale_major_mode == "K": s0, s1 = x_scale.shape else: s1, s0 = x_scale.shape x = rearrange( x.to(torch.float32), "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1 ) if scale_major_mode == "K": x_scale = rearrange(x_scale, "s0 s1 -> s0 s1 1 1") else: x_scale = rearrange(x_scale, "s0 s1 -> s1 s0 1 1") out = rearrange(x * x_scale, "s0 s1 t0 t1 -> (s0 t0) (s1 t1)") elif ndim == 3: if scale_major_mode == "K": s0, s1, s2 = x_scale.shape else: s0, s2, s1 = x_scale.shape x = rearrange( x.to(torch.float32), "(s0 t0) (s1 t1) (s2 t2)-> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2, ) if scale_major_mode == "K": x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s1 s2 1 1 1") else: x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1 1 1 1") out = rearrange(x * x_scale, "s0 s1 s2 t0 t1 t2 -> (s0 t0) (s1 t1) (s2 t2)") return out def set_seed(random_seed): """ Set random seed for reproducibility during testing. Args: random_seed (int): Random seed to set. Returns: None """ torch.manual_seed(random_seed) random.seed(random_seed) np.random.seed(random_seed) if torch.cuda.is_available(): torch.cuda.manual_seed(random_seed) torch.cuda.manual_seed_all(random_seed) def sleep_after_kernel_run(execution_time): """ Sleep after kernel run. Dynamically adjust sleep time up to 1 sec based on execution time. Args: execution_time (float): Kernel execution time in milliseconds. Returns: None """ if not math.isinf(execution_time): sleep_time = np.min([execution_time / 200, 1.0]) else: sleep_time = 0.01 time.sleep(sleep_time) return def attention_flops( batch_size, qo_seqlen, kv_seqlen, head_dim_qk, head_dim_vo, num_qo_heads, causal, ): """ Calculate FLOPs for a given attention layer. Assumes all sequence lengths are the same within the batch Args: batch_size (int): Batch size. qo_seqlen (int): Sequence length of the query. Assumed same within the batch. kv_seqlen (int): Sequence length of the key and value. Assumed same within the batch. head_dim_qk (int): Head dimension of the query and key. head_dim_vo (int): Head dimension of the value. num_qo_heads (int): Number of query heads. causal (bool): Whether to use causal masking. FLOPs is halved for causal masking. Returns: total_flops (int): Total FLOPs for the layer. """ if causal: bmm1_flops = ( batch_size * (2 * kv_seqlen - qo_seqlen) * qo_seqlen * num_qo_heads * head_dim_qk ) bmm2_flops = ( batch_size * (2 * kv_seqlen - qo_seqlen) * qo_seqlen * num_qo_heads * head_dim_vo ) else: bmm1_flops = 2 * batch_size * qo_seqlen * kv_seqlen * num_qo_heads * head_dim_qk bmm2_flops = 2 * batch_size * qo_seqlen * kv_seqlen * num_qo_heads * head_dim_vo total_flops = bmm1_flops + bmm2_flops return total_flops def attention_flops_with_actual_seq_lens( actual_seq_lens_q, actual_seq_lens_kv, head_dim_qk, head_dim_vo, num_qo_heads, causal, ): """ Calculate FLOPs for a given attention layer with actual sequence lengths where actual sequence lengths are provided as 1D tensors. Args: actual_seq_lens_q (torch.Tensor): Array of actual sequence lengths of the query. actual_seq_lens_kv (torch.Tensor): Array of actual sequence lengths of the key and value. head_dim_qk (int): Head dimension of the query and key. head_dim_vo (int): Head dimension of the value. num_qo_heads (int): Number of query heads. causal (bool): Whether to use causal masking. Note: Causal must be false for decode as this function assumes qo_seqlen == kv_seqlen. Returns: total_flops (int): Total FLOPs for the layer. """ if causal: bmm1_flops = ( torch.dot( 2 * actual_seq_lens_kv.to(torch.float32) - actual_seq_lens_q.to(torch.float32), actual_seq_lens_q.to(torch.float32), ) * num_qo_heads * head_dim_qk ) bmm2_flops = ( torch.dot( 2 * actual_seq_lens_kv.to(torch.float32) - actual_seq_lens_q.to(torch.float32), actual_seq_lens_q.to(torch.float32), ) * num_qo_heads * head_dim_vo ) else: bmm1_flops = ( 2 * torch.dot( actual_seq_lens_kv.to(torch.float32), actual_seq_lens_q.to(torch.float32), ) * num_qo_heads * head_dim_qk ) bmm2_flops = ( 2 * torch.dot( actual_seq_lens_kv.to(torch.float32), actual_seq_lens_q.to(torch.float32), ) * num_qo_heads * head_dim_vo ) total_flops = bmm1_flops + bmm2_flops return total_flops def attention_tflops_per_sec( batch_size, qo_seqlen, kv_seqlen, head_dim_qk, head_dim_vo, num_qo_heads, causal, time, ): """ Calculate TFLOPS per second for a given attention layer. Assumes all sequence lengths are the same within the batch. Args: batch_size (int): Batch size. qo_seqlen (int): Sequence length of the query. kv_seqlen (int): Sequence length of the key and value. head_dim_qk (int): Head dimension of the query and key. head_dim_vo (int): Head dimension of the value. num_qo_heads (int): Number of query heads. causal (bool): Whether to use causal masking. time (float): Execution time in milliseconds. Returns: tflops_per_sec (float): TFLOPS per second for the layer. """ f = attention_flops( batch_size, qo_seqlen, kv_seqlen, head_dim_qk, head_dim_vo, num_qo_heads, causal, ) return f / time / 1e9 if not math.isnan(time) else 0.0 def attention_tflops_per_sec_with_actual_seq_lens( actual_seq_lens_q, actual_seq_lens_kv, head_dim_qk, head_dim_vo, num_qo_heads, causal, time, ): """ Calculate TFLOPS per second for a given attention layer with actual sequence lengths. Does not assume all sequence lengths are the same within the batch. Args: actual_seq_lens_q (torch.Tensor): Array of actual sequence lengths of the query. actual_seq_lens_kv (torch.Tensor): Array of actual sequence lengths of the key and value. head_dim_qk (int): Head dimension of the query and key. head_dim_vo (int): Head dimension of the value. num_qo_heads (int): Number of query heads. causal (bool): Whether to use causal masking. time (float): Execution time in milliseconds. Returns: tflops_per_sec (float): TFLOPS per second for the layer. """ f = attention_flops_with_actual_seq_lens( actual_seq_lens_q, actual_seq_lens_kv, head_dim_qk, head_dim_vo, num_qo_heads, causal, ) return f.item() / time / 1e9 if not math.isnan(time) else 0.0 def attention_tb_per_sec( batch_size, qo_seqlen, kv_seqlen, head_dim_qk, head_dim_vo, num_qo_heads, num_kv_heads, time, q_dtype=torch.bfloat16, kv_dtype=torch.bfloat16, o_dtype=torch.bfloat16, ): """ Calculate TB per second perf achieved for a given attention layer. Assumes all sequence lengths are the same within the batch. Args: batch_size (int): Batch size. qo_seqlen (int): Sequence length of the query. kv_seqlen (int): Sequence length of the key and value. head_dim_qk (int): Head dimension of the query and key. head_dim_vo (int): Head dimension of the value. num_qo_heads (int): Number of query heads. num_kv_heads (int): Number of key and value heads. time (float): Execution time in milliseconds. q_dtype (torch.dtype): Data type of the query. kv_dtype (torch.dtype): Data type of the key and value. o_dtype (torch.dtype): Data type of the output. Returns: tb_per_sec (float): TB per second for the layer. """ q_bytes = batch_size * qo_seqlen * num_qo_heads * head_dim_qk * q_dtype.itemsize k_bytes = batch_size * kv_seqlen * num_kv_heads * head_dim_qk * kv_dtype.itemsize v_bytes = batch_size * kv_seqlen * num_kv_heads * head_dim_vo * kv_dtype.itemsize o_bytes = batch_size * qo_seqlen * num_qo_heads * head_dim_vo * o_dtype.itemsize total_bytes = q_bytes + k_bytes + v_bytes + o_bytes time_in_sec = time / 1e3 bytes_in_tb = total_bytes / 1e12 # TB not TiB return bytes_in_tb / time_in_sec if not math.isnan(time) else 0.0 def attention_tb_per_sec_with_actual_seq_lens( actual_seq_lens_q, actual_seq_lens_kv, head_dim_qk, head_dim_vo, num_qo_heads, num_kv_heads, time, q_dtype=torch.bfloat16, kv_dtype=torch.bfloat16, o_dtype=torch.bfloat16, ): """ Calculate TB per second perf achieved for a given attention layer with actual sequence lengths. Does not assume all sequence lengths are the same within the batch. Args: actual_seq_lens_q (torch.Tensor): Array of actual sequence lengths of the query. actual_seq_lens_kv (torch.Tensor): Array of actual sequence lengths of the key and value. head_dim_qk (int): Head dimension of the query and key. head_dim_vo (int): Head dimension of the value. num_qo_heads (int): Number of query heads. num_kv_heads (int): Number of key and value heads. time (float): Execution time in milliseconds. q_dtype (torch.dtype): Data type of the query. kv_dtype (torch.dtype): Data type of the key and value. o_dtype (torch.dtype): Data type of the output. Returns: tb_per_sec (float): TB per second for the layer. """ q_bytes = ( torch.sum(actual_seq_lens_q) * num_qo_heads * head_dim_qk * q_dtype.itemsize ) k_bytes = ( torch.sum(actual_seq_lens_kv) * num_kv_heads * head_dim_qk * kv_dtype.itemsize ) v_bytes = ( torch.sum(actual_seq_lens_kv) * num_kv_heads * head_dim_vo * kv_dtype.itemsize ) o_bytes = ( torch.sum(actual_seq_lens_q) * num_qo_heads * head_dim_vo * o_dtype.itemsize ) total_bytes = (q_bytes + k_bytes + v_bytes + o_bytes).item() time_in_sec = time / 1e3 bytes_in_tb = total_bytes / 1e12 # TB not TiB return bytes_in_tb / time_in_sec if not math.isnan(time) else 0.0 def bench_gpu_time( fn, dry_run_iters: int = None, repeat_iters: int = None, dry_run_time_ms: int = 25, repeat_time_ms: int = 100, l2_flush: bool = True, l2_flush_size_mb: int = 256, l2_flush_device: str = "cuda", sleep_after_run: bool = False, ): """ Benchmark kernel execution time without using CUDA graphs. Measures kernel launch latency + actual kernel execution time for fn(). Can flush L2 cache and sleep after the run. Number of dry run and actual run iterations can be set by iteration count or time: - If dry_run_iters and repeat_iters are provided, provided iteration count will be used. - If dry_run_iters and repeat_iters are not provided, dry_run_time_ms and repeat_time_ms will be used. Returns an array of measured times so that the caller can compute statistics. Args: fn: Function to benchmark. dry_run_iters: Number of dry runs during which times does not count. If not provided, dry_run_time_ms will be used. repeat_iters: Number of iterations. If not provided, repeat_time_ms will be used. dry_run_time_ms: Time to run the dry run in milliseconds. repeat_time_ms: Time to run the repeat in milliseconds. l2_flush: Whether to flush L2 cache. l2_flush_size_mb: Size of the L2 cache to flush. l2_flush_device: Device that needs to flush L2 cache. sleep_after_run: Whether to sleep after the run. Sleep time is dynamically set. Returns: measured_times: List of measured times. """ start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) if l2_flush: l2_flush_size = int(l2_flush_size_mb) * 1024 * 1024 buffer = torch.empty(l2_flush_size, device=l2_flush_device, dtype=torch.int8) ## Estimate kernel execution time by running the kernel 5 times measurement_iters = 5 torch.cuda.synchronize() fn() # Call once to exclude initial overhead torch.cuda.synchronize() start_event.record() for _ in range(measurement_iters): if l2_flush: buffer.zero_() fn() end_event.record() torch.cuda.synchronize() estimated_kernel_execution_time = ( start_event.elapsed_time(end_event) / measurement_iters ) ## Set dry run and repeat iterations if dry_run_iters is None: dry_run_iters = max(1, int(dry_run_time_ms / estimated_kernel_execution_time)) if repeat_iters is None: repeat_iters = max(1, int(repeat_time_ms / estimated_kernel_execution_time)) # Dry runs torch.cuda.synchronize() for _ in range(dry_run_iters): if l2_flush: buffer.zero_() fn() torch.cuda.synchronize() # Actual run start_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat_iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat_iters)] torch.cuda.synchronize() for iter_idx in range(repeat_iters): if l2_flush: buffer.zero_() start_events[iter_idx].record() fn() end_events[iter_idx].record() if sleep_after_run: sleep_after_kernel_run(estimated_kernel_execution_time) # Synchronize once outside of the loop to avoid synchronization overhead torch.cuda.synchronize() measured_times = [] for iter_idx in range(repeat_iters): measured_times.append(start_events[iter_idx].elapsed_time(end_events[iter_idx])) return measured_times def bench_gpu_time_with_cudagraph( fn, dry_run_iters: int = None, repeat_iters: int = None, dry_run_time_ms: int = 25, repeat_time_ms: int = 100, num_iters_within_graph: int = 10, l2_flush: bool = True, l2_flush_size_mb: int = 256, l2_flush_device: str = "cuda", sleep_after_run: bool = False, ): """ Benchmark GPU time using by constructing CUDA graphs with kernel launch and then replaying the graph. Increasing the number of iterations within graph can amortize kernel launch latency to help obtain measurements close to GPU kernel time of fn(). Can flush L2 cache and sleep after the run. Number of dry run and actual run iterations can be set by iteration count or time: - If dry_run_iters and repeat_iters are provided, provided iteration count will be used. - If dry_run_iters and repeat_iters are not provided, dry_run_time_ms and repeat_time_ms will be used. Returns an array of measured times so that the caller can compute statistics. Uses PyTorch's API to construt and use CUDA Graphs. Also see PyTorch's post on CUDA Graphs: https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/ Args: fn: Function to benchmark. dry_run_iters: Number of dry runs during which times does not count. If not provided, dry_run_time_ms will be used. repeat_iters: Number of iterations. If not provided, repeat_time_ms will be used. dry_run_time_ms: Time to run the dry run in milliseconds. repeat_time_ms: Time to run the repeat in milliseconds. num_iters_within_graph: Number of iterations to run within the graph. l2_flush: Whether to flush L2 cache. l2_flush_size_mb: Size of the L2 cache to flush. l2_flush_device: Device that needs to flush L2 cache. sleep_after_run: Whether to sleep after the run. Sleep time is dynamically set. Returns: measured_times: List of measured times. """ start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) if l2_flush: l2_flush_size = int(l2_flush_size_mb) * 1024 * 1024 buffer = torch.empty(l2_flush_size, device=l2_flush_device, dtype=torch.int8) # Warmup run torch.cuda.synchronize() s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): fn() torch.cuda.current_stream().wait_stream(s) # Capture kernel in graph g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): for _ in range(num_iters_within_graph): fn() torch.cuda.synchronize() ## Estimate kernel execution time by running the kernel 5 times measurement_iters = 5 start_event.record() for _ in range(measurement_iters): if l2_flush: buffer.zero_() g.replay() end_event.record() torch.cuda.synchronize() estimated_kernel_execution_time = ( start_event.elapsed_time(end_event) / measurement_iters ) ## Set dry run and repeat iterations if dry_run_iters is None: dry_run_iters = max(1, int(dry_run_time_ms / estimated_kernel_execution_time)) if repeat_iters is None: repeat_iters = max(1, int(repeat_time_ms / estimated_kernel_execution_time)) # Dry run torch.cuda.synchronize() for _ in range(dry_run_iters): if l2_flush: buffer.zero_() g.replay() torch.cuda.synchronize() # Actual run start_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat_iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat_iters)] torch.cuda.synchronize() for iter_idx in range(repeat_iters): if l2_flush: buffer.zero_() start_events[iter_idx].record() g.replay() end_events[iter_idx].record() if sleep_after_run: sleep_after_kernel_run(estimated_kernel_execution_time) # Synchronize once outside of the loop to avoid synchronization overhead torch.cuda.synchronize() measured_times = [] for iter_idx in range(repeat_iters): measured_times.append( start_events[iter_idx].elapsed_time(end_events[iter_idx]) / num_iters_within_graph ) return measured_times class empty_suppress: def __enter__(self): return self def __exit__(self, *_): pass class suppress_stdout_stderr: def __enter__(self): self.outnull_file = open(os.devnull, "w") self.errnull_file = open(os.devnull, "w") self.old_stdout_fileno_undup = sys.stdout.fileno() self.old_stderr_fileno_undup = sys.stderr.fileno() self.old_stdout_fileno = os.dup(sys.stdout.fileno()) self.old_stderr_fileno = os.dup(sys.stderr.fileno()) self.old_stdout = sys.stdout self.old_stderr = sys.stderr os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) sys.stdout = self.outnull_file sys.stderr = self.errnull_file return self def __exit__(self, *_): sys.stdout = self.old_stdout sys.stderr = self.old_stderr os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) os.close(self.old_stdout_fileno) os.close(self.old_stderr_fileno) self.outnull_file.close() self.errnull_file.close() # copied from DeepGEMM def bench_kineto( fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, trace_path: str = None, flush_l2: bool = True, with_multiple_kernels: bool = False, ): # Conflict with Nsight Systems using_nsys = int(os.environ.get("DG_NSYS_PROFILING", 0)) # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle flush_l2_size = int(8e9 // 4) # For some auto-tuning kernels with prints fn() # Profile suppress = ( suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress ) with suppress(): schedule = ( torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None ) profiler: Any = ( torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule ) if not using_nsys else empty_suppress() ) with profiler: for _i in range(2): for _ in range(num_tests): if flush_l2: torch.empty( flush_l2_size, dtype=torch.int, device="cuda" ).zero_() fn() if not using_nsys: profiler.step() # Return 1 if using Nsight Systems if using_nsys: return 1 # Parse the profiling table assert isinstance(kernel_names, (str, tuple)) is_tuple = isinstance(kernel_names, tuple) prof_lines = ( profiler.key_averages() .table(sort_by="cuda_time_total", max_name_column_width=100) .split("\n") ) kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names assert all([isinstance(name, str) for name in kernel_names]) if not with_multiple_kernels: for name in kernel_names: assert sum([name in line for line in prof_lines]) == 1, ( f"Errors of the kernel {name} in the profiling table" ) # Save chrome traces if trace_path is not None: profiler.export_chrome_trace(trace_path) # Return average kernel times units = {"ms": 1e3, "us": 1e6} kernel_times = [] for name in kernel_names: total_time = 0.0 total_num = 0 for line in prof_lines: if name in line: time_str = line.split()[-2] num_str = line.split()[-1] for unit, scale in units.items(): if unit in time_str: total_time += ( float(time_str.replace(unit, "")) / scale * int(num_str) ) total_num += int(num_str) break kernel_times.append(total_time / total_num) return tuple(kernel_times) if is_tuple else kernel_times[0] def count_bytes(*tensors): total = 0 for t in tensors: if isinstance(t, (tuple, list)): total += count_bytes(*t) elif t is not None: total += t.numel() * t.element_size() return total