import torch import flashinfer.utils as utils FLOAT4_E2M1_MAX = 6.0 # E2M1 to float # 0111 -> 6 # 0110 -> 4 # 0101 -> 3 # 0100 -> 2 # 0011 -> 1.5 # 0010 -> 1 # 0001 -> 0.5 # 0000 -> 0 E2M1_TO_FLOAT32 = [ 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, ] def cast_from_fp4(x): # The fp4 values are packed in uint8 as [v_1st | v_2nd] v_2nd = x & 0xF v_1st = (x >> 4) & 0xF c = torch.stack((v_2nd, v_1st), dim=-1) new_shape = c.shape[:-2] + ( c.shape[-2] * c.shape[-1], ) # fuse the dim added by stack lookup_table = torch.tensor(E2M1_TO_FLOAT32, device=c.device) out = lookup_table[c.to(torch.long)].reshape(new_shape).to(torch.float32) return out def cast_to_fp4(x): sign = torch.sign(x) x = torch.abs(x) x[(x >= 0.0) & (x <= 0.25)] = 0.0 x[(x > 0.25) & (x < 0.75)] = 0.5 x[(x >= 0.75) & (x <= 1.25)] = 1.0 x[(x > 1.25) & (x < 1.75)] = 1.5 x[(x >= 1.75) & (x <= 2.5)] = 2.0 x[(x > 2.5) & (x < 3.5)] = 3.0 x[(x >= 3.5) & (x <= 5.0)] = 4.0 x[x > 5.0] = 6.0 return x * sign def get_reciprocal(x): if isinstance(x, torch.Tensor): return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) elif isinstance(x, (float, int)): return 0.0 if x == 0 else 1.0 / x else: raise TypeError("Input must be a float, int, or a torch.Tensor.") def ref_fp4_quant(x, global_scale, block_size, sf_use_ue8m0=False): assert isinstance(global_scale, (float, int)) or global_scale.dtype == torch.float32 sliced_shape = x.shape[:-1] + (x.shape[-1] // block_size, block_size) sliced_x = torch.reshape(x, sliced_shape) vec_max = torch.max(torch.abs(sliced_x), dim=-1, keepdim=True)[0].to(torch.float32) scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) if sf_use_ue8m0: scale = (scale.view(torch.int32) + 0x007FFFFF) & 0x7F800000 scale = scale.view(torch.float32) else: scale = scale.to(torch.float8_e4m3fn).to(torch.float32) output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) scaled_x = sliced_x.to(torch.float32) * output_scale clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(x.shape) return cast_to_fp4(clipped_x), scale.squeeze(-1) def recover_swizzled_scales(scale, m, n, block_size, sf_start_index=0): assert sf_start_index + m <= scale.shape[0] full_m = scale.shape[0] scale_n = n // block_size rounded_n = utils.round_up(scale_n, 4) # Recover the swizzled scaling factor to linear layout tmp = torch.reshape(scale, (1, full_m // 128, rounded_n // 4, 32, 4, 4)) tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) result = torch.reshape(tmp, (full_m, rounded_n)).to(torch.float32) return result[sf_start_index : sf_start_index + m, :scale_n]