from typing import List, Optional, Tuple import torch from sglang.srt.layers.quantization.int8_kernel import ( per_token_group_quant_int8, w8a8_block_int8_matmul, ) def apply_w8a8_block_int8_linear( input: torch.Tensor, weight: torch.Tensor, block_size: List[int], weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert input_scale is None # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1]) output = w8a8_block_int8_matmul( q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype ) if bias is not None: output = output + bias return output.to(dtype=input.dtype).view(*output_shape) def input_to_int8( x: torch.Tensor, dtype: torch.dtype = torch.int8 ) -> Tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to int8 values with tensor-wise quantization.""" iinfo = torch.iinfo(dtype) min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) int8_min, int8_max = iinfo.min, iinfo.max scale = int8_max / amax x_scl_sat = (x * scale).clamp(min=int8_min, max=int8_max) return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() def block_dequant( x_q_block: torch.Tensor, x_s: torch.Tensor, block_size: List[int], ) -> torch.Tensor: """This function conducts block-wise dequantization. The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale and the block size. The outputs are dequantized tensor. """ block_n, block_k = block_size[0], block_size[1] n, k = x_q_block.shape n_tiles = (n + block_n - 1) // block_n k_tiles = (k + block_k - 1) // block_k assert n_tiles == x_s.shape[0] assert k_tiles == x_s.shape[1] x_dq_block = x_q_block.to(torch.float32) for i in range(k_tiles): for j in range(n_tiles): x_dq_block[ j * block_n : min((j + 1) * block_n, n), i * block_k : min((i + 1) * block_k, k), ] *= x_s[j][i] return x_dq_block