from typing import Tuple import torch import triton import triton.language as tl fused_softcap_autotune = triton.autotune( configs=[ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4), triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=8), triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=16), triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=4), triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=8), triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=4), triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=8), triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=16), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=32), triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=32), triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=32), triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32), triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32), triton.Config(kwargs={"BLOCK_SIZE": 32768}, num_warps=32), ], key=["n_ele"], ) @triton.jit def fused_softcap_kernel( output_ptr, input_ptr, n_ele, softcap_const: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_ele x = tl.load(input_ptr + offsets, mask=mask) fx = x.to(tl.float32) fxs = fx / softcap_const exped = tl.exp(2 * fxs) top = exped - 1 bottom = exped + 1 output = top / bottom * softcap_const tl.store(output_ptr + offsets, output, mask=mask) fused_softcap_kernel_autotuned = fused_softcap_autotune(fused_softcap_kernel) def fused_softcap(x, softcap_const, autotune=False): output = torch.empty_like(x, dtype=torch.float32) n_elements = output.numel() if autotune: grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) fused_softcap_kernel_autotuned[grid](output, x, n_elements, softcap_const) else: fused_softcap_kernel[(triton.cdiv(n_elements, 128),)]( output, x, n_elements, softcap_const, BLOCK_SIZE=128, num_warps=8 ) return output # cast to float + softcap class Softcap: def __init__(self, softcap_const: float): self.softcap_const = softcap_const def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) def forward(self, x: torch.Tensor) -> torch.Tensor: if x.is_cuda: return self.forward_cuda(x) else: return self.forward_native(x) def forward_native(self, x: torch.Tensor) -> torch.Tensor: return torch.tanh(x.float() / self.softcap_const) * self.softcap_const def forward_cuda(self, x: torch.Tensor, autotune=False) -> torch.Tensor: return fused_softcap(x, self.softcap_const, autotune=autotune) rmsnorm_autotune = triton.autotune( configs=[ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=1), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=1), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=4), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=4), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=8), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=8), triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8), triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16), triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8, num_stages=4), triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16, num_stages=4), triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=8), triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=16), triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8), triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16), triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32), triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=1), triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=1), triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=1), triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=4), triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=4), triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4), triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8), triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16), triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32), triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=1), triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=1), triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=1), triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=4), triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=4), triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4), ], key=["hidden_dim"], ) @triton.jit def fused_dual_residual_rmsnorm_kernel( output_ptr, mid_ptr, activ_ptr, residual_ptr, weight1_ptr, weight2_ptr, eps: tl.constexpr, hidden_dim: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) input_start = pid * hidden_dim offsets = tl.arange(0, BLOCK_SIZE) mask = offsets < hidden_dim a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0) a = a_.to(tl.float32) rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps) r = tl.load(residual_ptr + input_start + offsets, mask=mask, other=0.0) w1_ = tl.load(weight1_ptr + offsets, mask=mask, other=0.0) w1 = w1_.to(tl.float32) a2r = r + (a / rms * w1).to(r.dtype) tl.store( mid_ptr + input_start + offsets, a2r, mask=mask, ) a2r = a2r.to(tl.float32) rms2 = tl.sqrt(tl.sum(a2r * a2r, axis=0) / hidden_dim + eps) w2_ = tl.load(weight2_ptr + offsets, mask=mask, other=0.0) w2 = w2_.to(tl.float32) tl.store( output_ptr + input_start + offsets, a2r / rms2 * w2, # implicitly casts to output dtype here mask=mask, ) fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune( fused_dual_residual_rmsnorm_kernel ) def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False): assert len(x.shape) == 2 assert x.shape == residual.shape and x.dtype == residual.dtype output, mid = torch.empty_like(x), torch.empty_like(x) bs, hidden_dim = x.shape if autotune: fused_dual_residual_rmsnorm_kernel_autotune[(bs,)]( output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim ) else: config = { "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "num_warps": max( min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4 ), } fused_dual_residual_rmsnorm_kernel[(bs,)]( output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim, **config, ) return output, mid @triton.jit def fused_rmsnorm_kernel( output_ptr, activ_ptr, weight_ptr, eps: tl.constexpr, hidden_dim: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) input_start = pid * hidden_dim offsets = tl.arange(0, BLOCK_SIZE) mask = offsets < hidden_dim a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0) a = a_.to(tl.float32) rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps) w1_ = tl.load(weight_ptr + offsets, mask=mask, other=0.0) w1 = w1_.to(tl.float32) a_rms = a / rms * w1 tl.store( output_ptr + input_start + offsets, a_rms, # implicitly casts to output dtype here mask=mask, ) def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False): assert len(x.shape) == 2 if inplace: output = x else: output = torch.empty_like(x) bs, hidden_dim = x.shape config = { "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "num_warps": max( min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4 ), } fused_rmsnorm_kernel[(bs,)]( output, x, weight, eps=eps, hidden_dim=hidden_dim, **config ) return output class FusedDualResidualRMSNorm: """ Fused implementation of y = RMSNorm2(RMSNorm1(x) + residual)) """ def __init__(self, rmsnorm1, rmsnorm2) -> None: # the one after rmsnorm1 self.rmsnorm1 = rmsnorm1 self.rmsnorm2 = rmsnorm2 self.variance_epsilon = self.rmsnorm1.variance_epsilon assert self.rmsnorm1.variance_epsilon == self.rmsnorm2.variance_epsilon assert self.rmsnorm1.weight.shape == self.rmsnorm2.weight.shape def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) def forward( self, x: torch.Tensor, residual: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: if x.is_cuda: return self.forward_cuda(x, residual) else: return self.forward_flashinfer(x, residual) def forward_cuda( self, x: torch.Tensor, residual: torch.Tensor, autotune=False ) -> Tuple[torch.Tensor, torch.Tensor]: return fused_dual_residual_rmsnorm( x, residual, self.rmsnorm1.weight, self.rmsnorm2.weight, self.variance_epsilon, autotune=autotune, ) def forward_flashinfer( self, x: torch.Tensor, residual: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: normed1 = self.rmsnorm1(x) residual = normed1 + residual return self.rmsnorm2(residual), residual def forward_native( self, x: torch.Tensor, residual: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: normed1 = self.rmsnorm1.forward_native(x) residual = normed1 + residual return self.rmsnorm2.forward_native(residual), residual # gelu on first half of vector @triton.jit def gelu_and_mul_kernel( out_hidden_states_ptr, # (bs, hidden_dim) out_scales_ptr, # (bs,) hidden_states_ptr, # (bs, hidden_dim * 2) quant_max: tl.constexpr, static_scale: tl.constexpr, hidden_dim: tl.constexpr, # the output hidden_dim BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) input_start = pid * hidden_dim * 2 output_start = pid * hidden_dim input1_offs = tl.arange(0, BLOCK_SIZE) mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE) output_offs = tl.arange(0, BLOCK_SIZE) x1 = tl.load( hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0 ).to(tl.float32) x3 = tl.load( hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0 ).to(tl.float32) # gelu # cast down before mul to better match training? gelu_x1 = 0.5 * (1.0 + tl.erf(x1 * 0.7071067811865475)) * x1 out = x3 * gelu_x1.to(hidden_states_ptr.dtype.element_ty) if quant_max is not None: raise NotImplementedError() tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask) def gelu_and_mul_triton( hidden_states, scales=None, quantize=None, # dtype to quantize to out=None, ): bs, in_hidden_dim = hidden_states.shape hidden_dim = in_hidden_dim // 2 if out is None: out_hidden_states = torch.empty( (bs, hidden_dim), dtype=quantize or hidden_states.dtype, device=hidden_states.device, ) else: assert out.shape == (bs, hidden_dim) assert out.dtype == (quantize or hidden_states.dtype) out_hidden_states = out out_scales = None static_scale = False if quantize is not None: if scales is None: out_scales = torch.empty( (bs,), dtype=torch.float32, device=hidden_states.device ) else: out_scales = scales static_scale = True config = { # 8 ele per thread (not tuned) "num_warps": max( min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4 ), } gelu_and_mul_kernel[(bs,)]( out_hidden_states, out_scales, hidden_states, quant_max=torch.finfo(quantize).max if quantize is not None else None, static_scale=static_scale, hidden_dim=hidden_dim, BLOCK_SIZE=triton.next_power_of_2(hidden_dim), **config, ) if quantize is not None: return out_hidden_states, out_scales else: return out_hidden_states, None