import logging from typing import List, Optional import torch import triton import triton.language as tl from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.utils import is_cuda _is_cuda = is_cuda() if _is_cuda: from sglang.srt.layers.quantization.fp8_kernel import ( sglang_per_token_group_quant_fp8, ) logger = logging.getLogger(__name__) @triton.jit def deepep_permute_triton_kernel( input_ptr, gateup_input_ptr, src2dst_ptr, topk_ids_ptr, a1_scales_ptr, topk, hidden_size, BLOCK_SIZE: tl.constexpr, ): OutDtype = gateup_input_ptr.dtype.element_ty src_idx = tl.program_id(0) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk src_ptr = input_ptr + src_idx * hidden_size for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + tl.arange(0, BLOCK_SIZE) mask = offset < hidden_size in_data = tl.load(src_ptr + offset, mask=mask).to(OutDtype) for idx in range(topk): dst_idx = tl.load(src2dst_ptr + idx) if dst_idx >= 0: dst_ptr = gateup_input_ptr + dst_idx * hidden_size tl.store(dst_ptr + offset, in_data, mask=mask) @triton.jit def deepep_post_reorder_triton_kernel( down_output_ptr, output_ptr, src2dst_ptr, topk_ids_ptr, topk_weights_ptr, topk, hidden_size, BLOCK_SIZE: tl.constexpr, ): InDtype = down_output_ptr.dtype.element_ty src_idx = tl.program_id(0) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk topk_weights_ptr = topk_weights_ptr + src_idx * topk store_ptr = output_ptr + src_idx * hidden_size for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + tl.arange(0, BLOCK_SIZE) mask = offset < hidden_size sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) for idx in range(topk): dst_idx = tl.load(src2dst_ptr + idx) if dst_idx >= 0: weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) load_ptr = down_output_ptr + dst_idx * hidden_size in_data = tl.load(load_ptr + offset, mask=mask) sum_vec += in_data * weigh_scale tl.store(store_ptr + offset, sum_vec, mask=mask) @triton.jit def compute_src2dst_triton_kernel( reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr ): pid = tl.program_id(axis=0) dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = dst_id < num_toks src_id = tl.load(reorder_ids + dst_id, mask=mask) tl.store(src2dst + src_id, dst_id, mask=mask) @triton.jit def deepep_compute_src2dst_triton_kernel( reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr ): pid = tl.program_id(axis=0) dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = dst_id < num_toks src_id = tl.load(reorder_ids + dst_id, mask=mask) num_invalid = tl.load(num_minus_one) tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask) def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int): reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64) src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64) # Find offet expert_ids = torch.arange( num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype ) torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr) num_minus_one = seg_indptr[0] seg_indptr = seg_indptr - num_minus_one BLOCK_SIZE = 512 grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) deepep_compute_src2dst_triton_kernel[grid]( reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE ) reorder_topk_ids = reorder_topk_ids[num_minus_one:] return reorder_topk_ids, src2dst, seg_indptr @triton.jit def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): expert = tl.program_id(0) low = 0 high = num_toks - 1 target_location = -1 while low <= high: mid = (low + high) // 2 if tl.load(reorder_topk_ids + mid) > expert: high = mid - 1 else: low = mid + 1 target_location = mid tl.store(seg_indptr + expert + 1, target_location + 1) def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) compute_seg_indptr_triton_kernel[(num_experts,)]( reorder_topk_ids, seg_indptr, topk_ids.numel() ) BLOCK_SIZE = 512 grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) compute_src2dst_triton_kernel[grid]( reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE ) return reorder_topk_ids, src2dst, seg_indptr @triton.jit def pre_reorder_triton_kernel( input_ptr, gateup_input_ptr, src2dst_ptr, topk_ids_ptr, a1_scales_ptr, start_expert_id, end_expert_id, topk, hidden_size, BLOCK_SIZE: tl.constexpr, ): OutDtype = gateup_input_ptr.dtype.element_ty src_idx = tl.program_id(0) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk src_ptr = input_ptr + src_idx * hidden_size for idx in range(topk): expert_id = tl.load(topk_ids_ptr + idx) if expert_id >= start_expert_id and expert_id <= end_expert_id: if a1_scales_ptr is not None: scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id) else: scale = 1.0 dst_idx = tl.load(src2dst_ptr + idx) dst_ptr = gateup_input_ptr + dst_idx * hidden_size for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + tl.arange(0, BLOCK_SIZE) mask = offset < hidden_size in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) out_data = (in_data * scale).to(OutDtype) tl.store(dst_ptr + offset, out_data, mask=mask) @triton.jit def silu_and_mul_triton_kernel( gateup_output, down_input, hidden_size, reorder_topk_ids, scales, start_expert_id, end_expert_id, BLOCK_SIZE: tl.constexpr, ): InDtype = gateup_output.dtype.element_ty OutDtype = down_input.dtype.element_ty half_hidden_size = hidden_size // 2 pid = tl.program_id(0) expert_id = tl.load(reorder_topk_ids + pid) if expert_id >= start_expert_id and expert_id <= end_expert_id: gateup_output_ptr = gateup_output + pid * hidden_size gate_output_ptr = gateup_output_ptr up_output_ptr = gateup_output_ptr + half_hidden_size down_input_ptr = down_input + pid * half_hidden_size if scales is not None: scale = tl.load(scales + expert_id - start_expert_id) scale = (1 / scale).to(InDtype) else: scale = 1 for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): offset = start_offset + tl.arange(0, BLOCK_SIZE) mask = offset < half_hidden_size gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32) up_output = tl.load(up_output_ptr + offset, mask=mask) # silu & mul & quantize gate_output = gate_output * tl.sigmoid(gate_output) gate_output = gate_output.to(InDtype) silu_mul_output = gate_output * up_output * scale silu_mul_output = silu_mul_output.to(OutDtype) tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) @triton.jit def tanh(x): return 2 * tl.sigmoid(2 * x) - 1 @triton.jit def gelu_and_mul_triton_kernel( gateup_output, down_input, hidden_size, reorder_topk_ids, scales, start_expert_id, end_expert_id, BLOCK_SIZE: tl.constexpr, ): InDtype = gateup_output.dtype.element_ty OutDtype = down_input.dtype.element_ty half_hidden_size = hidden_size // 2 pid = tl.program_id(0) expert_id = tl.load(reorder_topk_ids + pid) if expert_id >= start_expert_id and expert_id <= end_expert_id: gateup_output_ptr = gateup_output + pid * hidden_size gate_output_ptr = gateup_output_ptr up_output_ptr = gateup_output_ptr + half_hidden_size down_input_ptr = down_input + pid * half_hidden_size if scales is not None: scale = tl.load(scales + expert_id - start_expert_id) scale = (1 / scale).to(InDtype) else: scale = 1 for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): offset = start_offset + tl.arange(0, BLOCK_SIZE) mask = offset < half_hidden_size gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32) up_output = tl.load(up_output_ptr + offset, mask=mask) # gelu & mul & quantize # https://pytorch.org/docs/stable/generated/torch.nn.GELU.html # sqrt(2/pi) kAlpha = 0.7978845608028654 gate_output = ( 0.5 * gate_output * ( 1 + tanh( kAlpha * ( gate_output + 0.044715 * gate_output * gate_output * gate_output ) ) ) ) gate_output = gate_output.to(InDtype) gelu_mul_output = gate_output * up_output * scale gelu_mul_output = gelu_mul_output.to(OutDtype) tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask) @triton.jit def post_reorder_triton_kernel( down_output_ptr, output_ptr, src2dst_ptr, topk_ids_ptr, topk_weights_ptr, start_expert_id, end_expert_id, topk, hidden_size, BLOCK_SIZE: tl.constexpr, ): InDtype = down_output_ptr.dtype.element_ty src_idx = tl.program_id(0) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk topk_weights_ptr = topk_weights_ptr + src_idx * topk computed = False store_ptr = output_ptr + src_idx * hidden_size for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + tl.arange(0, BLOCK_SIZE) mask = offset < hidden_size sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) for idx in range(topk): expert_id = tl.load(topk_ids_ptr + idx) if expert_id >= start_expert_id and expert_id <= end_expert_id: computed = True dst_idx = tl.load(src2dst_ptr + idx) weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) load_ptr = down_output_ptr + dst_idx * hidden_size in_data = tl.load(load_ptr + offset, mask=mask) sum_vec += in_data * weigh_scale tl.store(store_ptr + offset, sum_vec, mask=mask) if computed == False: for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + tl.arange(0, BLOCK_SIZE) mask = offset < hidden_size tl.store( store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask ) @triton.jit def compute_m_range( pid, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M: tl.constexpr, ): idx = 0 for bs in range(batch_size): tiles = tl.load(m_num_tiles_indptr + bs) if pid >= tiles: idx = bs idx_start = tl.load(m_num_tiles_indptr + idx) m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M) expert_id = tl.load(weight_indices + idx) return m_range_start, m_range_end, expert_id @triton.jit def grouped_gemm_triton_kernel( a, b, c, batch_size, N, K, seg_indptr, weight_indices, m_num_tiles_indptr, scale_a, scale_b, use_fp8_w8a8: tl.constexpr, group_n: tl.constexpr, group_k: tl.constexpr, a_stride_0: tl.constexpr, b_stride_0: tl.constexpr, b_stride_1: tl.constexpr, as_stride_0: tl.constexpr, as_stride_1: tl.constexpr, bs_stride_0: tl.constexpr, bs_stride_2: tl.constexpr, bs_stride_1: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): c_dtype = c.dtype.element_ty pid_m = tl.program_id(0) pid_n = tl.program_id(1) total_m_block = tl.load(m_num_tiles_indptr + batch_size) if pid_m >= total_m_block: return m_range_start, m_range_end, expert_id = compute_m_range( pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M ) if m_range_end - m_range_start == 0: return n_range_start = pid_n * BLOCK_SIZE_N n_range_end = min(n_range_start + BLOCK_SIZE_N, N) offs_am = tl.arange(0, BLOCK_SIZE_M) offs_bn = tl.arange(0, BLOCK_SIZE_N) offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0) offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0) offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :] b_ptr = b + ( (expert_id * b_stride_0) + (n_range_start + offs_bn[:, None]) * b_stride_1 + offs_k[None, :] ) if group_k > 0 and group_n > 0: a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0 offs_bsn = (n_range_start + offs_bn) // group_n b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a_tile = tl.load( a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 ) b_tile = tl.load( b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 ) if group_k > 0 and group_n > 0: k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1) b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2) accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :] else: accumulator = tl.dot(a_tile, b_tile.T, accumulator) a_ptr += BLOCK_SIZE_K b_ptr += BLOCK_SIZE_K if use_fp8_w8a8 and not (group_k > 0 and group_n > 0): scale_a_value = tl.load(scale_a + expert_id) scale_b_value = tl.load(scale_b + expert_id) accumulator *= scale_a_value * scale_b_value c_tile = accumulator.to(c_dtype) offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M) offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N) c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :] c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end) tl.store(c_ptr, c_tile, mask=c_mask) @triton.jit def compute_m_num_tiles_indptr( m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr ): for bs in range(batch_size): m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs) cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M) pre_num_tiles = tl.load(m_num_tiles_indptr + bs) tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles) def grouped_gemm_triton( a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, batch_size: int, weight_column_major: bool, seg_indptr: Optional[torch.Tensor] = None, weight_indices: Optional[torch.Tensor] = None, use_fp8_w8a8: bool = False, scale_a: torch.Tensor = None, scale_b: torch.Tensor = None, block_shape: Optional[List[int]] = None, ): assert weight_column_major == True # TODO: more if use_fp8_w8a8 and block_shape is None: assert scale_a is not None and scale_b is not None if block_shape is not None: assert len(block_shape) == 2 block_n, block_k = block_shape[0], block_shape[1] if _is_cuda: a, scale_a = sglang_per_token_group_quant_fp8(a, block_k) else: a, scale_a = per_token_group_quant_fp8(a, block_k) assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1] assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2] assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1] # TODO: adjust config or tune kernel # Reduce block size to prevent L40 shared memory overflow. config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, } m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64) compute_m_num_tiles_indptr[(1,)]( m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"] ) grid = lambda META: ( triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size, triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]), ) grouped_gemm_triton_kernel[grid]( a, b, c, batch_size, b.size(1), b.size(2), seg_indptr, weight_indices, m_num_tiles_indptr, scale_a, scale_b, use_fp8_w8a8, 0 if block_shape is None else block_shape[0], 0 if block_shape is None else block_shape[1], a.stride(0), b.stride(0), b.stride(1), scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0, scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0, scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0, scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0, scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0, **config, ) return c