from typing import List, Optional, Tuple import torch from sgl_kernel.utils import _get_cache_buf, get_cuda_stream def awq_dequantize( qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor ) -> torch.ByteTensor: return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros) def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): return torch.ops.sgl_kernel.int8_scaled_mm.default( mat_a, mat_b, scales_a, scales_b, out_dtype, bias, ) def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default( mat_a, mat_b, scales_a, scales_b, out_dtype, ) def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): return torch.ops.sgl_kernel.fp8_scaled_mm.default( mat_a, mat_b, scales_a, scales_b, out_dtype, bias, ) def _bmm_fp8_internal( workspace_buffer: torch.Tensor, A: torch.Tensor, B: torch.Tensor, D: torch.Tensor, A_scale: torch.Tensor, B_scale: torch.Tensor, ) -> None: cublas_handle = torch.cuda.current_blas_handle() torch.ops.sgl_kernel.bmm_fp8.default( A, B, D, A_scale, B_scale, workspace_buffer, cublas_handle, get_cuda_stream(), ) def bmm_fp8( A: torch.Tensor, B: torch.Tensor, A_scale: torch.Tensor, B_scale: torch.Tensor, dtype: torch.dtype, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: if out is None: out = torch.empty( (A.shape[0], A.shape[1], B.shape[2]), device=A.device, dtype=dtype, ) workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device) _bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale) return out def sgl_per_token_group_quant_fp8( input: torch.Tensor, output_q: torch.Tensor, output_s: torch.Tensor, group_size: int, eps: float, fp8_min: float, fp8_max: float, ) -> None: torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default( input, output_q, output_s, group_size, eps, fp8_min, fp8_max ) def sgl_per_token_group_quant_int8( input: torch.Tensor, output_q: torch.Tensor, output_s: torch.Tensor, group_size: int, eps: float, int8_min: float, int8_max: float, ) -> None: torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default( input, output_q, output_s, group_size, eps, int8_min, int8_max ) def sgl_per_tensor_quant_fp8( input: torch.Tensor, output_q: torch.Tensor, output_s: torch.Tensor, is_static: bool, ) -> None: torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default( input, output_q, output_s, is_static ) def cublas_grouped_gemm( inputs: List[torch.Tensor], weights: List[torch.Tensor], outputs: List[torch.Tensor], out_dtype: torch.dtype, ) -> None: assert ( len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0 ), "Inputs/weights/outputs should not be empty!" cublas_handle = torch.cuda.current_blas_handle() torch.ops.sgl_kernel.cublas_grouped_gemm.default( inputs, weights, outputs, out_dtype, cublas_handle, get_cuda_stream(), ) def sgl_per_token_quant_fp8( input: torch.Tensor, output_q: torch.Tensor, output_s: torch.Tensor, ) -> None: torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s) def cutlass_scaled_fp4_mm( a: torch.Tensor, b: torch.Tensor, block_scale_a: torch.Tensor, block_scale_b: torch.Tensor, alpha: torch.Tensor, out_dtype: torch.dtype, ) -> torch.Tensor: assert a.ndim == 2 and b.ndim == 2 m, n = a.shape[0], b.shape[0] out = torch.empty((m, n), dtype=out_dtype, device=a.device) torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default( out, a, b, block_scale_a, block_scale_b, alpha ) return out def scaled_fp4_quant( input: torch.Tensor, input_global_scale: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP4 and return quantized tensor and scale. This function quantizes the last dimension of the given tensor `input`. For every 16 consecutive elements, a single dynamically computed scaling factor is shared. This scaling factor is quantized using the `input_global_scale` and is stored in a swizzled layout (see https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x). Args: input: The input tensor to be quantized to FP4 input_global_scale: A scalar scaling factor for the entire tensor. Returns: Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every two values are packed into a uint8 and float8_e4m3 scaling factors in a sizzled layout. """ assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}." other_dims = 1 if input.ndim == 1 else -1 input = input.reshape(other_dims, input.shape[-1]) m, n = input.shape block_size = 16 device = input.device assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}." assert input.dtype in ( torch.float16, torch.bfloat16, ), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}." # Two fp4 values will be packed into an uint8. output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) # We use the rounded values to store the swizzled values. Then, the scaling # factors in float8_e4m3fn are packed into an int32 for every 4 values. rounded_m = ((m + 128 - 1) // 128) * 128 scale_n = n // block_size rounded_n = ((scale_n + 4 - 1) // 4) * 4 output_scale = torch.empty( (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 ) torch.ops.sgl_kernel.scaled_fp4_quant.default( output, input, output_scale, input_global_scale ) output_scale = output_scale.view(torch.float8_e4m3fn) return output, output_scale