565 lines
21 KiB
Python
565 lines
21 KiB
Python
import os
|
|
from typing import List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
|
_enable_jit_deepgemm,
|
|
per_token_group_quant_fp8,
|
|
static_quant_fp8,
|
|
w8a8_block_fp8_matmul,
|
|
)
|
|
from sglang.srt.utils import (
|
|
get_bool_env_var,
|
|
get_cuda_version,
|
|
get_device_capability,
|
|
is_cuda,
|
|
is_hip,
|
|
)
|
|
|
|
try:
|
|
import vllm
|
|
from vllm import _custom_ops as ops
|
|
|
|
VLLM_AVAILABLE = True
|
|
except ImportError:
|
|
VLLM_AVAILABLE = False
|
|
|
|
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
|
|
|
_is_hip = is_hip()
|
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
|
from aiter import gemm_a8w8_blockscale
|
|
|
|
_is_cuda = is_cuda()
|
|
if _is_cuda:
|
|
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
|
|
|
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
|
|
|
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
|
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
|
|
|
_TORCH_VERSION = torch.__version__.split("+")[0]
|
|
try:
|
|
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
|
|
except ValueError:
|
|
_TORCH_VERSION_TUPLE = (0, 0, 0)
|
|
|
|
# The condition to determine if it is on a platform that supports
|
|
# torch._scaled_mm rowwise feature.
|
|
# The condition is determined once as the operations
|
|
# are time consuming.
|
|
USE_ROWWISE_TORCH_SCALED_MM = (
|
|
_is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
|
|
)
|
|
|
|
|
|
def cutlass_fp8_supported():
|
|
if not _is_cuda:
|
|
return False
|
|
major, minor = get_device_capability()
|
|
cuda_version = get_cuda_version()
|
|
if major >= 9:
|
|
return cuda_version >= (12, 0)
|
|
elif major == 8 and minor == 9:
|
|
return cuda_version >= (12, 4)
|
|
return False
|
|
|
|
|
|
def normalize_e4m3fn_to_e4m3fnuz(
|
|
weight: torch.Tensor,
|
|
weight_scale: torch.Tensor,
|
|
input_scale: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
|
assert weight.dtype == torch.float8_e4m3fn
|
|
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
|
# but NaN in e4m3fnuz. So here we set it to 0.
|
|
# https://onnx.ai/onnx/technical/float8.html
|
|
weight_as_int8 = weight.view(torch.int8)
|
|
ROCM_FP8_NAN_AS_INT = -128
|
|
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
|
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
|
|
|
# For the same bits representation, e4m3fnuz value is half of
|
|
# the e4m3fn value, so we should double the scaling factor to
|
|
# get the same dequantized value.
|
|
# https://onnx.ai/onnx/technical/float8.html
|
|
weight_scale = weight_scale * 2.0
|
|
if input_scale is not None:
|
|
input_scale = input_scale * 2.0
|
|
return weight, weight_scale, input_scale
|
|
|
|
|
|
def cutlass_block_fp8_supported() -> bool:
|
|
if not get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"):
|
|
return False
|
|
if _is_cuda:
|
|
major, minor = torch.cuda.get_device_capability()
|
|
sm_version = major * 10 + minor
|
|
cuda_version = tuple(map(int, torch.version.cuda.split(".")))
|
|
if cuda_version >= (12, 0) and sm_version >= 90:
|
|
return True
|
|
return False
|
|
|
|
|
|
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
|
|
|
|
|
|
def apply_w8a8_block_fp8_linear(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
block_size: List[int],
|
|
weight_scale: torch.Tensor,
|
|
input_scale: Optional[torch.Tensor] = None,
|
|
bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
assert input_scale is None
|
|
# View input as 2D matrix for fp8 methods
|
|
input_2d = input.view(-1, input.shape[-1])
|
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
|
# TODO: add more robust shape check here
|
|
shape_supported_by_cutlass = (
|
|
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
|
|
)
|
|
if CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
|
|
q_input, x_scale = per_token_group_quant_fp8(
|
|
input_2d, block_size[1], column_major_scales=True
|
|
)
|
|
output = fp8_blockwise_scaled_mm(
|
|
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
|
)
|
|
elif _is_hip and get_bool_env_var("CK_MOE"):
|
|
q_input, x_scale = per_token_group_quant_fp8(
|
|
input_2d, block_size[1], column_major_scales=False
|
|
)
|
|
output = torch.zeros(
|
|
[q_input.shape[0], weight.shape[0]],
|
|
dtype=input.dtype,
|
|
device=q_input.device,
|
|
)
|
|
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
|
else:
|
|
if _enable_jit_deepgemm:
|
|
q_input, x_scale = per_token_group_quant_fp8(
|
|
input_2d,
|
|
block_size[1],
|
|
column_major_scales=True,
|
|
scale_tma_aligned=True,
|
|
)
|
|
else:
|
|
q_input, x_scale = per_token_group_quant_fp8(
|
|
input_2d, block_size[1], column_major_scales=False
|
|
)
|
|
output = w8a8_block_fp8_matmul(
|
|
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
|
|
)
|
|
|
|
if bias is not None:
|
|
output = output + bias
|
|
return output.to(dtype=input.dtype).view(*output_shape)
|
|
|
|
|
|
def input_to_float8(
|
|
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
|
finfo = torch.finfo(dtype)
|
|
min_val, max_val = x.aminmax()
|
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
|
fp8_max = finfo.max
|
|
if _is_hip:
|
|
fp8_max = 224.0
|
|
scale = fp8_max / amax
|
|
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
|
|
|
|
|
def block_quant_to_tensor_quant(
|
|
x_q_block: torch.Tensor,
|
|
x_s: torch.Tensor,
|
|
block_size: List[int],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""This function converts block-wise quantization to tensor-wise quantization.
|
|
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
|
|
and the block size.
|
|
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
|
|
Note only float8 is supported for now.
|
|
"""
|
|
block_n, block_k = block_size[0], block_size[1]
|
|
n, k = x_q_block.shape
|
|
n_tiles = (n + block_n - 1) // block_n
|
|
k_tiles = (k + block_k - 1) // block_k
|
|
assert n_tiles == x_s.shape[0]
|
|
assert k_tiles == x_s.shape[1]
|
|
|
|
x_dq_block = x_q_block.to(torch.float32)
|
|
|
|
x_dq_block_tiles = [
|
|
[
|
|
x_dq_block[
|
|
j * block_n : min((j + 1) * block_n, n),
|
|
i * block_k : min((i + 1) * block_k, k),
|
|
]
|
|
for i in range(k_tiles)
|
|
]
|
|
for j in range(n_tiles)
|
|
]
|
|
|
|
for i in range(k_tiles):
|
|
for j in range(n_tiles):
|
|
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
|
|
|
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
|
return x_q_tensor, scale
|
|
|
|
|
|
def apply_fp8_linear(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
weight_scale: torch.Tensor,
|
|
input_scale: Optional[torch.Tensor] = None,
|
|
input_scale_ub: Optional[torch.Tensor] = None,
|
|
bias: Optional[torch.Tensor] = None,
|
|
cutlass_fp8_supported: bool = True,
|
|
use_per_token_if_dynamic: bool = False,
|
|
) -> torch.Tensor:
|
|
# View input as 2D matrix for fp8 methods
|
|
input_2d = input.view(-1, input.shape[-1])
|
|
output_shape = [*input.shape[:-1], weight.shape[1]]
|
|
|
|
# cutlass w8a8 fp8 sgl-kernel only supports per-token scale
|
|
if input_scale is not None:
|
|
assert input_scale.numel() == 1
|
|
# broadcast per-tensor scale to per-token scale when supporting cutlass
|
|
qinput, x_scale = static_quant_fp8(
|
|
input_2d, input_scale, repeat_scale=cutlass_fp8_supported
|
|
)
|
|
else:
|
|
# default use per-token quantization if dynamic
|
|
if _is_cuda:
|
|
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
|
else:
|
|
qinput, x_scale = per_token_group_quant_fp8(
|
|
input_2d, group_size=input_2d.shape[1]
|
|
)
|
|
|
|
if cutlass_fp8_supported:
|
|
try:
|
|
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
|
# Fall back to vllm cutlass w8a8 fp8 kernel
|
|
output = ops.cutlass_scaled_mm(
|
|
qinput,
|
|
weight,
|
|
out_dtype=input.dtype,
|
|
scale_a=x_scale,
|
|
scale_b=weight_scale,
|
|
bias=bias,
|
|
)
|
|
else:
|
|
assert (
|
|
weight_scale.numel() == weight.shape[1]
|
|
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
|
output = fp8_scaled_mm(
|
|
qinput,
|
|
weight,
|
|
x_scale,
|
|
weight_scale,
|
|
out_dtype=input.dtype,
|
|
bias=bias,
|
|
)
|
|
return output.view(*output_shape)
|
|
except (ImportError, NameError, AttributeError):
|
|
pass
|
|
|
|
# torch.scaled_mm supports per tensor weights + activations only
|
|
# so fallback to naive if per channel or per token
|
|
else:
|
|
per_tensor_weights = weight_scale.numel() == 1
|
|
per_tensor_activations = x_scale.numel() == 1
|
|
|
|
if per_tensor_weights and per_tensor_activations:
|
|
# Fused GEMM_DQ
|
|
output = torch._scaled_mm(
|
|
qinput,
|
|
weight,
|
|
out_dtype=input.dtype,
|
|
scale_a=x_scale,
|
|
scale_b=weight_scale,
|
|
bias=bias,
|
|
)
|
|
# A fix for discrepancy in scaled_mm which returns tuple
|
|
# for torch < 2.5 and a single value in torch >= 2.5
|
|
if type(output) is tuple and len(output) == 2:
|
|
output = output[0]
|
|
|
|
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
|
|
|
else:
|
|
# Fallback for channelwise case, where we use unfused DQ
|
|
# due to limitations with scaled_mm
|
|
|
|
# Symmetric quantized GEMM by definition computes the following:
|
|
# C = (s_x * X) (s_w * W) + bias
|
|
# This is equivalent to dequantizing the weights and activations
|
|
# before applying a GEMM.
|
|
#
|
|
# In order to compute quantized operands, a quantized kernel
|
|
# will rewrite the above like so:
|
|
# C = s_w * s_x * (X * W) + bias
|
|
#
|
|
# For the scaled_mm fallback case, we break this down, since it
|
|
# does not support s_w being a vector.
|
|
|
|
# Making sure the dummy tensor is on the same device as the weight
|
|
global TORCH_DEVICE_IDENTITY
|
|
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
|
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
|
|
|
# GEMM
|
|
# This computes C = (X * W).
|
|
# Output in fp32 to allow subsequent ops to happen in-place
|
|
output = torch._scaled_mm(
|
|
qinput,
|
|
weight,
|
|
scale_a=TORCH_DEVICE_IDENTITY,
|
|
scale_b=TORCH_DEVICE_IDENTITY,
|
|
out_dtype=torch.float32,
|
|
)
|
|
# A fix for discrepancy in scaled_mm which returns tuple
|
|
# for torch < 2.5 and a single value in torch >= 2.5
|
|
if type(output) is tuple and len(output) == 2:
|
|
output = output[0]
|
|
# Unpad (undo num_token_padding)
|
|
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
|
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
|
|
|
# DQ
|
|
# C = sw * sx * (X * W) + bias
|
|
output = output * x_scale * weight_scale.t()
|
|
if bias is not None:
|
|
output = output + bias
|
|
return output.to(dtype=input.dtype).view(*output_shape)
|
|
|
|
|
|
def maybe_create_device_identity():
|
|
# Allocate dummy ones tensor for torch._scaled_mm
|
|
global TORCH_DEVICE_IDENTITY
|
|
if TORCH_DEVICE_IDENTITY is None:
|
|
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
|
|
|
|
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
|
|
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
|
# https://github.com/vllm-project/vllm/issues/14397
|
|
class Fp8LinearOp:
|
|
"""
|
|
This class executes a FP8 linear layer using cutlass if supported and
|
|
torch.scaled_mm otherwise.
|
|
It needs to be a class instead of a method so that config can be read
|
|
in the __init__ method, as reading config is not allowed inside forward.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
|
use_per_token_if_dynamic: bool = False,
|
|
pad_output: Optional[bool] = None,
|
|
):
|
|
self.cutlass_fp8_supported = cutlass_fp8_supported
|
|
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
|
|
|
# Note: we pad the input because torch._scaled_mm is more performant
|
|
# for matrices with batch dimension > 16.
|
|
# This could change in the future.
|
|
# We also don't pad when using torch.compile,
|
|
# as it breaks with dynamic shapes.
|
|
if pad_output is None:
|
|
enable_torch_compile = os.environ.get(
|
|
"SGLANG_ENABLE_TORCH_COMPILE", "0"
|
|
).lower() in ("1", "true", "yes")
|
|
pad_output = not enable_torch_compile
|
|
self.output_padding = 17 if pad_output else None
|
|
|
|
def apply(
|
|
self,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
weight_scale: torch.Tensor,
|
|
input_scale: Optional[torch.Tensor] = None,
|
|
input_scale_ub: Optional[torch.Tensor] = None,
|
|
bias: Optional[torch.Tensor] = None,
|
|
# TODO(luka) remove this parameter in favor of __init__
|
|
use_per_token_if_dynamic: Optional[bool] = None,
|
|
) -> torch.Tensor:
|
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
|
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
|
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
|
|
|
# View input as 2D matrix for fp8 methods
|
|
input_2d = input.view(-1, input.shape[-1])
|
|
output_shape = [*input.shape[:-1], weight.shape[1]]
|
|
|
|
# TODO(luka) this is here because currently MLA only decides this
|
|
# during the forward method instead of in __init__.
|
|
if use_per_token_if_dynamic is None:
|
|
use_per_token_if_dynamic = self.use_per_token_if_dynamic
|
|
|
|
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
|
# for sgl-kernel fp8_scaled_mm, it support per channel W now
|
|
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
|
|
if _is_cuda:
|
|
qinput, x_scale = sgl_scaled_fp8_quant(
|
|
input_2d,
|
|
input_scale,
|
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
|
)
|
|
else:
|
|
qinput, x_scale = ops.scaled_fp8_quant(
|
|
input_2d,
|
|
input_scale,
|
|
scale_ub=input_scale_ub,
|
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
|
)
|
|
|
|
# Fused GEMM_DQ
|
|
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
|
# Fall back to vllm cutlass w8a8 fp8 kernel
|
|
output = ops.cutlass_scaled_mm(
|
|
qinput,
|
|
weight,
|
|
out_dtype=input.dtype,
|
|
scale_a=x_scale,
|
|
scale_b=weight_scale,
|
|
bias=bias,
|
|
)
|
|
else:
|
|
assert (
|
|
weight_scale.numel() == weight.shape[1]
|
|
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
|
output = fp8_scaled_mm(
|
|
qinput,
|
|
weight,
|
|
x_scale,
|
|
weight_scale,
|
|
out_dtype=input.dtype,
|
|
bias=bias,
|
|
)
|
|
return output.view(*output_shape)
|
|
|
|
# torch.scaled_mm supports per tensor weights + activations only
|
|
# so fallback to naive if per channel or per token
|
|
else:
|
|
# Maybe apply padding to output, see comment in __init__
|
|
if _is_cuda:
|
|
qinput, x_scale = sgl_scaled_fp8_quant(
|
|
input_2d,
|
|
input_scale,
|
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
|
)
|
|
if self.output_padding:
|
|
pad_size = max(self.output_padding - qinput.shape[0], 0)
|
|
if pad_size > 0:
|
|
qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size))
|
|
else:
|
|
qinput, x_scale = ops.scaled_fp8_quant(
|
|
input_2d,
|
|
input_scale,
|
|
num_token_padding=self.output_padding,
|
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
|
)
|
|
|
|
per_tensor_weights = weight_scale.numel() == 1
|
|
per_tensor_activations = x_scale.numel() == 1
|
|
|
|
if per_tensor_weights and per_tensor_activations:
|
|
# Fused GEMM_DQ
|
|
output = torch._scaled_mm(
|
|
qinput,
|
|
weight,
|
|
out_dtype=input.dtype,
|
|
scale_a=x_scale,
|
|
scale_b=weight_scale,
|
|
bias=bias,
|
|
)
|
|
# A fix for discrepancy in scaled_mm which returns tuple
|
|
# for torch < 2.5 and a single value in torch >= 2.5
|
|
if type(output) is tuple and len(output) == 2:
|
|
output = output[0]
|
|
|
|
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
|
|
|
elif (
|
|
use_per_token_if_dynamic
|
|
and not per_tensor_weights
|
|
and not per_tensor_activations
|
|
and USE_ROWWISE_TORCH_SCALED_MM
|
|
):
|
|
# For now validated on ROCm platform
|
|
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
|
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
|
|
# and ROCm 6.3, which only exists in torch 2.7 and above.
|
|
# For CUDA platform please validate if the
|
|
# torch._scaled_mm support rowwise scaled GEMM
|
|
# Fused GEMM_DQ Rowwise GEMM
|
|
output = torch._scaled_mm(
|
|
qinput,
|
|
weight,
|
|
out_dtype=input.dtype,
|
|
scale_a=x_scale,
|
|
scale_b=weight_scale.t(),
|
|
bias=bias,
|
|
)
|
|
|
|
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
|
output = output.view(*output_shape)
|
|
return output
|
|
|
|
else:
|
|
# Fallback for channelwise case, where we use unfused DQ
|
|
# due to limitations with scaled_mm
|
|
|
|
# Symmetric quantized GEMM by definition computes the following:
|
|
# C = (s_x * X) (s_w * W) + bias
|
|
# This is equivalent to dequantizing the weights and activations
|
|
# before applying a GEMM.
|
|
#
|
|
# In order to compute quantized operands, a quantized kernel
|
|
# will rewrite the above like so:
|
|
# C = s_w * s_x * (X * W) + bias
|
|
#
|
|
# For the scaled_mm fallback case, we break this down, since it
|
|
# does not support s_w being a vector.
|
|
|
|
# GEMM
|
|
# This computes C = (X * W).
|
|
# Output in fp32 to allow subsequent ops to happen in-place
|
|
|
|
global TORCH_DEVICE_IDENTITY
|
|
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
|
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
|
|
|
output = torch._scaled_mm(
|
|
qinput,
|
|
weight,
|
|
scale_a=TORCH_DEVICE_IDENTITY,
|
|
scale_b=TORCH_DEVICE_IDENTITY,
|
|
out_dtype=torch.float32,
|
|
)
|
|
# A fix for discrepancy in scaled_mm which returns tuple
|
|
# for torch < 2.5 and a single value in torch >= 2.5
|
|
if type(output) is tuple and len(output) == 2:
|
|
output = output[0]
|
|
# Unpad (undo num_token_padding)
|
|
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
|
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
|
|
|
# DQ
|
|
# C = sw * sx * (X * W) + bias
|
|
output = output * x_scale * weight_scale.t()
|
|
if bias is not None:
|
|
output = output + bias
|
|
return output.to(dtype=input.dtype).view(*output_shape)
|